#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Date : 2021-04-22 20:43:06
# @Author : Chenghao Mou (mouchenghao@gmail.com)
"""base covers all the base classes, functions for other embedding based tokenizers."""
import abc
from typing import List, Optional, Union, Dict
from itertools import zip_longest
import numpy as np
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy, TruncationStrategy, TensorType, BatchEncoding, EncodedInput, is_torch_available, to_py_obj, TextInput
def is_torch(x) -> bool: # pragma: no
"""
Helper function to check whether the input is a torch tensor.
Parameters
----------
x : [type]
Input data
Returns
-------
bool
Boolean value indicating whether the input is a torch tensor
"""
import torch
return isinstance(x, torch.Tensor)
class EmbeddingTokenizer(PreTrainedTokenizerBase):
"""
Embedding based tokenizer. It assumes each token is mapped to a tensor instead of an index number.
This implementation borrows most implementation from huggingface's transformers library.
Parameters
----------
model_input_names : Optional[List[str]], optional
Required model input names, by default None
special_tokens : Optional[Dict[str, np.ndarray]], optional
Required model special tokens, by default None
max_length : Optional[int], optional
Maximum sequence length supported by the model, by default 2048
"""
def __init__(
self,
model_input_names: Optional[List[str]] = None,
special_tokens: Optional[Dict[str, np.ndarray]] = None,
max_length: Optional[int] = 2048,
):
self.model_input_names = model_input_names
self.special_tokens = special_tokens
self.max_length = max_length
@abc.abstractmethod
def text2embeddings(self, text: str) -> np.ndarray:
raise NotImplementedError('This function is not implemented')
def __call__(
self,
text: Union[TextInput, List[TextInput]],
text_pair: Optional[Union[TextInput, List[TextInput]]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_length: bool = False,
**kwargs,
) -> BatchEncoding:
"""
Tokenize the text into a sequence of image blocks.
Parameters
----------
text : Union[TextInput, List[TextInput]]
A single text or a list of text
text_pair : Optional[Union[TextInput, List[TextInput]]], optional
A single text or a list of text, by default None
add_special_tokens : bool, optional
Whether to add special tokens to the data, by default True
padding : Union[bool, str, PaddingStrategy], optional
The padding strategy, by default False
truncation : Union[bool, str, TruncationStrategy], optional
The truncation strategy, by default False
max_length : Optional[int], optional
Maximum sequence length, overriding the class variable, by default None
pad_to_multiple_of : Optional[int], optional
Padding parameters, by default None
return_tensors : Optional[Union[str, TensorType]], optional
Return tensors in `pt`, 'tf' or 'np', by default None
return_token_type_ids : Optional[bool], optional
Return token type ids, by default None
return_attention_mask : Optional[bool], optional
Return attention mask, by default None
return_overflowing_tokens : bool, optional
Return overflowing tokens, by default False
return_special_tokens_mask : bool, optional
Return special token mask, by default False
return_length : bool, optional
Return length, by default False
Returns
-------
BatchEncoding
A BatchEncoding object
"""
if self.special_tokens is None:
self.special_tokens = {
"CLS": self.text2embeddings("[CLS]"),
"SEP": self.text2embeddings("[SEP]"),
}
if add_special_tokens and text_pair:
actual_max_length = self.max_length - len(self.special_tokens["SEP"]) * 2 - len(self.special_tokens["CLS"])
else:
actual_max_length = self.max_length
batch_outputs = {}
text = text if isinstance(text, list) else [text]
text_pair = text_pair if isinstance(text_pair, list) else [text_pair]
if isinstance(padding, str):
padding = PaddingStrategy(padding)
if isinstance(truncation, str):
truncation = TruncationStrategy(truncation)
for first_text, second_text in zip_longest(text, text_pair, fillvalue=None):
first_embeddings = self.text2embeddings(first_text)
second_embeddings = self.text2embeddings(second_text)
outputs = self.prepare_for_model(
first_embeddings,
second_embeddings,
add_special_tokens=add_special_tokens,
padding=PaddingStrategy.DO_NOT_PAD, # we pad in batch afterward
truncation=truncation,
max_length=max_length or actual_max_length,
pad_to_multiple_of=None, # we pad in batch afterward
return_attention_mask=False, # we pad in batch afterward
return_token_type_ids=return_token_type_ids,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_length=return_length,
return_tensors=None, # We convert the whole batch to tensors at the end
prepend_batch_axis=False,
)
for key, value in outputs.items():
if key not in batch_outputs:
batch_outputs[key] = []
batch_outputs[key].append(value)
batch_outputs = self.pad(
batch_outputs,
padding=padding,
max_length=max_length or actual_max_length,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
)
batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
return batch_outputs
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
if token_ids_1 is None:
return token_ids_0
return np.concatenate(
[
self.special_tokens["CLS"],
token_ids_0,
self.special_tokens["SEP"],
token_ids_1,
self.special_tokens["SEP"],
],
axis=0
)
def prepare_for_model(
self,
ids: List[int],
pair_ids: Optional[List[int]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_s