# 人工智能导论期末大作业报告
## AI 对联
1. 背景
微软亚洲研究院:用户输入上联,电脑根据上联给出相应的下联供选择,选择后再给出横批,最后生成对联图片可供保存下载。
- ![](https://www.writebug.com/myres/static/uploads/2021/11/12/b20fe182f6a6b6ddfb830bed2468dea7.writebug)
2. 数据集极其处理:
1. 数据来源:数据来源于一位新浪博主编辑的一部名为《联语杂酱面》的书,训练集包括 70 万副对联,测试集包括 4000 副对联。数据格式为每行一个上联或者夏凉的文本文档。
2. 数据处理过程:主要是文本文档的读取和字典的一些操作。首先是一些库的引入和参数设置,然后开始读取数据和处理。
```python
import codecs
import numpy as np
from keras.models import Model
from keras.layers import *
from keras.callbacks import Callback
min_count = 2
maxlen = 16
batch_size = 64
char_size = 128
train_input_path = 'datasets/train_in.txt'
train_output_path = 'datasets/train_out.txt'
test_input_path = 'datasets/test_in.txt'
test_output_path = 'datasets/test_out.txt'
```
1. 按行读入:按照空格把每个对联分成字的集合。
```python
def read_data(txtname):
txt = codecs.open(txtname, encoding='utf-8').readlines()
txt = [line.strip().split(' ') for line in txt] # 每行按空格切分
txt = [line for line in txt if len(line) <= maxlen] # 过滤掉字数超过maxlen的对联
return txt
x_train_txt = read_data(train_input_path)
y_train_txt = read_data(train_output_path)
x_test_txt = read_data(test_input_path)
y_test_txt = read_data(test_output_path)
```
![](https://www.writebug.com/myres/static/uploads/2021/11/12/20d9658f176a7ae5a398a41b126303be.writebug)
2. 保存到字典:统计每个字在数据集中占的个数。
```python
chars = {}
for txt in [x_train_txt,y_train_txt,x_test_txt,y_test_txt]:
for line in txt:
for word in line:
chars[word] = chars.get(word,0) + 1
c = 0
##查看字典数据
for word,count in chars.items():
if c <=5:
print(word,count)
c = c+1
```
![](https://www.writebug.com/myres/static/uploads/2021/11/12/a8d9928ab667e3bf7e79402ad36b08af.writebug)
3. 映射为数字:设置一个映射函数,把字相应的映射为数字。
```python
chars = {word:count for word,count in chars.items() if count >= min_count}
id2char = {word_id+1:word for word_id,word in enumerate(chars)}
# 更换一下key-value的位置
char2id = {word:word_id for word_id,word in id2char.items()}
def string2id(char_list):
return [char2id.get(char,0) for char in char_list]
x_train = list(map(string2id, x_train_txt))
y_train = list(map(string2id, y_train_txt))
x_test = list(map(string2id, x_test_txt))
y_test = list(map(string2id, y_test_txt))
```
![](https://www.writebug.com/myres/static/uploads/2021/11/12/6b78d350d38d241063539b1d31bf808f.writebug)
4. 整合上下联:用二维数组保存上下联的数字对应,并转换成 ndarray 类型,至此数据处理完成
```python
def generate_count_dict(result_dict,x,y):
for i,charIDlist in enumerate(x):
j = len(charIDlist)
if j not in result_dict:
result_dict[j] = [[],[]] # [存放长度为j的上联的字的匹配编号,对应的下联]
result_dict[j][0].append(charIDlist)
result_dict[j][1].append(y[i])
return result_dict
train_dict = {}
test_dict = {}
train_dict = generate_count_dict(train_dict, x_train, y_train)
test_dict = generate_count_dict(test_dict,x_test,y_test)
print('共有{}种不同的字数'.format(len(train_dict.keys())))
for wordCount,[data,y] in train_dict.items():
print('字数',wordCount,':对应上联x的个数:',len(data),'下联y的个数',len(y))
```
![](https://www.writebug.com/myres/static/uploads/2021/11/12/c8a0e5dd3d01ecedef426a93be257228.writebug)
3. 模型建立和训练
1. 模型介绍:两个 Conv1D 形式一样,卷积核数窗口大小都一样,但是权值不共享,也就是说参数翻倍了,其中一个用 sigmoid 函数激活,另外一个不加激活函数然后将它们逐位相乘。解决了梯度消失并且使信息能够在多通道传输。
![](https://www.writebug.com/myres/static/uploads/2021/11/12/463da1de9c206aada5c0e7f9bd6bbc5a.writebug)
2. 模型搭建:函数 data_generator()用于后面作为模型的参数,用 yield 生成器来传入模型进行训练。门卷积层用 keras 的卷积层和匿名函数来实现,最后模型的搭建设置了一个输入层,一个嵌入层,一个防止过拟合的 Dropout 层,六个门卷积层,最后一个全连接层。
```python
# 构建batch大小的数据集
# 随机抽取生成大小为batch的上联与下联的数据集
# data: train_dict or test_dict
def data_generator(data):
# 计算每个对联长度的权重
data_probability = [float(len(x)) for wordcount,[x,y] in data.items()] #[每个字数key(1-16)对应对联list中上联数据的个数]
print(data_probability)
data_probability = np.array(data_probability) / sum(data_probability) #标准化至[0,1],这是每个字数的权重
# 随机选择字数,然后随机选择字数对应的上联样本,生成batch
while True:
# 随机选字数id,概率为上面计算的字数权重
idx = np.random.choice(len(data_probability), p = data_probability) + 1
size = min(batch_size, len(data[idx][0])) # batch_size=64,len(data[idx][0])随机选择的字数key对应的上联个数
# 从上联列表下标list中随机选出大小为size的list
idxs = np.random.choice(len(data[idx][0]), size = size)
# 返回选出的上联x与下联y, 将原本1-d array维度扩展为(row,col,1)
yield data[idx][0][idxs], np.expand_dims(data[idx][1][idxs],axis=2)
# return data[idx][0][idxs], np.expand_dims(data[idx][1][idxs],axis=2)
```
```python
# 门卷积模块
def gated_resnet(x, ksize=3):
# 门卷积 + 残差
x_dim = K.int_shape(x)[-1]
xo = Conv1D(x_dim*2, ksize, padding='same')(x)
return Lambda(lambda x: x[0] * K.sigmoid(x[1][..., :x_dim]) + x[1][..., x_dim:] * K.sigmoid(-x[1][..., :x_dim]))([x, xo])
x_in = Input(shape=(None,))
x = x_in
x = Embedding(len(chars)+1, char_size)(x)
x = Dropout(0.25)(x)
x = gated_resnet(x)
x = gated_resnet(x)
x = gated_resnet(x)
x = gated_resnet(x)
x = gated_resnet(x)
x = gated_resnet(x)
x = Dense(len(chars)+1, activation='softmax')(x)
```
3. 模型训练:用 Evaluate 来训练模型,模型训练过程中传入三个测试的对联以展示模型随着训练的结果而逐渐优化的效果。
```python
def couplet_match(s):
# 输出对联
# 先验知识:跟上联同一位置的字不能一样
x = np.array([string2id(s)]) # 上联-->id array
y = model.predict(x)[0]
for i,j in enumerate(x[0]):
y[i, j] = 0.
y = y[:, 1:].argmax(axis=1) + 1
r = ''.join([id2char[i] for i in y])
print('上联:%s,下联:%s' % (s, r))
return r
class Evaluate(Callback):
def __init__(self):
self.lowest = 1e10
def on_epoch_end(self, epoch, logs=None):
# 训练过程中观察几个例子,显示对联质量提高的过程
couplet_match(u'晚风摇树树还挺')
couplet_match(u'今天天气不错')
couplet_ma
神仙别闹
- 粉丝: 4311
- 资源: 7532
最新资源
- 童心党史小程序-微信小程序毕业项目,适合计算机毕-设、实训项目、大作业学习.zip
- 速达物流信息查询微信小程序设计与实现ssm-微信小程序毕业项目,适合计算机毕-设、实训项目、大作业学习.rar
- 小区租拼车管理信息系统+ssm-微信小程序毕业项目,适合计算机毕-设、实训项目、大作业学习.rar
- 无中介租房系统+ssm-微信小程序毕业项目,适合计算机毕-设、实训项目、大作业学习.rar
- 校友会系统的实现+ssm-微信小程序毕业项目,适合计算机毕-设、实训项目、大作业学习.rar
- 微信点餐系统-微信小程序毕业项目,适合计算机毕-设、实训项目、大作业学习.zip
- 校友林微信小程序+ssm-微信小程序毕业项目,适合计算机毕-设、实训项目、大作业学习.rar
- 校园二手数码交易平台+ssm-微信小程序毕业项目,适合计算机毕-设、实训项目、大作业学习.rar
- 微信点餐系统小程序ssm-微信小程序毕业项目,适合计算机毕-设、实训项目、大作业学习.rar
- 校园工会体育报名系统+ssm-微信小程序毕业项目,适合计算机毕-设、实训项目、大作业学习.rar
- 外卖小程序ssm-微信小程序毕业项目,适合计算机毕-设、实训项目、大作业学习.rar
- 校园顺路代送微信小程序ssm-微信小程序毕业项目,适合计算机毕-设、实训项目、大作业学习.rar
- 校园服务平台+ssm-微信小程序毕业项目,适合计算机毕-设、实训项目、大作业学习.rar
- 基于STM32开发的数字频率计项目 组成部分:时基电路,整形电路,调节电路,信号输入 实现功能:测量信号输入幅度1-5v方波,频率为1khz-10khz测量精度1%,信号输出 当输入信号大于15v
- 校园约拍微信小程序设计与实现ssm-微信小程序毕业项目,适合计算机毕-设、实训项目、大作业学习.rar
- 校园资讯平台微信小程序+ssm-微信小程序毕业项目,适合计算机毕-设、实训项目、大作业学习.rar
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈