import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm import tqdm
from torchcrf import CRF
import pickle
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
from seqeval.scheme import IOB2
from torch.optim.lr_scheduler import OneCycleLR
# 定义运算设备
device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
def load_data(path):
'''
加载数据
'''
with open(path, 'rb') as f:
data = pickle.load(f)
return data
class DataSet(Dataset):
def __init__(self, path, vocab, classes):
data = load_data(path)
# 获取词汇表
self.vocab = vocab
self.classes = classes
self.X = data['texts']
self.y = data['labels']
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
X_idx = [self.vocab.get(chr, self.vocab['<unk>']) for chr in self.X[idx]]
y_idx = [self.classes[chr] for chr in self.y[idx]]
return X_idx, y_idx
def collate_fn(max_length=128):
'''
使用闭包给整理器传参
max_length: 规定最大截断长度
max_len: 批次中序列的最大长度
mask: pytorchcrf 所需要的掩码
'''
def collate(batch):
x, y = zip(*batch)
x, y = list(x), list(y)
max_len = max(len(x) for x in x)
x_padded = torch.zeros((len(x), max_len), dtype=torch.long)
y_padded = torch.full((len(y), max_len), -100, dtype=torch.long)
mask = torch.zeros((len(x), max_len), dtype=torch.bool)
for i, (x_i, y_i) in enumerate(zip(x, y)):
x_i = x_i[:max_len]
y_i = y_i[:max_len]
x_padded[i, :len(x_i)] = torch.tensor(x_i)
y_padded[i, :len(y_i)] = torch.tensor(y_i)
mask[i, :len(x_i)] = 1
return (x_padded, y_padded, mask)
return collate
class LSTMCRF(nn.Module):
def __init__(self, vocab_size, embedding_size, hidden_size, output_size, layer_nums):
super(LSTMCRF, self).__init__()
self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.hidden_size = hidden_size
self.output_size = output_size
self.layer_nums = layer_nums
self.embedding = nn.Embedding(vocab_size, embedding_size, dtype=torch.float)
self.lstm = nn.LSTM(embedding_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
# 充当损失函数,解码等操作
self.CRF = CRF(output_size, batch_first=True)
def forward(self, X):
x = self.embedding(X)
x, _ = self.lstm(x)
emissions = self.fc(x)
return emissions
def loss(self, emissions, tags, mask):
# 计算负对数
return -self.CRF(emissions, tags, mask)
def decode(self, emissions, mask):
# 解码推理
return self.CRF.decode(emissions, mask)
def train_model(model, train_loader, dev_loader, optimizer, num_epochs, id2class, scheduler):
best_score = 0# 最佳分数
for epoch in range(num_epochs):
model.train()
total_loss = 0
for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
X, y, mask = batch
X, y, mask = X.to(device), y.to(device), mask.to(device)
optimizer.zero_grad()
emissions = model(X)
loss = model.loss(emissions, y, mask=mask)
loss.backward()
optimizer.step()
total_loss += loss.item()
# 学习率更新
scheduler.step()
avg_loss = total_loss / len(train_loader)
print(f'Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}')
f1 = evaluate(model, dev_loader, id2class)
print(f'Epoch {epoch+1}/{num_epochs} - F1: {f1:.4f}')
print(f"学习率 --> {scheduler.get_last_lr()}")
if f1 > best_score:
best_score = f1
torch.save(model.state_dict(), "best_model.pth")
def evaluate(model, dev_loader, id2class):
model.eval()
total_preds = []
total_labels = []
with torch.no_grad():
for batch in tqdm(dev_loader, desc=f"Evaluate "):
X, y, mask = batch
X, y, mask = X.to(device), y.to(device), mask.to(device)
emissions = model(X)
# 解码
preds = model.decode(emissions, mask)
# 构建评估所需的数据格式
for pred, label, m in zip(preds, y, mask):
pred = [id2class[p] for p in pred]
label = [id2class[l.item()] for l in label[m]]
total_preds.append(pred)
total_labels.append(label)
# 使用 seqeval 评估
report = classification_report(total_labels, total_preds, scheme=IOB2)
print(report)
f1 = f1_score(total_labels, total_preds, average='macro', scheme=IOB2)
return f1
def main():
# 数据集路径
train_path = 'train'
dev_path = 'dev'
test_path = 'test'
vocab_path = 'vocab'
classes_path = 'classes'
# 加载数据
vocab = load_data(vocab_path)
classes = load_data(classes_path)
vocab_size = len(vocab)
classes_size = len(classes)
id2class = {v: k for k, v in classes.items()}
# 构架 dataloader
train_dataset = DataSet(train_path, vocab, classes)
dev_dataset = DataSet(dev_path, vocab, classes)
test_dataset = DataSet(test_path, vocab, classes)
torch.manual_seed(42)
max_length = 128
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn(max_length))
dev_loader = DataLoader(dev_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn(max_length))
test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate_fn(max_length), shuffle=False)
# 超参
embedding_size = 256
hidden_size = 256
output_size = classes_size
layer_nums = 1
num_epochs = 50
lr = 1e-4
# 模型、优化器、学习率调度器
model = LSTMCRF(vocab_size, embedding_size, hidden_size, output_size,layer_nums)
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(train_loader), epochs=num_epochs)
# 模型训练
train_model(model, train_loader, dev_loader, optimizer, num_epochs, id2class, scheduler)
model.load_state_dict(torch.load("best_model.pth"))
# 模型测试
test_f1 = evaluate(model, test_loader, id2class)
print(f'Test F1 Score: {test_f1}')
if __name__ == '__main__':
main()
没有合适的资源?快使用搜索试试~ 我知道了~
使用LSTM+CRF实现NER
共12个文件
train:2个
test:2个
dev:2个
0 下载量 65 浏览量
2024-09-27
08:23:42
上传
评论
收藏 6.95MB ZIP 举报
温馨提示
使用LSTM+CRF实现NER
资源推荐
资源详情
资源评论
收起资源包目录
归档.zip (12个子文件)
Test.ipynb 6KB
classes 92B
train 816KB
weiboNER_2nd_conll.dev 103KB
vocab 27KB
weiboNER_2nd_conll.test 106KB
weiboNER_2nd_conll.train 523KB
save_data.py 898B
best_model.pth 5.08MB
test 164KB
dev 160KB
lstm_crf.py 7KB
共 12 条
- 1
资源评论
多吃轻食
- 粉丝: 890
- 资源: 3
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- 青藏高原冻土空间分布-2023年最新绘制
- order system(1).c
- 基于微博数据的舆情分析项目(包括微博爬虫、LDA主题分析和情感分析)高分项目
- 测试电路板用的双针床设备(含工程图sw17可编辑+cad)全套技术开发资料100%好用.zip
- 基于Python控制台的网络入侵检测
- 基于微博数据的舆情分析项目-包括数据分析、LDA主题分析和情感分析(高分项目源码)
- 制作生成自己专属的安卓app应用 制作apk
- 基于python开发的贪食蛇(源码)
- frmcurvechart.ui
- NSFetchedResultsControllerError如何解决.md
- 基于java银行客户信息管理系统论文.doc
- EmptyStackException(解决方案).md
- RuntimeError.md
- wqwerwerwere
- 基于java+ssm+mysql的4S店预约保养系统任务书.docx
- 基于java在线考试系统2毕业论文.doc
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功