from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import numpy as np
import utils
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
sess = tf.InteractiveSession()
batch_size = 256
noise_scale = 0.5
def weights_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
#strides 第一个参数:batch上的步长,第二个:height上的步长,第三个:weights上的步长,第四个:channel上的步长
def conv2d(x, W):
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
def deconv2d(x, W, output):
return tf.nn.conv2d_transpose(x, W, output, strides=[1, 2, 2, 1], padding='SAME')
#ksize 第一个参数:batch上的池化,第二个:height上的池化,第三个:weights上的池化,第四个:channel上的池化
def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
x = tf.placeholder(tf.float32, [None, 784], name='input')
x_noise = tf.placeholder(tf.float32, [None, 784], name='input')
#y_ = tf.placeholder(tf.float32, [None, 10])
#-1代表样本数量不固定,最后的1代表颜色通道数量
x_image = tf.reshape(x, [-1, 28, 28, 1])
x_image_noise = tf.reshape(x_noise, [-1, 28, 28, 1])
#编码
W_conv1 = weights_variable([3, 3, 1, 64])
b_conv1 = bias_variable([64])
h_conv1 = tf.nn.relu(conv2d(x_image_noise, W_conv1) + b_conv1, name='noise_layer')
h_pool1 = max_pool_2x2(h_conv1)
W_conv2 = weights_variable([3, 3, 64, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)
W_conv3 = weights_variable([3, 3, 64, 32])
b_conv3 = bias_variable([32])
h_conv3 = tf.nn.relu(conv2d(h_pool2, W_conv3) + b_conv3)
h_pool3 = max_pool_2x2(h_conv3)
#解码
W_deconv1 = weights_variable([3, 3, 32, 32])
h_deconv1 = deconv2d(h_pool3, W_deconv1, [batch_size, 7, 7, 32])
W_deconv2 = weights_variable([3, 3, 64, 32])
h_deconv2 = deconv2d(h_deconv1, W_deconv2, [batch_size, 14, 14, 64])
W_deconv3 = weights_variable([3, 3, 64, 64])
h_deconv3 = deconv2d(h_deconv2, W_deconv3, [batch_size, 28, 28, 64])
#卷积层
W_conv_final = weights_variable([3, 3, 64, 1])
b_conv_final = bias_variable([1])
h_conv_final = tf.nn.bias_add(conv2d(h_deconv3, W_conv_final), b_conv_final, name='output_layer')
#输出图像
output_img = tf.reshape(h_conv_final, shape=[-1, 28, 28, 1])
output_img_fomat = utils.convert2int(output_img)
output = tf.reshape(h_conv_final, shape=[-1, 784])
input = tf.reshape(x, shape=[-1, 784])
#训练
cost = tf.reduce_mean(tf.pow(tf.subtract(output, input), 2.0)) #计算平方误差
train_steps = tf.train.AdamOptimizer(0.001).minimize(cost)
with tf.name_scope('images'):
tf.summary.image('input', x_image, 1)
tf.summary.image('gaussian', x_image_noise, 1)
tf.summary.image('reconstruction', (output_img_fomat), 1)
merged = tf.summary.merge_all()
n_samples = int(mnist.train.num_examples)
print('train samples: %d' % n_samples)
print('batch size: %d' % batch_size)
total_batch = int(n_samples / batch_size)
print('total batchs: %d' % total_batch)
init = tf.global_variables_initializer()
sess.run(init)
train_epochs = 35
log_dir = 'logs/mnist_with_summaries'
test_writer = tf.summary.FileWriter(log_dir + '/test')
saver = tf.train.Saver()
for epoch in range(train_epochs):
for batch_index in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
noise_x = batch_xs + noise_scale * np.random.randn(*batch_xs.shape)
noise_x = np.clip(noise_x, 0., 1.)
_, train_loss, summaries = sess.run([train_steps, cost, merged], feed_dict={x: batch_xs, x_noise: noise_x})
print("epoch: %04d\tbatch: %04d\ttrain loss: %.9f" % (epoch + 1, batch_index + 1, train_loss))
test_writer.add_summary(summaries, epoch * total_batch + batch_index)
saver.save(sess, './checkpoint_dir/CAEmodel')
test_writer.close()
n_test_samples = int(mnist.test.num_examples)
test_total_batch = int(n_test_samples / batch_size)
评论1
最新资源