# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT finetuning runner."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import csv
import os
import modeling
import optimization
import tokenization
import tensorflow as tf
flags = tf.flags
FLAGS = flags.FLAGS
## Required parameters
flags.DEFINE_string(
"data_dir", None,
"The input data dir. Should contain the .tsv files (or other data files) "
"for the task.")
flags.DEFINE_string(
"bert_config_file", None,
"The config json file corresponding to the pre-trained BERT model. "
"This specifies the model architecture.")
flags.DEFINE_string("task_name", None, "The name of the task to train.")
flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.")
flags.DEFINE_string(
"output_dir", None,
"The output directory where the model checkpoints will be written.")
## Other parameters
flags.DEFINE_string(
"init_checkpoint", None,
"Initial checkpoint (usually from a pre-trained BERT model).")
flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
flags.DEFINE_integer(
"max_seq_length", 128,
"The maximum total input sequence length after WordPiece tokenization. "
"Sequences longer than this will be truncated, and sequences shorter "
"than this will be padded.")
flags.DEFINE_bool("do_train", False, "Whether to run training.")
flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.")
flags.DEFINE_bool(
"do_predict", False,
"Whether to run the model in inference mode on the test set.")
flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.")
flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
flags.DEFINE_float("num_train_epochs", 3.0,
"Total number of training epochs to perform.")
flags.DEFINE_float(
"warmup_proportion", 0.1,
"Proportion of training to perform linear learning rate warmup for. "
"E.g., 0.1 = 10% of training.")
flags.DEFINE_integer("save_checkpoints_steps", 1000,
"How often to save the model checkpoint.")
flags.DEFINE_integer("iterations_per_loop", 1000,
"How many steps to make in each estimator call.")
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
tf.flags.DEFINE_string(
"tpu_name", None,
"The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
"url.")
tf.flags.DEFINE_string(
"tpu_zone", None,
"[Optional] GCE zone where the Cloud TPU is located in. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata.")
tf.flags.DEFINE_string(
"gcp_project", None,
"[Optional] Project name for the Cloud TPU-enabled project. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata.")
tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
flags.DEFINE_integer(
"num_tpu_cores", 8,
"Only used if `use_tpu` is True. Total number of TPU cores to use.")
class InputExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b=None, label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
class PaddingInputExample(object):
"""Fake example so the num input examples is a multiple of the batch size.
When running eval/predict on the TPU, we need to pad the number of examples
to be a multiple of the batch size, because the TPU requires a fixed batch
size. The alternative is to drop the last batch, which is bad because it means
the entire output data won't be generated.
We use this class instead of `None` because treating `None` as padding
battches could cause silent errors.
"""
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self,
input_ids,
input_mask,
segment_ids,
label_id,
is_real_example=True):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id
self.is_real_example = is_real_example
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_test_examples(self, data_dir):
"""Gets a collection of `InputExample`s for prediction."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
@classmethod
def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with tf.gfile.Open(input_file, "r") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:
lines.append(line)
return lines
def _read_csv(cls, input_file, quotechar=None):
"""Reads a comma separated value file."""
with tf.gfile.Open(input_file, "r") as f:
reader = csv.reader(f, delimiter=",", quotechar=quotechar)
lines = []
for line in reader:
lines.append(line)
return lines
class XnliProcessor(DataProcessor):
"""Processor for the XNLI data set."""
def __init__(self):
self.language = "zh"
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(
os.path.join(data_dir, "multinli",
"multinli.train.%s.tsv" % self.language))
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "train-%d" % (i)
text_a = tokenization.convert_to_unicode(line[0])
text_b = tokenization.convert_to_unicode(line[1])
label = tokenization.convert_to_unicode(line[2])
if label == tokenization.convert_to_unicode("contradictory"):
label = tokenization.convert_to_unicode("contradiction")
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path
没有合适的资源?快使用搜索试试~ 我知道了~
twitter预处理和bert分类代码.rar
共2个文件
py:2个
5星 · 超过95%的资源 需积分: 13 12 下载量 84 浏览量
2019-08-27
16:38:10
上传
评论 1
收藏 10KB RAR 举报
温馨提示
压缩包中包含两个代码,一个是对twitter数据进行预处理的,另一个是个人修改的用bert对twitter数据分类的run_classifier.py代码,希望可以帮助到大家。
资源详情
资源评论
资源推荐
收起资源包目录
twitter预处理和bert分类代码.rar (2个子文件)
run_classifier.py 37KB
twitter_preprocess.py 2KB
共 2 条
- 1
中科小白
- 粉丝: 14
- 资源: 5
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功
评论1