from typing import List, Tuple
import tensorflow as tf
import numpy as np
from ..models import CCLMModelBase, TransformerBlock
from ..preprocessing import Preprocessor
from .core import Pretrainer
class MaskedLanguagePretrainer(tf.keras.Model, Pretrainer):
def __init__(
self,
*args,
n_conv_filters: int = 128,
downsample_factor: int = 4,
n_strided_convs: int = 2,
stride_len: int = 2,
mask_id: int = 1,
learning_rate: float = 0.001,
train_base: bool = True,
training_pool_mode: str = "local",
min_mask_len: int = 3,
num_negatives=5,
**kwargs,
):
"""
Pretrain a model by doing masked language modeling on a corpus.
Model that is trained accepts the base input with shape (batch_size, example_len, base_embedding_dim)
and performs convolutions on the input. The input can be downsampled using strided conv layers,
making the Transformer layer less expensive. The kernel_size of the conv layers
also match the stride_len for simplicity
"""
tf.keras.Model.__init__(self)
self.n_strided_convs = n_strided_convs
self.downsample_factor = downsample_factor
self.training_pool_mode = training_pool_mode
self.min_mask_len = min_mask_len
self.mask_id = mask_id
# calculate stride length to achieve the right downsampling
assert (
stride_len ** n_strided_convs == downsample_factor
), "stride_len^n_strided_convs must equal downsample_factor"
self.stride_len = int(stride_len)
self.n_conv_filters = n_conv_filters
self.pool = tf.keras.layers.GlobalMaxPool1D(dtype="float32")
self.optimizer = tf.optimizers.SGD(learning_rate)
self.num_negatives = num_negatives
self.train_base = train_base
Pretrainer.__init__(self, *args, **kwargs)
# initialize output weights for the sampled softmax or nce loss
self.output_weights = tf.Variable(
tf.random.normal(
[self.preprocessor.tokenizer.get_vocab_size() + 1, self.n_conv_filters]
)
)
self.output_biases = tf.Variable(
tf.zeros([self.preprocessor.tokenizer.get_vocab_size() + 1])
)
# negative sampling head
self.classification_head = tf.keras.Sequential(
[
tf.keras.layers.Dense(128, activation="tanh"),
tf.keras.layers.Dense(1, activation="sigmoid", dtype="float32"),
]
)
self.concat = tf.keras.layers.Concatenate()
self.output_embedding = tf.keras.layers.Embedding(
self.preprocessor.tokenizer.get_vocab_size() + 1, self.n_conv_filters
)
self.pretraining_model = self.get_pretraining_model()
def get_model(
self,
):
"""
Until handled better, inputs need to be padded to a multiple of filter_stride_len*n_strided_convs.
The model uses one or more strided Conv1D to reduce the input shape before passing it to
one or more transformer blocks
"""
# reduce the size, transformer, upsample
layers = [
tf.keras.layers.Conv1D(
self.n_conv_filters,
self.stride_len,
strides=self.stride_len,
padding="same",
activation="tanh",
)
for _ in range(self.n_strided_convs)
]
model = tf.keras.Sequential(
[
tf.keras.layers.Conv1D(
self.n_conv_filters,
self.stride_len,
padding="same",
activation="tanh",
),
tf.keras.layers.Dropout(0.2),
*layers,
tf.keras.layers.Dropout(0.2),
TransformerBlock(embed_dim=self.n_conv_filters),
*[
tf.keras.layers.UpSampling1D(size=self.downsample_factor)
for _ in range(1)
if self.downsample_factor > 1
],
tf.keras.layers.Conv1D(
self.n_conv_filters,
self.stride_len,
padding="same",
activation="tanh",
),
]
)
return model
def can_learn_from(self, example_str: str) -> bool:
"""
Decide whether it's appropriate to learn from this example
"""
# if it's an empty string, skip it
if example_str == "":
return False
# if it's just a [CLS] a few tokens and [SEP], skip it
if len(example_str) < 20:
return False
return True
def batch_from_strs(
self, input_strs: List[str]
) -> Tuple[List[str], List[Tuple[int, int, str]], List[int]]:
"""
Transform input strings into correct-length substrings and pick tokens to mask
"""
batch_inputs: List[str] = []
batch_outputs: List[int] = []
batch_spans: List[Tuple[int, int, str]] = []
tokenizer = self.preprocessor.tokenizer
for example in input_strs:
# subset to a substring of the correct len
# encoded = self.get_substr(encoded)
example = self.get_substr(example)
encoded = tokenizer.encode(example)
# get all tokens that are long enough to be masked
possible_masked_tokens = [
_id
for n, _id in enumerate(encoded.ids)
if len(encoded.tokens[n]) >= self.min_mask_len
]
possible_encoding_indexes = [
n
for n, token in enumerate(encoded.tokens)
if len(token) >= self.min_mask_len
]
# if none, pick a shorter token
if len(possible_masked_tokens) == 0:
possible_masked_tokens = encoded.ids
possible_encoding_indexes = [n for n in range(0, len(encoded.ids))]
# sample a value to index into possible_masked_tokens
masked_token_index = np.random.randint(0, len(possible_masked_tokens))
# look up the value sampled
masked_token_id = possible_masked_tokens[masked_token_index]
start, end = encoded.token_to_chars(
possible_encoding_indexes[masked_token_index]
)
masked_token_len = end - start
inp = (
example[:start]
+ "?" * masked_token_len
+ example[start + masked_token_len :]
)
batch_inputs.append(inp)
batch_spans.append((start, end, example[start : start + masked_token_len]))
batch_outputs.append(masked_token_id)
return batch_inputs, batch_spans, batch_outputs
def generator(self, data: List[str], batch_size: int):
"""
Generator for training purposes
"""
while True:
batch_inputs, batch_outputs, batch_spans, batch_strs = (
[],
[],
[],
[],
)
for n, example in enumerate(data):
example = example.strip()
if not self.can_learn_from(example):
continue
batch_strs.append(example)
if len(batch_strs) == batch_size or (
n + 1 == len(data) and len(batch_strs) > 0
):
batch_inputs, batch_spans, batch_outputs = self.batch_from_strs(
batch_strs
)
x = self.get_batch(batch_inputs, batch_spans=batch_spans)
# x = tf.concat([x, *[x for _ in range(self.num_negatives)]], axis=0)
y = np.array(batch_outputs)
y_sample, t