clear;
clc;
% -----------加载数据
load('mnist_uint8', 'train_x');
train_x = double(reshape(train_x,60000,28,28))/255;
train_x = permute(train_x,[1,3,2]);
train_x = reshape(train_x,60000,784);
% -----------------定义模型
generator = GAN_nnsetup([100,512,784]);
discriminator = GAN_nnsetup([784,512,1]);
% -----------开始训练
batch_size = 60;
epoch = 1;%原来是100
images_num = 60000;
batch_num = ceil(images_num / batch_size);
learning_rate = 0.001;
for e=1:epoch
kk = randperm(images_num);
for t=1:batch_num
% 准备数据
images_real = train_x(kk((t - 1) * batch_size + 1:t * batch_size), :, :);
noise = unifrnd(-1, 1, batch_size, 100);
% 开始训练
% -----------更新generator,固定discriminator
generator = GAN_nnff(generator, noise);
images_fake = generator.layers{generator.layers_count}.a;
discriminator = GAN_nnff(discriminator, images_fake);
logits_fake = discriminator.layers{discriminator.layers_count}.z;
discriminator = nnbp_d(discriminator, logits_fake, ones(batch_size, 1));
generator = nnbp_g(generator, discriminator);
generator = GAN_nnapplygrade(generator, learning_rate);
% -----------更新discriminator,固定generator
generator = GAN_nnff(generator, noise);
images_fake = generator.layers{generator.layers_count}.a;
images = [images_fake;images_real];
discriminator = GAN_nnff(discriminator, images);
logits = discriminator.layers{discriminator.layers_count}.z;
labels = [zeros(batch_size,1);ones(batch_size,1)];
discriminator = nnbp_d(discriminator, logits, labels);
discriminator = GAN_nnapplygrade(discriminator, learning_rate);
% ----------------输出loss
if t == batch_num
c_loss = sigmoid_cross_entropy(logits(1:batch_size), ones(batch_size, 1));
d_loss = sigmoid_cross_entropy(logits, labels);
fprintf('c_loss:"%f",d_loss:"%f"\n',c_loss, d_loss);
end
if t == batch_num
path = ['C:\Users\Administrator\Documents\MATLAB\pictures/epoch_',int2str(e),'_t_',int2str(t),'.png'];
save_images(images_fake, [4, 4], path);
fprintf('save_sample:%s\n', path);
end
end
end
%%链接:https://www.jianshu.com/p/3f89e95f891d
%%來源:简书
%%简书著作权归作者所有,任何形式的转载都请联系作者获得授权并注明出处。