import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import get_linear_schedule_with_warmup
from torch.utils.tensorboard import SummaryWriter
from utils import load_data, CustomDataset
from model import BertTextCNN
# 模型训练
def train_model(model, train_loader, val_loader, num_epochs, optimizer, device, patience=5):
# 设置模型为训练模式,这是为了启用Dropout和Batch Normalization等在训练时使用的特性
model.train()
# 添加学习率调度器
total_steps = len(train_loader) * num_epochs
# 该调度器用于动态调整学习率,其中num_warmup_steps表示预热步数,num_training_steps表示总训练步数。
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)
# 添加早停机制(防过拟合):设置一个patience参数,表示连续多少个epoch验证集acc没有减小就停止训练
best_val_accuracy = 0.93
patience_count = 0
for epoch in range(num_epochs):
total_loss = 0
'''看一下训练集的准确度,可删'''
correct_predictions = 0
total_samples = 0
# 使用tqdm包装train_loader,显示进度条
train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}")
for step, (input_ids, attention_mask, labels) in enumerate(train_loader_tqdm):
input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
# 将优化器的梯度置零,用于反向传播前准备。
optimizer.zero_grad()
logits = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
# 使用交叉熵损失函数计算当前batch的损失
loss = F.cross_entropy(logits, labels)
total_loss += loss.item()
'''看一下训练集的准确度,可删'''
predicted_labels = torch.argmax(logits, dim=1)
correct_predictions += (predicted_labels == labels).sum().item()
total_samples += len(labels)
# 反向传播,计算模型参数的梯度
loss.backward()
# 根据梯度更新模型参数
optimizer.step()
# 调整学习率
scheduler.step()
# 更新进度条
train_loader_tqdm.set_postfix({"Loss": loss.item()})
average_loss = total_loss / len(train_loader)
train_acc = correct_predictions / total_samples
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {average_loss}")
'''看一下训练集的准确度,可删'''
print(f"训练集正确数:{correct_predictions},样本总数:{total_samples},准确率:{train_acc}")
# 在每个epoch结束后进行验证模型性能
val_acc, val_loss = evaluate_model(model, val_loader, device)
# 记录准确率
acc_metrics = {'Train': train_acc, 'Validation': val_acc}
writer.add_scalars("Accuracy", acc_metrics, epoch)
# 记录损失率
loss_metrics = {'Train': average_loss, 'Validation': val_loss}
writer.add_scalars("Loss", loss_metrics, epoch)
if val_acc > best_val_accuracy:
best_val_accuracy = val_acc
patience_count = 0
# 保存最佳模型
torch.save(model.state_dict(), f'./model/best_model_accuracy_{best_val_accuracy:.4f}')
else:
patience_count += 1
if patience_count >= patience:
print(f"由于validation loss没有提升,提前在epoch {epoch + 1} 停止。")
break
return best_val_accuracy
# 模型验证
def evaluate_model(model, val_loader, device):
# 设置模型为评估模式,这是为了禁用Dropout和Batch Normalization等在评估时使用的特性
model.eval()
total_loss = 0
correct_predictions = 0
total_samples = 0
# 上下文管理器,确保在验证阶段不会计算梯度,从而节省内存和计算资源
with torch.no_grad():
for input_ids, attention_mask, labels in val_loader:
input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
logits = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = F.cross_entropy(logits, labels) # 使用交叉熵损失函数
total_loss += loss.item()
predicted_labels = torch.argmax(logits, dim=1)
correct_predictions += (predicted_labels == labels).sum().item()
# 在验证集很大的情况下,可以逐个批次进行验证,而不需要一次性将所有验证样本加载到内存中
total_samples += len(labels)
average_loss = total_loss / len(val_loader)
accuracy = correct_predictions / total_samples
print(f"Validation Loss: {average_loss}, Accuracy: {accuracy}")
return accuracy, average_loss
if __name__ == "__main__":
train_file_path = 'train.csv'
val_file_path = 'val.csv'
# 加载BertTokenizer分词器,用于将文本转换成模型可以处理的输入格式
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
# 定义一个目录用于保存TensorBoard的日志
tensorboard_dir = "./logs"
# 初始化SummaryWriter对象,用于记录训练过程中的日志和指标
writer = SummaryWriter(log_dir=tensorboard_dir)
# 加载训练集数据
train_texts, train_labels, label2id = load_data(train_file_path)
train_dataset = CustomDataset(train_texts, train_labels, tokenizer, label2id)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
# 加载验证集数据
val_texts, val_labels, _ = load_data(val_file_path)
val_dataset = CustomDataset(val_texts, val_labels, tokenizer, label2id)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
# 加载BertForSequenceClassification模型
bert_model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=len(label2id))
# 自定义的TextCNN参数
vocab_size = len(tokenizer) # 替换为实际词汇表大小
embedding_dim = 768 # 与Bert模型的隐藏层维度相同
num_filters = 100 # 卷积核的数量
filter_sizes = [2, 3, 4, 5] # 不同大小的卷积核尺寸
dropout = 0.3 # dropout率
# 将Bert和TextCNN连接成一个整体模型
model = BertTextCNN(bert_model, num_labels=len(label2id), vocab_size=vocab_size,
embedding_dim=embedding_dim, num_filters=num_filters,
filter_sizes=filter_sizes, dropout=dropout)
# 定义优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
# 训练模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
best_val_acc = train_model(model, train_loader, val_loader, num_epochs=20, optimizer=optimizer, device=device)
# 加载最佳模型并在验证集上重新评估性能
model.load_state_dict(torch.load(f'./model/best_model_accuracy_{best_val_acc:.4f}.pth'))
evaluate_model(model, val_loader, device)
# 关闭SummaryWriter对象
writer.close()
# 保存模型的状态字典(权重)
torch.save(model.state_dict(), './model/model_last') # 将模型保存到'./model'
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
基于Bert+TextCNN模型的文本分类项目源码.zip 代码完整下载就可以使用。 基于Bert+TextCNN模型的文本分类项目源码.zip 代码完整下载就可以使用。基于Bert+TextCNN模型的文本分类项目源码.zip 代码完整下载就可以使用。基于Bert+TextCNN模型的文本分类项目源码.zip 代码完整下载就可以使用。基于Bert+TextCNN模型的文本分类项目源码.zip 代码完整下载就可以使用。基于Bert+TextCNN模型的文本分类项目源码.zip 代码完整下载就可以使用。基于Bert+TextCNN模型的文本分类项目源码.zip 代码完整下载就可以使用。基于Bert+TextCNN模型的文本分类项目源码.zip 代码完整下载就可以使用。基于Bert+TextCNN模型的文本分类项目源码.zip 代码完整下载就可以使用。基于Bert+TextCNN模型的文本分类项目源码.zip 代码完整下载就可以使用。基于Bert+TextCNN模型的文本分类项目源码.zip 代码完整下载就可以使用。基于Bert+TextCNN模型的文本分类项目源码.zip 代码完
资源推荐
资源详情
资源评论
收起资源包目录
基于Bert+TextCNN模型的文本分类项目.zip (13个子文件)
基于Bert+TextCNN模型的文本分类项目源码
utils.py 2KB
main.py 7KB
val.csv 138KB
model.py 3KB
.idea
vcs.xml 180B
misc.xml 299B
inspectionProfiles
profiles_settings.xml 174B
modules.xml 290B
.gitignore 176B
text_classification.iml 324B
train.csv 553KB
test.py 3KB
test.csv 138KB
共 13 条
- 1
猰貐的新时代
- 粉丝: 1w+
- 资源: 2571
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功
- 1
- 2
前往页