摘要
论文链接:https://arxiv.org/abs/1803.05407.pdf
官方代码:https://github.com/timgaripov/swa
论文翻译:【第32篇】SWA:平均权重导致更广泛的最优和更好的泛化_AI浩的博客-CSDN博客
SWA简单来说就是对训练过程中的多个checkpoints进行平均,以提升模型的泛化性能。记训练过
程第 个epoch的checkpoint为 ,一般情况下我们会选择训练过程中最后的一个epoch的模型 或
者在验证集上效果最好的一个模型 作为最终模型。但SWA一般在最后采用较高的固定学习速率或者
周期式学习速率额外训练一段时间,取多个checkpoints的平均值。
pytorch使用举例:
上面的代码展示了SWA的主要代码,实现的步骤:
1、定义SGD优化器。
2、定义SWA。
3、定义SWALR,调整模型的学习率。
4、开始训练,等待训练完成。
5、在每个epoch中更新模型的参数,更新学习率。
6、等待训练完成后,更新BN层的参数。
from torch.optim.swa_utils import AveragedModel, SWALR
# 采用SGD优化器
optimizer = torch.optim.SGD(model.parameters(),lr=1e-4, weight_decay=1e-3,
momentum=0.9)
# 随机权重平均SWA,实现更好的泛化
swa_model = AveragedModel(model).to(device)
# SWA调整学习率
swa_scheduler = SWALR(optimizer, swa_lr=1e-6)
for epoch in range(1, epoch + 1):
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device, non_blocking=True), target.to(device,
non_blocking=True)
# 在反向传播前要手动将梯度清零
optimizer.zero_grad()
output = model(data)
#计算losss
loss = train_criterion(output, targets)
# 反向传播求解梯度
loss.backward()
optimizer.step()
lr = optimizer.state_dict()['param_groups'][0]['lr']
swa_model.update_parameters(model)
swa_scheduler.step()
# 最后更新BN层参数
torch.optim.swa_utils.update_bn(train_loader, swa_model, device=device)
# 保存结果
torch.save(swa_model.state_dict(), "last.pt")
评论0