# TextGAN-PyTorch
TextGAN is a PyTorch framework for Generative Adversarial Networks (GANs) based text generation models, including general text generation models and category text generation models. TextGAN serves as a benchmarking platform to support research on GAN-based text generation models. Since most GAN-based text generation models are implemented by Tensorflow, TextGAN can help those who get used to PyTorch to enter the text generation field faster.
If you find any mistake in my implementation, please let me know! Also, please feel free to contribute to this repository if you want to add other models.
![LICENSE](https://img.shields.io/packagist/l/doctrine/orm.svg)
![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)
## Requirements
- **PyTorch >= 1.1.0**
- Python 3.6
- Numpy 1.14.5
- CUDA 7.5+ (For GPU)
- nltk 3.4
- tqdm 4.32.1
- KenLM (https://github.com/kpu/kenlm)
To install, run `pip install -r requirements.txt`. In case of CUDA problems, consult the official PyTorch [Get Started guide](https://pytorch.org/get-started/locally/).
## KenLM Installation
- Download stable release and unzip: http://kheafield.com/code/kenlm.tar.gz
- Need Boost >= 1.42.0 and bjam
- Ubuntu: `sudo apt-get install libboost-all-dev`
- Mac: `brew install boost; brew install bjam`
- Run *within* kenlm directory:
```bash
mkdir -p build
cd build
cmake ..
make -j 4
```
- `pip install https://github.com/kpu/kenlm/archive/master.zip`
- For more information on KenLM see: https://github.com/kpu/kenlm and http://kheafield.com/code/kenlm/
## Implemented Models and Original Papers
### General Text Generation
- **SeqGAN** - [SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient](https://arxiv.org/abs/1609.05473)
- **LeakGAN** - [Long Text Generation via Adversarial Training with Leaked Information](https://arxiv.org/abs/1709.08624)
- **MaliGAN** - [Maximum-Likelihood Augmented Discrete Generative Adversarial Networks](https://arxiv.org/abs/1702.07983)
- **JSDGAN** - [Adversarial Discrete Sequence Generation without Explicit Neural Networks as Discriminators](http://proceedings.mlr.press/v89/li19g.html)
- **RelGAN** - [RelGAN: Relational Generative Adversarial Networks for Text Generation](https://openreview.net/forum?id=rJedV3R5tm)
- **DPGAN** - [DP-GAN: Diversity-Promoting Generative Adversarial Network for Generating Informative and Diversified Text](https://arxiv.org/abs/1802.01345)
- **DGSAN** - [DGSAN: Discrete Generative Self-Adversarial Network](https://arxiv.org/abs/1908.09127)
- **CoT** - [CoT: Cooperative Training for Generative Modeling of Discrete Data](https://arxiv.org/abs/1804.03782)
### Category Text Generation
- **SentiGAN** - [SentiGAN: Generating Sentimental Texts via Mixture Adversarial Networks](https://www.ijcai.org/proceedings/2018/618)
- **CatGAN** (ours) - [CatGAN: Category-aware Generative Adversarial Networks with Hierarchical Evolutionary Learning for Category Text Generation](https://arxiv.org/abs/1911.06641)
## Get Started
- Get Started
```bash
git clone https://github.com/williamSYSU/TextGAN-PyTorch.git
cd TextGAN-PyTorch
```
- For real data experiments, all datasets (`Image COCO`, `EMNLP NEWs`, `Movie Review`, `Amazon Review`) can be downloaded from [here](https://drive.google.com/drive/folders/1XvT3GqbK1wh3XhTgqBLWUtH_mLzGnKZP?usp=sharing).
- Run with a specific model
```bash
cd run
python3 run_[model_name].py 0 0 # The first 0 is job_id, the second 0 is gpu_id
# For example
python3 run_seqgan.py 0 0
```
## Features
1. **Instructor**
For each model, the entire runing process is defined in `instructor/oracle_data/seqgan_instructor.py`. (Take SeqGAN in Synthetic data experiment for example). Some basic functions like `init_model()`and `optimize()` are defined in the base class `BasicInstructor` in `instructor.py`. If you want to add a new GAN-based text generation model, please create a new instructor under `instructor/oracle_data` and define the training process for the model.
2. **Visualization**
Use `utils/visualization.py` to visualize the log file, including model loss and metrics scores. Custom your log files in `log_file_list`, no more than `len(color_list)`. The log filename should exclude `.txt`.
3. **Logging**
The TextGAN-PyTorch use the `logging` module in Python to record the running process, like generator's loss and metric scores. For the convenience of visualization, there would be two same log file saved in `log/log_****_****.txt` and `save/**/log.txt` respectively. Furthermore, The code would automatically save the state dict of models and a batch-size of generator's samples in `./save/**/models` and `./save/**/samples` per log step, where `**` depends on your hyper-parameters.
4. **Running Signal**
You can easily control the training process with the class `Signal` (please refer to `utils/helpers.py`) based on dictionary file `run_signal.txt`.
For using the `Signal`, just edit the local file `run_signal.txt` and set `pre_sig` to `Fasle` for example, the program will stop pre-training process and step into next training phase. It is convenient to early stop the training if you think the current training is enough.
5. **Automatiaclly select GPU**
In `config.py`, the program would automatically select a GPU device with the least `GPU-Util` in `nvidia-smi`. This feature is enabled by default. If you want to manually select a GPU device, please uncomment the `--device` args in `run_[run_model].py` and specify a GPU device with command.
## Implementation Details
### SeqGAN
- run file: [run_seqgan.py](run/run_seqgan.py)
- Instructors: [oracle_data](instructor/oracle_data/seqgan_instructor.py), [real_data](instructor/real_data/seqgan_instructor.py)
- Models: [generator](models/SeqGAN_G.py), [discriminator](models/SeqGAN_D.py)
- Structure (from [SeqGAN](https://arxiv.org/pdf/1609.05473.pdf))
![model_seqgan](./assets/model_seqgan.png)
### LeakGAN
- run file: [run_leakgan.py](run/run_leakgan.py)
- Instructors: [oracle_data](instructor/oracle_data/leakgan_instructor.py), [real_data](instructor/real_data/leakgan_instructor.py)
- Models: [generator](models/LeakGAN_G.py), [discriminator](models/LeakGAN_D.py)
- Structure (from [LeakGAN](https://arxiv.org/pdf/1709.08624.pdf))
![model_leakgan](assets/model_leakgan.png)
### MaliGAN
- run file: [run_maligan.py](run/run_maligan.py)
- Instructors: [oracle_data](instructor/oracle_data/maligan_instructor.py), [real_data](instructor/real_data/maligan_instructor.py)
- Models: [generator](models/MaliGAN_G.py), [discriminator](models/MaliGAN_D.py)
- Structure (from my understanding)
![model_maligan](assets/model_maligan.png)
### JSDGAN
- run file: [run_jsdgan.py](run/run_jsdgan.py)
- Instructors: [oracle_data](instructor/oracle_data/jsdgan_instructor.py), [real_data](instructor/real_data/jsdgan_instructor.py)
- Models: [generator](models/JSDGAN_G.py) (No discriminator)
- Structure (from my understanding)
![model_jsdgan](assets/model_jsdgan.png)
### RelGAN
- run file: [run_relgan.py](run/run_relgan.py)
- Instructors: [oracle_data](instructor/oracle_data/relgan_instructor.py), [real_data](instructor/real_data/relgan_instructor.py)
- Models: [generator](models/RelGAN_G.py), [discriminator](models/RelGAN_D.py)
- Structure (from my understanding)
![model_relgan](assets/model_relgan.png)
### DPGAN
- run file: [run_dpgan.py](run/run_dpgan.py)
- Instructors: [oracle_data](instructor/oracle_data/dpgan_instructor.py), [real_data](instructor/real_data/dpgan_instructor.py)
- Models: [generator](models/DPGAN_G.py), [discriminator](models/DPGAN_D.py)
- Structure (from [DPGAN](https://arxiv.org/abs/1802.01345))
![model_dpgan](assets/model_dpgan.png)
### DGSAN
- run file: [run_dgsan.py](run/run_dgsan.py)
- Instructors: [oracle_data](instructor/oracle_data/dgsan_instructor.py), [real_data](instructor/real_data/dgsan_instructor.py)
- M
没有合适的资源?快使用搜索试试~ 我知道了~
TextGAN-PyTorch:TextGAN是用于基于生成对抗网络(GAN)的文本生成模型的PyTorch框架
共91个文件
py:77个
png:9个
txt:2个
5星 · 超过95%的资源 需积分: 38 16 下载量 197 浏览量
2021-02-03
18:15:59
上传
评论 8
收藏 1.65MB ZIP 举报
温馨提示
TextGAN-PyTorch TextGAN是用于基于生成对抗网络(GAN)的文本生成模型的PyTorch框架,包括常规文本生成模型和类别文本生成模型。 TextGAN是一个基准测试平台,可支持基于GAN的文本生成模型的研究。 由于大多数基于GAN的文本生成模型都是由Tensorflow实现的,因此TextGAN可以帮助那些习惯PyTorch的人更快地进入文本生成领域。 如果您在执行中发现任何错误,请告诉我! 另外,如果您要添加其他模型,请随时为该存储库做出贡献。 要求 PyTorch> = 1.1.0 Python 3.6 脾气暴躁的1.14.5 CUDA 7.5+(适用于GPU
资源详情
资源评论
资源推荐
收起资源包目录
TextGAN-PyTorch-master.zip (91个子文件)
TextGAN-PyTorch-master
.gitignore 1KB
run_signal.txt 34B
requirements.txt 52B
assets
model_sentigan.png 100KB
model_cot.png 75KB
model_relgan.png 94KB
model_maligan.png 43KB
model_catgan.png 1.11MB
model_seqgan.png 89KB
model_jsdgan.png 26KB
model_leakgan.png 94KB
model_dpgan.png 56KB
run
run_maligan.py 3KB
run_sentigan.py 3KB
run_dpgan.py 3KB
run_catgan.py 4KB
run_jsdgan.py 3KB
run_seqgan.py 3KB
run_dgsan.py 3KB
run_cot.py 3KB
run_relgan.py 4KB
run_leakgan.py 3KB
visual
visual_metric.py 2KB
visual_human.py 2KB
visual_temp_compare.py 3KB
visual_temp_appendix.py 2KB
models
CoT_D.py 843B
RelGAN_G.py 5KB
EvoGAN_G.py 5KB
CoT_G.py 1KB
LeakGAN_G.py 16KB
EvoGAN_D.py 2KB
SeqGAN_D.py 728B
discriminator.py 8KB
MaliGAN_D.py 732B
DPGAN_G.py 1KB
SentiGAN_D.py 1KB
RelGAN_D.py 2KB
SeqGAN_G.py 1KB
MaliGAN_G.py 1KB
DPGAN_D.py 1KB
JSDGAN_G.py 3KB
LeakGAN_D.py 731B
SentiGAN_G.py 2KB
CatGAN_G.py 7KB
relational_rnn_general.py 15KB
DGSAN_G.py 564B
CatGAN_D.py 3KB
Oracle.py 752B
generator.py 4KB
LICENSE 1KB
utils
visualization.py 2KB
data_utils.py 7KB
cat_data_loader.py 6KB
rollout.py 8KB
data_loader.py 4KB
text_process.py 12KB
gan_loss.py 6KB
helpers.py 6KB
README.md 9KB
config.py 14KB
instructor
real_data
seqgan_instructor.py 6KB
dpgan_instructor.py 7KB
cot_instructor.py 5KB
jsdgan_instructor.py 4KB
evogan_instructor.py 17KB
dgsan_instructor.py 5KB
sentigan_instructor.py 10KB
relgan_instructor.py 6KB
maligan_instructor.py 6KB
instructor.py 11KB
catgan_instructor.py 22KB
leakgan_instructor.py 9KB
oracle_data
seqgan_instructor.py 6KB
dpgan_instructor.py 7KB
cot_instructor.py 5KB
jsdgan_instructor.py 4KB
evogan_instructor.py 21KB
dgsan_instructor.py 5KB
sentigan_instructor.py 9KB
relgan_instructor.py 6KB
maligan_instructor.py 6KB
instructor.py 10KB
catgan_instructor.py 26KB
leakgan_instructor.py 9KB
main.py 9KB
metrics
ppl.py 4KB
clas_acc.py 1KB
nll.py 4KB
basic.py 584B
bleu.py 4KB
共 91 条
- 1
西西里上尉
- 粉丝: 26
- 资源: 4667
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- LCD1602电子时钟程序
- 西北太平洋热带气旋【灾害风险统计】及【登陆我国次数评估】数据集-1980-2023
- 全球干旱数据集【自校准帕尔默干旱程度指数scPDSI】-190101-202312-0.5x0.5
- 基于Python实现的VAE(变分自编码器)训练算法源代码+使用说明
- 全球干旱数据集【标准化降水蒸发指数SPEI-12】-190101-202312-0.5x0.5
- C语言小游戏-五子棋-详细代码可运行
- 全球干旱数据集【标准化降水蒸发指数SPEI-03】-190101-202312-0.5x0.5
- spring boot aop记录修改前后的值demo
- 全球干旱数据集【标准化降水蒸发指数SPEI-01】-190101-202312-0.5x0.5
- ActiveReports
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功
评论1