别再只调包了!手把手带你用PyTorch从零实现BiLSTM-CRF命名实体识别(附完整代码)

张开发
2026/4/17 15:19:33 15 分钟阅读

分享文章

别再只调包了!手把手带你用PyTorch从零实现BiLSTM-CRF命名实体识别(附完整代码)
从零构建BiLSTM-CRF命名实体识别模型深入原理与完整PyTorch实现1. 命名实体识别技术全景命名实体识别NER作为自然语言处理的基础任务其核心目标是从非结构化文本中定位并分类特定类型的实体。不同于简单的词典匹配现代NER系统需要解决实体边界模糊、类型多样、上下文依赖等复杂问题。在医疗领域NER面临独特挑战专业术语密集弥漫大B细胞淋巴瘤等专业名词需要精确识别表述变体丰富同一症状可能有气促、呼吸困难等多种表述标签约束严格疾病与症状标签需遵循B-dis/I-dis、B-sym/I-sym的标注规则传统方法对比# 基于规则的方法示例 def rule_based_ner(text): disease_dict [淋巴瘤, 肺癌, 冠心病] for term in disease_dict: if term in text: yield (term, disease) # 基于统计的方法HMM示例 hmm HiddenMarkovModel(transition_prob, emission_prob)2. BiLSTM-CRF架构深度解析2.1 模型整体架构graph TD A[输入序列] -- B[词嵌入层] B -- C[BiLSTM层] C -- D[全连接层] D -- E[CRF层] E -- F[预测标签序列]2.2 关键组件实现词嵌入层配置class BiLSTM_CRF(nn.Module): def __init__(self, vocab_size, embedding_dim, ...): super().__init__() self.embedding nn.Embedding( num_embeddingsvocab_size, embedding_dimembedding_dim, padding_idx0 )BiLSTM层参数参数名典型值作用hidden_size100单层LSTM单元数num_layers2堆叠LSTM层数dropout0.3层间dropout概率bidirectionalTrue启用双向结构CRF转移矩阵设计self.transitions nn.Parameter(torch.randn(tag_size, tag_size)) # 约束非法转移 self.transitions.data[START_TAG, :] -10000 self.transitions.data[:, STOP_TAG] -100003. 完整实现流程3.1 数据预处理管道def build_vocab(texts): char_to_id {PAD: 0, UNK: 1} for text in texts: for char in text: if char not in char_to_id: char_to_id[char] len(char_to_id) return char_to_id def sentence_to_ids(sentence, char_to_id, max_len): ids [char_to_id.get(c, 1) for c in sentence] ids ids[:max_len] [0]*(max_len - len(ids)) return torch.tensor(ids, dtypetorch.long)3.2 模型训练关键代码def train_epoch(model, dataloader, optimizer): model.train() total_loss 0 for inputs, tags in tqdm(dataloader): optimizer.zero_grad() loss model.neg_log_likelihood(inputs, tags) loss.backward() optimizer.step() total_loss loss.item() return total_loss / len(dataloader)3.3 维特比解码实现def viterbi_decode(emissions, transitions): seq_length, num_tags emissions.shape viterbi torch.zeros(seq_length, num_tags) backpointers torch.zeros(seq_length, num_tags, dtypetorch.long) # 初始化第一步 viterbi[0] emissions[0] for t in range(1, seq_length): scores viterbi[t-1].unsqueeze(1) transitions viterbi[t] emissions[t] scores.max(0)[0] backpointers[t] scores.argmax(0) # 回溯最佳路径 best_path [viterbi[-1].argmax().item()] for t in reversed(range(1, seq_length)): best_path.append(backpointers[t, best_path[-1]].item()) return best_path[::-1]4. 医疗NER实战技巧4.1 特殊处理策略领域词典增强medical_terms { B-dis: [淋巴瘤, 肺炎], I-dis: [型, 期] }标签不平衡处理class_weights torch.tensor([1.0, 2.0, 2.0, 3.0, 3.0]) # O, B-dis, I-dis, B-sym, I-sym criterion nn.CrossEntropyLoss(weightclass_weights)4.2 评估指标优化精确率-召回率平衡策略def evaluate(y_true, y_pred): tp ((y_true y_pred) (y_true ! 0)).sum().item() pred_pos (y_pred ! 0).sum().item() true_pos (y_true ! 0).sum().item() precision tp / (pred_pos 1e-10) recall tp / (true_pos 1e-10) f1 2 * precision * recall / (precision recall 1e-10) return precision, recall, f15. 进阶优化方向5.1 模型改进方案层次化表示self.char_embedding nn.Embedding(char_vocab_size, char_embed_dim) self.word_lstm nn.LSTM(word_embed_dim, hidden_size//2, bidirectionalTrue)注意力机制增强self.attention nn.MultiheadAttention(embed_dimhidden_size, num_heads4)5.2 预训练模型集成from transformers import BertModel class BERT_BiLSTM_CRF(nn.Module): def __init__(self, bert_path, tagset_size): super().__init__() self.bert BertModel.from_pretrained(bert_path) self.lstm nn.LSTM(768, hidden_size//2, bidirectionalTrue) self.crf CRF(tagset_size)6. 完整代码结构项目目录结构ner_project/ ├── data/ │ ├── train.txt # 标注数据 │ └── vocab.json # 字符词典 ├── model/ │ ├── bilstm_crf.py # 模型定义 │ └── train.py # 训练脚本 ├── utils/ │ ├── data_loader.py # 数据加载 │ └── metrics.py # 评估指标 └── config.yaml # 超参数配置核心训练循环for epoch in range(epochs): train_loss train_epoch(model, train_loader, optimizer) val_metrics evaluate(model, val_loader) print(fEpoch {epoch}:) print(f Train Loss: {train_loss:.4f}) print(f Val F1: {val_metrics[f1]:.4f}) if val_metrics[f1] best_f1: torch.save(model.state_dict(), best_model.pt) best_f1 val_metrics[f1]7. 实际应用建议领域适应技巧使用领域特定语料继续预训练设计领域相关的标签约束规则部署优化model BiLSTM_CRF.load_from_checkpoint(best_model.pt) model.eval() torchscript_model torch.jit.script(model) # 转换为TorchScript持续学习策略def online_learning(new_data): optimizer torch.optim.SGD(model.parameters(), lr0.01) for x, y in new_data: loss model(x, y) loss.backward() optimizer.step()通过本实现开发者不仅能掌握BiLSTM-CRF的核心原理还能获得可直接应用于实际项目的完整代码框架。建议在医疗文本上测试时重点关注模型对嵌套实体、不完整表述等复杂情况的处理能力。

更多文章