LLMs之BELLE:源码解读(dpo_train.py文件)训练一个基于强化学习的自动对话生成模型(DPO算法微调预训练语言模型)—解析命令行参数与初始化→加载数据集(json格式)→模型训练与评估之详细攻略
目录
源码解读(dpo_train.py文件)训练一个基于强化学习的自动对话生成模型(DPO算法微调预训练语言模型)—解析命令行参数与初始化→加载数据集(json格式)→模型训练与评估
# ignore_warnings函数用于忽略 Python 警告信息,通常在训练脚本中用于减少控制台输出的干扰。
# 1.1、解析命令行参数,ScriptArguments数据类配置和控制DPO训练过程:
# 1.2、根据命令行参数 script_args.torch_dtype 的值来确定模型的数据类型
# 1.4、if判断是否忽略模型中的偏置(bias)缓冲区:针对 PyTorch 分布式训练的修复,用于处理与偏置相关的问题。
# 1.5、加载参考模型:常用于DPO训练中,可能会与主模型一起用于比较或其他目的
# 2.1、指定数据格式都为json格式,每个样本是一个dict
# 2.2、通过load_dataset函数从jason格式的train_file和validation_file中加载训练和验证数据集
# 3.1、计算一些训练相关的超参数和配置,如训练集和验证集的样本数量、GPU数量、batch_size、
# 3.2、初始化TrainingArguments训练参数:包括学习率、批量大小、模型保存策略、日志设置等等
# (1)、使用 CustomDPOTrainer类对训练过程进行封装。包括了模型、模型参考、训练参数、训练数据集、评估数据集(如果提供)、分词器和数据收集器
# (2)、利用 compute_metrics函数分配给 dpo_trainer 的 compute_metrics 属性,以计算评估指标
# (3)、获取当前进程的全局排名,在分布式训练中很有用,用于确定每个进程的唯一标识。
# 3.5、启动训练过程:调用trainer的train()接口训练模型