import os
import sys
import torch
import hashlib
from itertools import chain
from typing import List, Literal, Optional, Tuple
import transformers
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
HfArgumentParser,
Seq2SeqTrainingArguments,
BitsAndBytesConfig
)
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
import datasets
from datasets import Dataset, concatenate_datasets, load_dataset
from peft import (
PeftModel,
TaskType,
LoraConfig,
get_peft_model
)
from peft.utils import CONFIG_NAME
from trl import AutoModelForCausalLMWithValueHead
from .config import (
ModelArguments,
DataTrainingArguments,
FinetuningArguments,
GeneratingArguments
)
from .template import Template
from .other import (
get_logger,
load_trainable_params,
load_valuehead_params,
print_trainable_params,
prepare_model_for_training,
IGNORE_INDEX
)
check_min_version("4.29.1")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
require_version("trl>=0.4.4", "To fix: pip install trl>=0.4.4")
logger = get_logger(__name__)
def _init_adapter(
model: PreTrainedModel,
model_args: ModelArguments,
finetuning_args: FinetuningArguments,
is_trainable: bool,
is_mergeable: bool
) -> PreTrainedModel:
r"""
Initializes the adapters.
Support full-parameter, freeze and LoRA training.
Note that the trainable parameters must be cast to float32.
"""
if finetuning_args.finetuning_type == "none" and is_trainable:
raise ValueError("You cannot use finetuning_type=none while training.")
if finetuning_args.finetuning_type == "full":
logger.info("Fine-tuning method: Full")
model = model.float()
if finetuning_args.finetuning_type == "freeze":
logger.info("Fine-tuning method: Freeze")
for name, param in model.named_parameters():
if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers):
param.requires_grad_(False)
else:
param.data = param.data.to(torch.float32)
if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora":
assert is_mergeable and len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
else:
assert is_mergeable or len(model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint."
if finetuning_args.finetuning_type == "lora":
logger.info("Fine-tuning method: LoRA")
lastest_checkpoint = None
if model_args.checkpoint_dir is not None:
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
"The given checkpoint is not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
else:
checkpoints_to_merge = model_args.checkpoint_dir
for checkpoint in checkpoints_to_merge:
model = PeftModel.from_pretrained(model, checkpoint)
model = model.merge_and_unload()
if len(checkpoints_to_merge) > 0:
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
if lastest_checkpoint is not None: # resume lora training or quantized inference
model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=is_trainable)
if is_trainable and lastest_checkpoint is None: # create new lora weights while training
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=finetuning_args.lora_rank,
lora_alpha=finetuning_args.lora_alpha,
lora_dropout=finetuning_args.lora_dropout,
target_modules=finetuning_args.lora_target
)
model = get_peft_model(model, lora_config)
if model_args.checkpoint_dir is not None:
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
return model
def load_pretrained(
model_args: ModelArguments,
finetuning_args: FinetuningArguments,
is_trainable: Optional[bool] = False,
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
r"""
Loads pretrained model and tokenizer.
Support both training and inference.
"""
if (not is_trainable) and model_args.checkpoint_dir is None:
logger.warning("Checkpoint is not found at evaluation, load the original model.")
finetuning_args = FinetuningArguments(finetuning_type="none")
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
"RM and PPO training can only be performed with the LoRA method."
config_kwargs = {
"trust_remote_code": True,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
}
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
padding_side="left",
**config_kwargs
)
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
is_mergeable = True
# Quantization configurations (using bitsandbytes library).
if model_args.quantization_bit is not None:
if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["load_in_8bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0
)
elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3")
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
config_kwargs["load_in_4bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type
)
is_mergeable = False
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
if not is_trainable: # `device_map=auto` should be used for inference only
config_kwargs["device_map"] = "auto"
# Load and prepare pretrained models (without valuehead).
model = Auto
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
收起资源包目录
基于Ziya-LLaMA-13B-V1的中医古籍知识问答大模型.zip (52个子文件)
HuangDI-main
src
__init__.py 0B
export_model.py 608B
train_rm.py 2KB
utils
__init__.py 502B
peft_trainer.py 5KB
data_collator.py 3KB
seq2seq.py 4KB
model.py 3KB
pairwise.py 2KB
template.py 4KB
common.py 24KB
ppo.py 9KB
__pycache__
peft_trainer.cpython-311.pyc 9KB
config.cpython-311.pyc 18KB
pairwise.cpython-311.pyc 4KB
other.cpython-38.pyc 7KB
__init__.cpython-311.pyc 1KB
seq2seq.cpython-38.pyc 4KB
common.cpython-311.pyc 30KB
ppo.cpython-311.pyc 14KB
ppo.cpython-38.pyc 7KB
seq2seq.cpython-311.pyc 8KB
config.cpython-38.pyc 11KB
other.cpython-311.pyc 13KB
__init__.cpython-37.pyc 797B
common.cpython-37.pyc 15KB
template.cpython-38.pyc 4KB
template.cpython-311.pyc 6KB
common.cpython-38.pyc 15KB
data_collator.cpython-311.pyc 5KB
peft_trainer.cpython-38.pyc 5KB
data_collator.cpython-38.pyc 3KB
__init__.cpython-38.pyc 801B
pairwise.cpython-38.pyc 3KB
Knowledge_utils.py 2KB
other.py 7KB
vocab_utils.py 5KB
config.py 11KB
cli_demo.py 2KB
web_demo.py 5KB
train_sft.py 4KB
train_ppo.py 3KB
train_pt.py 3KB
api_demo.py 3KB
data
示例.json 42KB
configs
default_config.yaml 384B
infer_config.yaml 385B
scripts
stf_guji.sh 1KB
cli_demo.sh 528B
guji_pretrain.sh 910B
guji_pretrain-1.sh 760B
guji_evaluation.sh 868B
共 52 条
- 1
资源评论
博士僧小星
- 粉丝: 2224
- 资源: 5988
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- C183579-123578-c1235789.jpg
- Qt5.14 绘画板 Qt Creator C++项目
- python实现Excel表格合并
- Java实现读取Excel批量发送邮件.zip
- 【java毕业设计】商城后台管理系统源码(springboot+vue+mysql+说明文档).zip
- 【java毕业设计】开发停车位管理系统(调用百度地图API)源码(springboot+vue+mysql+说明文档).zip
- 星耀软件库(升级版).apk.1
- 基于Django后端和Vue前端的多语言购物车项目设计源码
- 基于Python与Vue的浮光在线教育平台源码设计
- 31129647070291Eclipson MXS R.zip
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功