Text Classification
-------------------------------------------------------------------------
the purpose of this repository is to explore text classification methods in NLP with deep learning.
UPDATE: if you want to try a model now, you can go to folder 'a02_TextCNN', run 'python -u p7_TextCNN_train.py', it will use sample data to train a model, and print loss and F1 score periodically.
it has all kinds of baseline models for text classificaiton.
it also support for multi-label classification where multi label associate with an sentence or document.
although many of these models are simple, and may not get you to top level of the task.but some of these models are very classic, so they may be good to serve as baseline models.
each model has a test function under model class. you can run it to performance toy task first. the model is indenpendent from dataset.
<a href='https://github.com/brightmart/text_classification/blob/master/multi-label-classification.pdf'>check here for formal report of large scale multi-label text classification with deep learning</a>
serveral modes here can also be used for modelling question answering (with or without context), or to do sequences generating.
we explore two seq2seq model(seq2seq with attention,transformer-attention is all you need) to do text classification. and these two models can also be used for sequences generating and other tasks. if you task is a multi-label classification, you can cast the problem to sequences generating.
we implement two memory network. one is dynamic memory network. previously it reached state of art in question answering, sentiment analysis and sequence generating tasks. it is so called one model to do serveral different tasks, and reach high performance. it has four modules. the key component is episodic memory module. it use gate mechanism to performance attention, and use gated-gru to update episode memory, then it has another gru( in a vertical direction) to pefromance hidden state update. it has ability to do transitive inference.
the second memory network we implemented is recurrent entity network: tracking state of the world. it has blocks of key-value pairs as memory, run in parallel, which achieve new state of art. it can be used for modelling question answering with contexts(or history). for example, you can let the model to read some sentences(as context), and ask a question(as query), then ask the model to predict an answer; if you feed story same as query, then it can do classification task.
if you need some sample data and word embedding pertrained on word2vec, you can find it in closed issues, such as:<a href="https://github.com/brightmart/text_classification/issues/3">issue 3</a>.
you can also find some sample data at folder "data". it contains two files:'sample_single_label.txt', contains 50k data with single label; 'sample_multiple_label.txt', contains 20k data with multiple labels. input and label of is separate by " __label__".
if you want to know more detail about dataset of text classification or task these models can be used, one of choose is below:
https://biendata.com/competition/zhihu/
Models:
-------------------------------------------------------------------------
1) fastText
2) TextCNN
3) TextRNN
4) RCNN
5) Hierarchical Attention Network
6) seq2seq with attention
7) Transformer("Attend Is All You Need")
8) Dynamic Memory Network
9) EntityNetwork:tracking state of the world
10) Ensemble models
11) Boosting:
for a single model, stack identical models together. each layer is a model. the result will be based on logits added together. the only connection between layers are label's weights. the front layer's prediction error rate of each label will become weight for the next layers. those labels with high error rate will have big weight. so later layer's will pay more attention to those mis-predicted labels, and try to fix previous mistake of former layer. as a result, we will get a much strong model.
check a00_boosting/boosting.py
and other models:
1) BiLstmTextRelation;
2) twoCNNTextRelation;
3) BiLstmTextRelationTwoRNN
Performance
-------------------------------------------------------------------------
(mulit-label label prediction task,ask to prediction top5, 3 million training data,full score:0.5)
Model | fastText|TextCNN|TextRNN| RCNN | HierAtteNet|Seq2seqAttn|EntityNet|DynamicMemory|Transformer
--- | --- | --- | --- |--- |--- |--- |--- |--- |----
Score | 0.362 | 0.405| 0.358 | 0.395| 0.398 |0.322 |0.400 |0.392 |0.322
Training| 10m | 2h |10h | 2h | 2h |3h |3h |5h |7h
--------------------------------------------------------------------------------------------------
Ensemble of TextCNN,EntityNet,DynamicMemory: 0.411
Ensemble EntityNet,DynamicMemory: 0.403
--------------------------------------------------------------------------------------------------
Notice:
`m` stand for **minutes**; `h` stand for **hours**;
`HierAtteNet` means Hierarchical Attention Networkk;
`Seq2seqAttn` means Seq2seq with attention;
`DynamicMemory` means DynamicMemoryNetwork;
`Transformer` stand for model from 'Attention Is All You Need'.
Useage:
-------------------------------------------------------------------------------------------------------
1) model is in `xxx_model.py`
2) run python `xxx_train.py` to train the model
3) run python `xxx_predict.py` to do inference(test).
Each model has a test method under the model class. you can run the test method first to check whether the model can work properly.
-------------------------------------------------------------------------
Environment:
-------------------------------------------------------------------------------------------------------
python 2.7+ tensorflow 1.1
(tensorflow 1.2,1.3,1.4 also works; most of models should also work fine in other tensorflow version, since we use very few features bond to certain version; if you use python 3.5, it will be fine as long as you change print/try catch function)
TextCNN model is already transfomed to python 3.6
-------------------------------------------------------------------------
Notice:
-------------------------------------------------------------------------------------------------------
Some util function is in data_util.py;
typical input like: "x1 x2 x3 x4 x5 __label__ 323434" where 'x1,x2' is words, '323434' is label;
it has a function to load and assign pretrained word embedding to the model,where word embedding is pretrained in word2vec or fastText.
Models Detail:
-------------------------------------------------------------------------
1.fastText:
-------------
implmentation of <a href="https://arxiv.org/abs/1607.01759">Bag of Tricks for Efficient Text Classification</a>
after embed each word in the sentence, this word representations are then averaged into a text representation, which is in turn fed to a linear classifier.it use softmax function to compute the probability distribution over the predefined classes. then cross entropy is used to compute loss. bag of word representation does not consider word order. in order to take account of word order, n-gram features is used to capture some partial information about the local word order; when the number of classes is large, computing the linear classifier is computational expensive. so it usehierarchical softmax to speed training process.
1) use bi-gram and/or tri-gram
2) use NCE loss to speed us softmax computation(not use hierarchy softmax as original paper)
result: performance is as good as paper, speed also very fast.
check: p5_fastTextB_model.py
![alt text](https://github.com/brightmart/text_classification/blob/master/images/fastText.JPG)
-------------------------------------------------------------------------
2.TextCNN:
-------------
Implementation of <a href="http://www.aclweb.org/anthology/D14-1181"> Co