#基于Kera环境实现的
from __future__ import print_function, division
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import sys
import numpy as np
class GAN():
def __init__(self):
self.img_rows = 28
self.img_cols = 28
self.channels = 1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
# 生成原始噪点数据大小
self.latent_dim = 100
optimizer = Adam(0.0002, 0.5)
# Build and compile the discriminator
# 1、建立判别器训练参数
# 选择损失,优化器,以及衡量准确率
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
# Build the generator
# 2、联合建立生成器训练参数,指定生成器损失
self.generator = self.build_generator()
# The generator takes noise as input and generates imgs
z = Input(shape=(self.latent_dim,))
img = self.generator(z)
# For the combined model we will only train the generator
# 合并模型的损失,并且之后只训练生成器,判别器不训练
self.discriminator.trainable = False
# The discriminator takes generated images as input and determines validity
validity = self.discriminator(img)
# The combined model (stacked generator and discriminator)
# Trains the generator to fool the discriminator
# 训练生成器欺骗判别器
self.combined = Model(z, validity)
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
def build_generator(self):
model = Sequential()
model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
model.add(Reshape(self.img_shape))
model.summary()
noise = Input(shape=(self.latent_dim,))
img = model(noise)
return Model(noise, img)
def build_discriminator(self):
model = Sequential()
model.add(Flatten(input_shape=self.img_shape))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='sigmoid'))
model.summary()
img = Input(shape=self.img_shape)
validity = model(img)
return Model(img, validity)
def train(self, epochs, batch_size=128, sample_interval=50):
# Load the dataset
(X_train, _), (_, _) = mnist.load_data()
# Rescale -1 to 1
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)
# Adversarial ground truths
# 正负样本的目标值建立
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
# ---------------------
# Train Discriminator
# ---------------------
# Select a random batch of images
# 1、训练判别器
# 选择随机的一些真实样本
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
# 生成器产生假样本
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# Generate a batch of new images
gen_imgs = self.generator.predict(noise)
# Train the discriminator
# 训练判别器过程
d_loss_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
# 计算平均两部分损失
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# ---------------------
# Train Generator
# ---------------------
# 2、训练生成器,停止判别器
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
# 合并训练,并停止训练判别器
# 用目标值为1去训练,目的使得生成器生成的样本越来越接近真是样本
# Train the generator (to have the discriminator label samples as valid)
g_loss = self.combined.train_on_batch(noise, valid)
# Plot the progress
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
# If at save interval => save generated image samples
if epoch % sample_interval == 0:
self.sample_images(epoch)
def sample_images(self, epoch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
gen_imgs = self.generator.predict(noise)
# Rescale images 0 - 1
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
axs[i,j].axis('off')
cnt += 1
fig.savefig("images/%d.png" % epoch)
plt.close()
if __name__ == '__main__':
gan = GAN()
gan.train(epochs=30000, batch_size=32, sample_interval=200)
没有合适的资源?快使用搜索试试~ 我知道了~
GAN对抗网络(原始版).zip
共153个文件
png:150个
gitignore:2个
py:1个
1.该资源内容由用户上传,如若侵权请联系客服进行举报
2.虚拟产品一经售出概不退款(资源遇到问题,请及时私信上传者)
2.虚拟产品一经售出概不退款(资源遇到问题,请及时私信上传者)
版权申诉
0 下载量 198 浏览量
2023-08-19
19:31:37
上传
评论
收藏 13.47MB ZIP 举报
温馨提示
GAN对抗神经网络的实现
资源推荐
资源详情
资源评论
收起资源包目录
GAN对抗网络(原始版).zip (153个子文件)
.gitignore 13B
.gitignore 13B
200.png 213KB
0.png 175KB
400.png 159KB
600.png 155KB
800.png 155KB
1000.png 154KB
1200.png 152KB
1600.png 145KB
1400.png 144KB
1800.png 142KB
2200.png 140KB
2000.png 139KB
3200.png 133KB
2600.png 130KB
2400.png 128KB
3600.png 126KB
3800.png 125KB
2800.png 124KB
3400.png 124KB
3000.png 123KB
4200.png 123KB
4000.png 121KB
4800.png 119KB
4400.png 118KB
4600.png 116KB
5000.png 116KB
5200.png 113KB
5600.png 112KB
9400.png 109KB
5400.png 108KB
5800.png 108KB
6200.png 106KB
7800.png 106KB
10600.png 106KB
6400.png 105KB
7000.png 105KB
8000.png 105KB
7600.png 105KB
8200.png 105KB
9200.png 104KB
9800.png 104KB
9600.png 104KB
8800.png 103KB
6800.png 103KB
6000.png 103KB
7400.png 103KB
8400.png 101KB
7200.png 101KB
6600.png 101KB
11000.png 100KB
10400.png 100KB
9000.png 100KB
13600.png 99KB
11800.png 98KB
11400.png 98KB
14200.png 98KB
15000.png 97KB
13800.png 97KB
20800.png 96KB
8600.png 96KB
15800.png 96KB
10000.png 96KB
24400.png 96KB
17000.png 95KB
12200.png 94KB
22000.png 94KB
16800.png 94KB
24600.png 94KB
14600.png 94KB
11200.png 94KB
15600.png 93KB
21800.png 93KB
13000.png 93KB
18400.png 92KB
12600.png 92KB
16600.png 92KB
12800.png 92KB
25000.png 92KB
25200.png 91KB
19600.png 91KB
21200.png 91KB
12000.png 91KB
25800.png 90KB
28800.png 90KB
15200.png 90KB
24200.png 90KB
20000.png 90KB
10800.png 89KB
28400.png 89KB
25400.png 89KB
22400.png 89KB
18000.png 89KB
17200.png 89KB
16200.png 89KB
17600.png 89KB
20400.png 88KB
13200.png 88KB
26000.png 88KB
共 153 条
- 1
- 2
资源评论
sjx_alo
- 粉丝: 1w+
- 资源: 1208
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功