作者:一个处女座的程序猿

LLMs之BELLE:源码解读(dpo_train.py文件)训练一个基于强化学习的自动对话生成模型(DPO算法微调预训练语言模型)—解析命令行参数与初始化→加载数据集(json格式)→模型训练与评估

LLMs之BELLE:源码解读(dpo_train.py文件)训练一个基于强化学习的自动对话生成模型(DPO算法微调预训练语言模型)—解析命令行参数与初始化→加载数据集(json格式)→模型训练与评估之详细攻略

目录

源码解读(dpo_train.py文件)训练一个基于强化学习的自动对话生成模型(DPO算法微调预训练语言模型)—解析命令行参数与初始化→加载数据集(json格式)→模型训练与评估

# 0、获取环境变量(分布式训练中的进程数量)

# ignore_warnings函数用于忽略 Python 警告信息,通常在训练脚本中用于减少控制台输出的干扰。

# 1、解析命令行参数与初始化

# 1.1、解析命令行参数,ScriptArguments数据类配置和控制DPO训练过程:

# 1.2、根据命令行参数 script_args.torch_dtype 的值来确定模型的数据类型

# 1.3、加载预训练模型

# 1.4、if判断是否忽略模型中的偏置(bias)缓冲区:针对 PyTorch 分布式训练的修复,用于处理与偏置相关的问题。

# 1.5、加载参考模型:常用于DPO训练中,可能会与主模型一起用于比较或其他目的

# 1.6、加载tokenizer并padding操作

# 2、加载数据集:训练数据、验证数据

# 2.1、指定数据格式都为json格式,每个样本是一个dict

# 2.2、通过load_dataset函数从jason格式的train_file和validation_file中加载训练和验证数据集

# 3、模型训练

# 3.1、计算一些训练相关的超参数和配置,如训练集和验证集的样本数量、GPU数量、batch_size、

# 3.2、初始化TrainingArguments训练参数:包括学习率、批量大小、模型保存策略、日志设置等等

# 3.3、自定义的DPOTrainer初始化(trl库)

# (1)、使用 CustomDPOTrainer类对训练过程进行封装。包括了模型、模型参考、训练参数、训练数据集、评估数据集(如果提供)、分词器和数据收集器

# (2)、利用 compute_metrics函数分配给 dpo_trainer 的 compute_metrics 属性,以计算评估指标

# (3)、获取当前进程的全局排名,在分布式训练中很有用,用于确定每个进程的唯一标识。

# 3.4、打印训练配置

# 3.5、启动训练过程:调用trainer的train()接口训练模型

实战代码



 

源码解读(dpo_train.py文件)训练一个基于强化学习的自动对话生成模型(D

lock