from collections import defaultdict
import time
from joblib import Parallel, delayed
from multiprocessing import cpu_count
from math import ceil
import torch
from torch import nn
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from nltk.corpus import stopwords
from transformers import BertTokenizer, AdamW, get_linear_schedule_with_warmup
import numpy as np
import os
import shutil
import sys
from tqdm import tqdm
from model import LOTClassModel
import warnings
import nltk
#nltk.download('stopwords')
warnings.filterwarnings("ignore")
class LOTClassTrainer(object):
def __init__(self, args):
self.args = args
self.max_len = args.max_len
self.dataset_dir = args.dataset_dir
self.dist_port = args.dist_port
self.num_cpus = min(10, cpu_count() - 1) if cpu_count() > 1 else 1
self.world_size = args.gpus
self.train_batch_size = args.train_batch_size
self.eval_batch_size = args.eval_batch_size
self.accum_steps = args.accum_steps
eff_batch_size = self.train_batch_size * self.world_size * self.accum_steps
assert abs(eff_batch_size - 64) < 10, f"Make sure the effective training batch size is around 128, current: {eff_batch_size}"
print(f"Effective training batch size: {eff_batch_size}")
self.pretrained_lm = 'bert-base-uncased'
self.tokenizer = BertTokenizer.from_pretrained(self.pretrained_lm, do_lower_case=True)
self.vocab = self.tokenizer.get_vocab()
self.vocab_size = len(self.vocab)
self.mask_id = self.vocab[self.tokenizer.mask_token]
self.inv_vocab = {k:v for v, k in self.vocab.items()}
self.read_label_names(args.dataset_dir, args.label_names_file)
self.num_class = len(self.label_name_dict)
self.model = LOTClassModel.from_pretrained(self.pretrained_lm,
output_attentions=False,
output_hidden_states=False,
num_labels=self.num_class)
self.read_data(args.dataset_dir, args.train_file, args.test_file, args.test_label_file)
self.with_test_label = True if args.test_label_file is not None else False
self.temp_dir = f'tmp_{self.dist_port}'
self.mcp_loss = nn.CrossEntropyLoss()
self.st_loss = nn.KLDivLoss(reduction='batchmean')
self.update_interval = args.update_interval
self.early_stop = args.early_stop
# set up distributed training
def set_up_dist(self, rank):
dist.init_process_group(
backend='gloo',
init_method=f'tcp://localhost:{self.dist_port}',
world_size=self.world_size,
rank=rank
)
# create local model
model = self.model.to(rank)
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
return model
# get document truncation statistics with the defined max length
def corpus_trunc_stats(self, docs):
doc_len = []
for doc in docs:
input_ids = self.tokenizer.encode(doc, add_special_tokens=True)
doc_len.append(len(input_ids))
print(f"Document max length: {np.max(doc_len)}, avg length: {np.mean(doc_len)}, std length: {np.std(doc_len)}")
trunc_frac = np.sum(np.array(doc_len) > self.max_len) / len(doc_len)
print(f"Truncated fraction of all documents: {trunc_frac}")
# convert a list of strings to token ids
def encode(self, docs):
encoded_dict = self.tokenizer.batch_encode_plus(docs, add_special_tokens=True, max_length=self.max_len, padding='max_length',
return_attention_mask=True, truncation=True, return_tensors='pt')
input_ids = encoded_dict['input_ids']
attention_masks = encoded_dict['attention_mask']
return input_ids, attention_masks
# convert list of token ids to list of strings
def decode(self, ids):
strings = self.tokenizer.batch_decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
return strings
# convert dataset into tensors
def create_dataset(self, dataset_dir, text_file, label_file, loader_name, find_label_name=False, label_name_loader_name=None):
loader_file = os.path.join(dataset_dir, loader_name)
if os.path.exists(loader_file):
print(f"Loading encoded texts from {loader_file}")
data = torch.load(loader_file)
else:
print(f"Reading texts from {os.path.join(dataset_dir, text_file)}")
corpus = open(os.path.join(dataset_dir, text_file), encoding="utf-8")
docs = [doc.strip() for doc in corpus.readlines()]
print(f"Converting texts into tensors.")
chunk_size = ceil(len(docs) / self.num_cpus)
chunks = [docs[x:x+chunk_size] for x in range(0, len(docs), chunk_size)]
results = Parallel(n_jobs=self.num_cpus)(delayed(self.encode)(docs=chunk) for chunk in chunks)
input_ids = torch.cat([result[0] for result in results])
attention_masks = torch.cat([result[1] for result in results])
print(f"Saving encoded texts into {loader_file}")
if label_file is not None:
print(f"Reading labels from {os.path.join(dataset_dir, label_file)}")
truth = open(os.path.join(dataset_dir, label_file))
labels = [int(label.strip()) for label in truth.readlines()]
labels = torch.tensor(labels)
data = {"input_ids": input_ids, "attention_masks": attention_masks, "labels": labels}
else:
data = {"input_ids": input_ids, "attention_masks": attention_masks}
torch.save(data, loader_file)
if find_label_name:
loader_file = os.path.join(dataset_dir, label_name_loader_name)
if os.path.exists(loader_file):
print(f"Loading texts with label names from {loader_file}")
label_name_data = torch.load(loader_file)
else:
print(f"Reading texts from {os.path.join(dataset_dir, text_file)}")
corpus = open(os.path.join(dataset_dir, text_file), encoding="utf-8")
docs = [doc.strip() for doc in corpus.readlines()]
print("Locating label names in the corpus.")
chunk_size = ceil(len(docs) / self.num_cpus)
chunks = [docs[x:x+chunk_size] for x in range(0, len(docs), chunk_size)]
results = Parallel(n_jobs=self.num_cpus)(delayed(self.label_name_occurrence)(docs=chunk) for chunk in chunks)
input_ids_with_label_name = torch.cat([result[0] for result in results])
attention_masks_with_label_name = torch.cat([result[1] for result in results])
label_name_idx = torch.cat([result[2] for result in results])
assert len(input_ids_with_label_name) > 0, "No label names appear in corpus!"
label_name_data = {"input_ids": input_ids_with_label_name, "attention_masks": attention_masks_with_label_name, "labels": label_name_idx}
loader_file = os.path.join(dataset_dir, label_name_loader_name)
print(f"Saving texts with label names into {loader_file}")
torch.save(label_name_data, loader_file)
return data, label_name_data
else:
return data
# find label name indices and replace out-of-vocab label names with [MASK]
def label_name_in_doc(self, doc):
doc = self.tokenizer.tokenize(doc)
label_idx = -1 * torch.ones(self.max_len, dtype=torch.long)
new_doc = []
wordpcs = []