from tensorflow.keras.layers import Dense,BatchNormalization,LeakyReLU,Conv2DTranspose,Reshape,Conv2D,Dropout,Flatten
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os
# Dataset 中的 buffer
buffer_size = 60000
# 批次大小
batch_size = 256
# 训练周期
epochs = 15 # 51
# 100 维的随机噪声
noise_dim = 100
# 载入 MNNIST 数据,只需要训练集的图片就可以
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
# reshape 为 4 维数据
train_images = train_images.reshape(-1, 28, 28, 1).astype('float32')
# 将图片归一化到 [0, 1] 区间内
train_images = train_images/ 255.0
# 定义 Dataset,用于生成打乱后的批次数据
# tf.data.Dataset.from_tensor_slices(train_images)
# .shuffle(buffer_size)
# .batch(batch_size)
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(buffer_size).batch(batch_size)
# 定义生成器
def generator_model():
# 顺序模型
model = tf.keras.Sequential()
# 传入噪声数据,然后与 7*7*256 个神经元进行全连接
# 7*7*256 主要是为了后面可以 Reshape 变成(7, 7, 256)
model.add(Dense(7 * 7 * 256, input_shape=(noise_dim,)))
model.add(BatchNormalization())
model.add(LeakyReLU())
# 变成 4 维图像数据(-1,7,7,256)
model.add(Reshape((7, 7, 256)))
# 转置卷积,图像 shape 变成(-1,7,7,128)
model.add(Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same'))
model.add(BatchNormalization())
model.add(LeakyReLU())
# 转置卷积,图像 shape 变成(-1,14,14,64)
model.add(Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same'))
model.add(BatchNormalization())
model.add(LeakyReLU())
# 转置卷积,图像 shape 变成(-1,28,28,1)
# 激活函数使用 sigmoid,主要是因为我们把 MNIST 数据图片归一化为[0,1]之间了,生成的假图片要跟真实图片数据匹配
model.add(Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', activation='sigmoid'))
return model
# 定义判别器
def discriminator_model():
# 顺序模型
model = tf.keras.Sequential()
# 传入一张图片数据进行卷积,卷积后图像 shape 为(1,14,14,64)
model.add(Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]))
model.add(LeakyReLU())
model.add(Dropout(0.3))
# 卷积后图像 shape 为(1,7,7,128)
model.add(Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
model.add(LeakyReLU())
model.add(Dropout(0.3))
model.add(Flatten())
# 最后输出一个值,激活函数为 sigmoid 函数,用于判断图片的真假
model.add(Dense(1, activation='sigmoid'))
return model
# 创建生成器模型
generator = generator_model()
# 创建判别器模型
discriminator = discriminator_model()
# 生成随机数
noise = tf.random.normal([1, noise_dim])
# print(noise.shape) # (1, 100)
# 传入生成器生成一张图片
# (1, 100)
generated_image = generator(noise, training=False)
# 显示出图片,刚开始模型还没有训练,所以生成的图片会得到噪声图片
plt.imshow(generated_image[0, :, :, 0], cmap='gray')
plt.show()
# 定义 2 分类交叉熵代价函数
cross_entropy = tf.keras.losses.BinaryCrossentropy()
# 判别器 loss,传入对真实图片的判断结果以及对假图片的判断结果
def discriminator_loss(real_output, fake_output):
# tf.ones_like(real_output)表示对真实图片的判断结果应该全为 1
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
# tf.zeros_like(fake_output)表示对假图片的判断结果应该全为 0
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
# 求总 loss,再返回
total_loss = real_loss + fake_loss
return total_loss
# 生成器 loss,传入判别器对假图片的判断结果
# 对于生成器来说,生成器希望判别器对假图片的判断结果都是 1
# 所以标签设定为 tf.ones_like(fake_output),全为 1
# 生成器模型在训练过程中会不断优化自身参数,使得模型生成逼真的假图片
def generator_loss(fake_output):
return cross_entropy(tf.ones_like(fake_output), fake_output)
# 由于我们需要分别训练两个网络,判别器和生成器的优化器是不同的。
generator_optimizer = tf.keras.optimizers.Adam(3e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
# 把生成器模型和判别器模型以及对应的优化器存入 checkpoint
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
discriminator_optimizer=discriminator_optimizer,
generator=generator,
discriminator=discriminator)
/*
开发不易,整理也不易,如需要详细的说明文档和程序,以及完整的数据集,训练好的模型,或者进一步开发,
可加作者新联系方式咨询,WX:Q3101759565,QQ:3101759565
*/
# 生成图片并保存显示
def generate_and_save_images(model, epoch, test_input):
predictions = model(test_input, training=False)
# 画 16 张子图
for i in range(16):
plt.subplot(4, 4, i + 1)
plt.imshow(predictions[i, :, :, 0], cmap='gray')
plt.axis('off')
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
# 显示图片
plt.show()
# 训练模型
def train(dataset, epochs):
for epoch in range(epochs):
for image_batch in dataset:
train_step(image_batch)
# 显示和保存图片
# generator是在训练的网络
# 这里可见每一个epoch进行生成图片显示
generate_and_save_images(generator, epoch, seed)
# 每 5 个 epoch 保存一次模型
if epoch % 5 == 0:
# checkpoint 为需要保存的内容
# 'checkpoint_dir'为模型保存位置
# max_to_keep 设置最多保留几个模型
# 保存模型
# checkpoint_number 设置模型编号
manager.save(checkpoint_number=epoch)
# 模型训练
train(train_dataset, epochs)