这个仓库是[字节跳动比赛 Byte Cup 2018,自动生成新闻标题](https://biendata.com/competition/bytecup2018/),水滴队(最终成绩23名)的代码。
有关比赛的总结,欢迎移步博客:https://blog.csdn.net/taoyafan/article/details/84879285
## Requirements
python3.5 或以上
tensorflow 1.12(接近的几个版本应该也可以)
## 程序说明
修改自程序[**RLSeq2Seq**](https://github.com/yaserkl/RLSeq2Seq)
主要改动如下:
(1)更换 python 版本为 python3。
(2)对 policy gradient 部分进行了大量的修改,原程序存在很多错误,如计算 ROUGE 时没有将 decode mask 去掉,前向计算时 greedy 和 sample 没有分开,decode 的输入也是一样的。
(3)训练的同时增加 eval,并保存在验证集效果最好的最后三个模型。
(4)增加对[**pointer-generator**](https://github.com/abisee/pointer-generator)的模型的兼容性,可以直接使用其预训练模型。
(5)对 policy gradient 的修改,将论文[A Deep Reinforced Model for Abstractive Summarization](https://arxiv.org/abs/1705.04304)中的公式(15)改为
$$
L_{rl}=(r(y^{s}) - r(y^{g}))\sum_{t=1}^{n'}{\rm log}p(y_{t}^{g}|y_{1}^{g},...y_{t-1}^{g},x)
$$
即将 sample 得到的结果当做 baseline,根据 greedy 得到的结果对来计算梯度。
## 使用说明
参考程序[**RLSeq2Seq**](https://github.com/yaserkl/RLSeq2Seq) 和 [**pointer-generator**](https://github.com/abisee/pointer-generator),他们介绍的很清楚,只是这里只用在比赛中,生成的标题较短,且数据集来自官方。
数据的预处理参考 cnn-dailymail 中的 [**make_datafiles.ipynb**](https://github.com/taoyafan/cnn-dailymail/tree/master/bytecup)。
### 文件说明
[src](https://github.com/taoyafan/abstractive_summarization/tree/master/src) 中为源代码,其中[run_summarization.py](https://github.com/taoyafan/abstractive_summarization/blob/master/src/run_summarization.py)为主程序。
[results](https://github.com/taoyafan/abstractive_summarization/tree/master/results)中为不同模型的运行命令(参数)。
### 参数说明
基本命令同[**RLSeq2Seq**](https://github.com/yaserkl/RLSeq2Seq),增加参数如下:
| 参数 | 说明 |
| -------------------------- | ------------------------------------------------------------ |
| convert_version_old_to_new | 为 True 时可加载 [**pointer-generator**](https://github.com/abisee/pointer-generator) 提供的预训练模型 |
| eval_data_path | 验证集路径 |
| dropout_keep_p | 在 encoder 和 decoder 的 LSTM 的 cell 中增加 drop out,对 input、output 和 state 使用相同的 keep_p,默认为1,即不使用 drop out |
| rising_greedy_r | 为 True 时 policy gradient 使用更改后的公式,即目标为提升 greedy 得到的 reward,为 False 时为原公式,但是占用显存增大一倍 |
### 运行说明
在 results 中寻找对应模型的命令,如基准模型 [base_eta=0_lr=0.15.txt](https://github.com/taoyafan/abstractive_summarization/blob/master/results/base_eta%3D0_lr%3D0.15.txt)
在训练时执行命令:
```
python3 run_summarization.py --mode=train --data_path=../finished_files/chunked/train* --eval_data_path=../finished_files/chunked/test* --vocab_path=../finished_files/vocab --log_root=../log --exp_name=base_eta=0_lr=0.15 --batch_size=20 --use_temporal_attention=False --intradecoder=False --eta=0 --rl_training=True --lr=0.15 --sampling_probability=0 --fixed_eta=True --scheduled_sampling=True --fixed_sampling_probability=True --greedy_scheduled_sampling=True
```
经过实验 drop out 取 0.8 时效果最好,不过最终没来得及使用,最终成绩所使用的模型为:
基础模型(有pointer_gen,无coverage,无rl),然后使用 policy gradient。
没有合适的资源?快使用搜索试试~ 我知道了~
Byte Cup 2018国际机器学习竞赛 23 名(水滴队)代码.zip
共77个文件
py:19个
pyc:19个
txt:17个
需积分: 5 0 下载量 120 浏览量
2024-05-08
10:09:35
上传
评论
收藏 963KB ZIP 举报
温馨提示
Byte Cup 2018国际机器学习竞赛 23 名(水滴队)代码.zip
资源推荐
资源详情
资源评论
收起资源包目录
Byte Cup 2018国际机器学习竞赛 23 名(水滴队)代码.zip (77个子文件)
content
src
__init__.py 0B
.DS_Store 8KB
replay_buffer.pyc 12KB
helper
newsroom_data_maker.py 7KB
cnn_dm_data_maker.py 4KB
README.rst 4KB
cnn_dm_data_merger.py 5KB
cnn_dm_downloader.py 3KB
filter_files.txt 665KB
decode.py 12KB
rouge_tensor.pyc 7KB
util.py 2KB
attention_decoder.py 39KB
inspect_checkpoint.py 1KB
attention_decoder.pyc 22KB
replay_buffer.py 10KB
model.py 50KB
beam_search.py 10KB
rouge_tensor.py 7KB
dqn.py 6KB
data.pyc 12KB
data.py 13KB
batcher.py 18KB
rouge.py 11KB
file_spliter.py 843B
batcher.pyc 15KB
__pycache__
beam_search.cpython-35.pyc 8KB
util.cpython-35.pyc 2KB
replay_buffer.cpython-35.pyc 10KB
dqn.cpython-35.pyc 6KB
batcher.cpython-35.pyc 15KB
decode.cpython-35.pyc 10KB
attention_decoder.cpython-35.pyc 21KB
rouge_tensor.cpython-35.pyc 6KB
model.cpython-35.pyc 34KB
__init__.cpython-35.pyc 175B
data.cpython-35.pyc 11KB
rouge.cpython-35.pyc 10KB
model.pyc 36KB
rouge.pyc 11KB
run_summarization.py 54KB
filter_files.txt 665KB
.idea
vcs.xml 180B
results
rising_sample_r .txt 1KB
rising_greedy_r.txt 1KB
base_eta=0_lr=0.15.txt 1KB
base_dropout_0 .8 5.txt 1KB
history
lr=0.0001 383B
base_no_temporal_attention 1KB
baseline.txt 1KB
base_eta=0_lr=0.15 1KB
risng_greedy_rouge_lr=0.15 415B
risng_greedy_rouge_lr=1 403B
lr=0.05 375B
noself_critic_lr=1 411B
scripts.txt 6KB
base_p=1_no_temporal.txt 1KB
lr=10 367B
rising_greedy_r_no_temporal_after_pre-train.txt 1KB
base_sample_p=1 1KB
nointradecoder-notemporal-withpretraining-after-RL.txt 1KB
lr=0.15 415B
avg_reward_lr=0.15 388B
lr=1 363B
lr=100 371B
sample_p=1_after_pretrain 1KB
policy_gradient 3KB
sample_p=1_lr=0.15 1KB
AC_DDQN 2KB
code_test 419B
nointradecoder-notemporal-withpretraining-before-RL.txt 1KB
base_dropout_0 .8.txt 1KB
new_rising_greedy_r_lr=1.5.txt 2KB
base_dropout_0 .7.txt 1KB
base_dropout_0 .9.txt 1KB
rising_greedy_r_dropout_0.9.txt 1KB
README.md 4KB
共 77 条
- 1
资源评论
生瓜蛋子
- 粉丝: 3828
- 资源: 5775
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- 数据库管理工具:dbeaver-ce-23.1.5-stable.x86-64.rpm
- 以下是一些适用于英语六级作文的万能句型模板,涵盖了引言、正文和结论部分的各类表达方式.docx
- MATLAB中的非线性规划
- 进行C语言面试资格确认是招聘过程中一个重要的步骤,目的是确保候选人具备足够的C语言编程能力和知识.docx
- Java 轻量级的集群负载均衡设计
- 纹身师个人网站模板.jpg
- 在C语言中,连接两个字符串(即将一个字符串附加到另一个字符串的末尾)通常可以使用标准库中的 `strcat` 函数.docx
- 数据库管理工具:dbeaver-ce-23.1.1-stable.x86-64.rpm
- 以下是几个具体竞赛题目的详细解答,包括建模思路、方法和步骤 .docx
- 一份关于全国大学生建模大赛的相关教程!!
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功