#!/usr/bin/env python
# -*- coding: utf-8 -*-
import argparse
from loguru import logger
import pandas as pd
import time
import os
import sys
from textgen.t5 import T5Model
from pydantic import BaseModel
class T5ModelInputParams(BaseModel):
taskType: str
query: str
context: str = "nan"
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--train_file', default='data/t5MultiTask.xlsx', type=str, help='Training data file')
parser.add_argument('--model_type', default='t5', type=str, help='Transformers model type')
parser.add_argument('--model_name', default='models/mengzi-t5-base', type=str, help='Transformers model or path')
parser.add_argument('--do_train', action='store_true', help='Whether to run training.')
parser.add_argument('--do_predict', action='store_true', help='Whether to run predict.')
parser.add_argument('--do_test', action='store_true', help='Whether to run test.')
parser.add_argument('--prefix', default='QA', type=str, help='Prefix str')
parser.add_argument('--output_dir', default='outputs/mengzi_t5_zh/', type=str, help='Model output directory')
parser.add_argument('--max_seq_length', default=200, type=int, help='Input max sequence length')
parser.add_argument('--max_length', default=200, type=int, help='Output max sequence length')
parser.add_argument('--num_epochs', default=50, type=int, help='Number of training epochs')
parser.add_argument('--batch_size', default=32, type=int, help='Batch size')
args = parser.parse_args()
logger.info(args)
if args.do_train:
logger.info('Loading data...')
# train_data: Pandas DataFrame containing the 3 columns - `prefix`, `input_text`, `target_text`.
# - `prefix`: A string indicating the task to perform. (E.g. `"question"`, `"stsb"`)
# - `input_text`: The input text. `prefix` is prepended to form the full input. (<prefix>: <input_text>)
# - `target_text`: The target sequence
all_data_df = pd.read_excel(args.train_file)
all_data_df["input_text"] = all_data_df.apply(lambda x: f'已知:{x["context"]},请回答问题:{x["input_text"]}', axis=1)
logger.debug('train_data: {}'.format(all_data_df[:10]))
train_df = all_data_df
eval_df = all_data_df[7:33]
model_args = {
"reprocess_input_data": True,
"overwrite_output_dir": True,
"max_seq_length": args.max_seq_length,
"max_length": args.max_length,
"train_batch_size": args.batch_size,
"num_train_epochs": args.num_epochs,
"save_eval_checkpoints": False,
"save_model_every_epoch": False,
"evaluate_generated_text": True,
"evaluate_during_training": True,
"evaluate_during_training_verbose": True,
"use_multiprocessing": False,
"save_best_model": True,
"output_dir": args.output_dir,
"use_early_stopping": True,
"best_model_dir": os.path.join(args.output_dir, "best_model"),
}
# model_type: t5 model_name: Langboat/mengzi-t5-base
model = T5Model(args.model_type, args.model_name, args=model_args)
def sim_text_chars(text1, text2):
if not text1 or not text2:
return 0.0
same = set(text1) & set(text2)
m = len(same)
n = len(set(text1)) if len(set(text1)) > len(set(text2)) else len(set(text2))
return m / n
def count_matches(labels, preds):
logger.debug(f"labels: {labels[:10]}")
logger.debug(f"preds: {preds[:10]}")
match = sum([sim_text_chars(label, pred) for label, pred in zip(labels, preds)]) / len(labels)
logger.debug(f"match: {match}")
return match
model.train_model(train_df, eval_data=eval_df, matches=count_matches)
print(model.eval_model(eval_df, matches=count_matches))
if args.do_predict:
model = T5Model(args.model_type, args.output_dir, args={"eval_batch_size": args.batch_size})
sentences = ["什么是ai", "你是什么类型的计算机", "你知道热力学吗"]
sentences_add_prefix = [args.prefix + ": " + i for i in sentences]
print("inputs:", sentences)
print("outputs:", model.predict(sentences_add_prefix))
sentences_add_prefix = sentences_add_prefix * 50
t1 = time.time()
res = model.predict(sentences_add_prefix)
print(type(res), len(res))
logger.info(f'spend time: {time.time() - t1}, size: {len(sentences_add_prefix)}')
if args.do_test:
model = T5Model(args.model_type, args.output_dir, args={"eval_batch_size": args.batch_size})
while True:
taskType = input("请输入任务类型:")
query = input("请输入查询问题:")
context = input("请输入任务类型:")
if query == "exit":
break
sentence = f'已知:{context},请回答问题:{query}'
sentence_add_prefix = [taskType + ": " + sentence]
print("outputs:", model.predict(sentence_add_prefix))
if __name__ == '__main__':
main()