# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa)."""
from __future__ import absolute_import, division, print_function
import argparse
import glob
import logging
import os
import random
import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from transformers import DataProcessor, InputExample, InputFeatures
try:
from torch.utils.tensorboard import SummaryWriter
except:
from tensorboardX import SummaryWriter
from tqdm import tqdm, trange
from transformers import (WEIGHTS_NAME, BertConfig,
BertForSequenceClassification, BertTokenizer,
RobertaConfig,
RobertaForSequenceClassification,
RobertaTokenizer,
XLMConfig, XLMForSequenceClassification,
XLMTokenizer, XLNetConfig,
XLNetForSequenceClassification,
XLNetTokenizer,
DistilBertConfig,
DistilBertForSequenceClassification,
DistilBertTokenizer,
AlbertConfig,
AlbertForSequenceClassification,
AlbertTokenizer,
)
from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import glue_compute_metrics as compute_metrics
from transformers import glue_output_modes as output_modes
from transformers import glue_processors as processors
from transformers import glue_convert_examples_to_features as convert_examples_to_features
from sklearn.metrics import f1_score
logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig,
RobertaConfig, DistilBertConfig)), ())
MODEL_CLASSES = {
'bert': (BertConfig, BertForSequenceClassification, BertTokenizer),
'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
'albert': (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer)
}
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def simple_accuracy(preds, labels):
return (preds == labels).mean()
def acc_and_f1(preds, labels):
acc = simple_accuracy(preds, labels)
f1 = f1_score(y_true=labels, y_pred=preds,average='weighted')
return {
"acc": acc,
"f1": f1,
"acc_and_f1": (acc + f1) / 2,
}
class CnesProcessor(DataProcessor):
"""Processor for the cnews data set (GLUE version)."""
def get_example_from_tensor_dict(self, tensor_dict):
"""See base class."""
return InputExample(tensor_dict['idx'].numpy(),
tensor_dict['sentence'].numpy().decode('utf-8'),
None,
str(tensor_dict['label'].numpy()))
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "cnews.train.txt")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "cnews.test.txt")), "dev")
def get_labels(self):
"""See base class."""
return ["体育",
"娱乐",
"家居",
"房产",
"教育",
"时尚",
"时政",
"游戏",
"科技",
"财经"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = line[1]
label = line[0]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
processors["cnews"] = CnesProcessor
output_modes["cnews"] = "classification"
def train(args, train_dataset, model, tokenizer):
""" Train the model """
if args.local_rank in [-1, 0]:
tb_writer = SummaryWriter('./runs/bert')
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
if args.max_steps > 0:
t_total = args.max_steps
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
else:
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
'weight_decay': args.weight_decay},
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps,
num_training_steps=t_total)
if args.fp16:
try:
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
# multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
output_device=args.local_rank,
find_unused_parameters=True)
# Train!
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
logger.info(" Total train batch size (w. parallel, distr
onnx
- 粉丝: 9972
- 资源: 5626
最新资源
- 基于Go和React的高质量企业微信私域流量管理系统 。遵守Apache2.0协议,全网唯一免费商用。企业微信、私域流量、SCRM。详细文档+优秀项目+全部资料.zip
- 基于Go和React的企业微信的开源 SCRM 系统。企微,个微,微信,聚合聊天,自动回复,加好友。详细文档+优秀项目+全部资料.zip
- 基于go语言的搜索引擎,信息检索系统详细文档+优秀项目+全部资料.zip
- 基于Go语言,Beego框架开发的简单的文章发布管理系统。详细文档+优秀项目+全部资料.zip
- 基于Go语言的轻量级高性能的分布式日志系统详细文档+优秀项目+全部资料.zip
- 基于go语言简单的用户管理系统详细文档+优秀项目+全部资料.zip
- 基于Go语言的一个秒杀系统详细文档+优秀项目+全部资料.zip
- 基于Go语言的一个简易blog系统详细文档+优秀项目+全部资料.zip
- 基于Go语言开发,具备高安全性、高性能和易扩展性的企业级内容管理系统,详细文档+优秀项目+全部资料.zip
- 基于Go语言实现的单点登录系统详细文档+优秀项目+全部资料.zip
- 基于Go语言实现的微服务电商系统详细文档+优秀项目+全部资料.zip
- 基于python和sql server2014做的商品学生信息管理系统,有界面,利用的sql server数据库详细文档+优秀项目+全部资料.zip
- 基于OWTP协议库,封装所有与openw-server钱包服务API交互方法。用于集成到go语言开发下的应用方系统。详细文档+优秀项目+全部资料.zip
- 基于Quasar和Go语言的功能全面的管理系统详细文档+优秀项目+全部资料.zip
- 基于SpringBoot 开发的员工的季度绩效考核系统详细文档+优秀项目+全部资料.zip
- 四足机器人技术发展及其应用场景概述
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈