'''
Created on 2018年6月14日
@author: Administrator
'''
import residual_unit.ResUnit_block as ResUnit_block
import residual_unit.ResUnit_bottleneck as ResUnit_bottleneck
import tensorflow as tf
import numpy as np
import random
#读取cifar_10的数据,是一个dict,有b'data'和b'labels'这两个key
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
data_dict = pickle.load(fo, encoding='bytes')
return data_dict
def get_cifar_10_data(filepath):
cifar_10_dict = unpickle(filepath)
x_data = cifar_10_dict[b'data']
origin_y_data = cifar_10_dict[b'labels']
x_data = np.array(x_data).astype(np.float32)
origin_y_data = np.array(origin_y_data).astype(np.int32)
y_data = np.zeros([origin_y_data.shape[0],10])
for i in range(origin_y_data.shape[0]):
y_data[i,origin_y_data[i]] = 1
y_data = y_data.astype(np.float32)
return x_data,y_data
def shuffle_data():
x_data_1,y_data_1 = get_cifar_10_data('data/data_batch_1')
x_data_2,y_data_2 = get_cifar_10_data('data/data_batch_2')
x_data_3,y_data_3 = get_cifar_10_data('data/data_batch_3')
x_data_4,y_data_4 = get_cifar_10_data('data/data_batch_4')
x_data_5,y_data_5 = get_cifar_10_data('data/data_batch_5')
x_data_6,y_data_6 = get_cifar_10_data('data/test_batch')
x_data = np.r_[x_data_1,x_data_2,x_data_3,x_data_4,x_data_5,x_data_6]
y_data = np.r_[y_data_1,y_data_2,y_data_3,y_data_4,y_data_5,y_data_6]
x_data = x_data.reshape([x_data.shape[0],-1])
y_data = y_data.reshape([y_data.shape[0],-1])
#打乱数据
shuffle_index = random.sample(range(x_data.shape[0]),x_data.shape[0])
x_data = x_data[shuffle_index,:]
y_data = y_data[shuffle_index,:]
train_x = x_data[0:50000,:]
train_y = y_data[0:50000,:]
test_x = x_data[50000:60000,:]
test_y = y_data[50000:60000,:]
return train_x,train_y,test_x,test_y;
def next_batch(x_data,y_data,position,batch_size):
x_data = np.array(x_data)
x_data = x_data.reshape(x_data.shape[0],-1)
y_data = np.array(y_data)
y_data = y_data.reshape(y_data.shape[0],-1)
if position<=0:
position = 0
if (position+batch_size)>=x_data.shape[0]:
batch_x_data = x_data[position:x_data.shape[0],:]
batch_y_data = y_data[position:y_data.shape[0],:]
position = 0
else:
batch_x_data = x_data[position:position+batch_size,:]
batch_y_data = y_data[position:position+batch_size,:]
position = position+batch_size
return batch_x_data,batch_y_data,position
#权重初始化函数
def weight_variable(shape,stddev=0.1):
initial = tf.truncated_normal(shape,mean=0, stddev=stddev)#产生一个截断的正态分布,产生的值与mean值的差距不会大于2倍stddev(标准差)
return tf.Variable(initial)
#偏置项初始化函数
def bias_variable(shape,stddev=0.1):
initial = tf.truncated_normal(shape,mean=0, stddev=stddev)
return tf.Variable(initial)
def neural_network_train(train_x,train_y,test_x,test_y,batch_size,learning_rate,training_number,lamda=0.01):
x = tf.placeholder('float',[None,32*32*3], 'rgb_image_data')
y = tf.placeholder('float',[None,10],'rgb_image_label')#y是真实的标签
training_flag = tf.placeholder(tf.bool)
x_image = tf.reshape(x,[-1,32,32,3])
net = ResUnit_bottleneck(x_image, out_channels=16, bottleneck_channels=8, bn_training_flag = training_flag, pool_flag = False, batch_norm_flag = True, lamda = lamda)
net = ResUnit_bottleneck(net, out_channels=32, bottleneck_channels=16, bn_training_flag = training_flag, pool_flag = True, batch_norm_flag = True, lamda = lamda)
net = ResUnit_bottleneck(net, out_channels=32, bottleneck_channels=16, bn_training_flag = training_flag, pool_flag = False, batch_norm_flag = True, lamda = lamda)
net = ResUnit_bottleneck(net, out_channels=64, bottleneck_channels=32, bn_training_flag = training_flag, pool_flag = True, batch_norm_flag = True, lamda = lamda)
net = ResUnit_bottleneck(net, out_channels=64, bottleneck_channels=32, bn_training_flag = training_flag, pool_flag = False, batch_norm_flag = True, lamda = lamda)
# net = ResUnit_block(x_image, filters=16, bn_training_flag = training_flag, pool_flag = False, batch_norm_flag = True,lamda = lamda)
# net = ResUnit_block(net, filters=16, bn_training_flag = training_flag, pool_flag = False, batch_norm_flag = True,lamda = lamda)
# net = ResUnit_block(net, filters=16, bn_training_flag = training_flag, pool_flag = False, batch_norm_flag = True,lamda = lamda)
#
# net = ResUnit_block(net, filters=32, bn_training_flag = training_flag, pool_flag = True, batch_norm_flag = True,lamda = lamda)
# net = ResUnit_block(net, filters=32, bn_training_flag = training_flag, pool_flag = False, batch_norm_flag = True,lamda = lamda)
# net = ResUnit_block(net, filters=32, bn_training_flag = training_flag, pool_flag = False, batch_norm_flag = True,lamda = lamda)
#
# net = ResUnit_block(net, filters=64, bn_training_flag = training_flag, pool_flag = True, batch_norm_flag = True,lamda = lamda)
# net = ResUnit_block(net, filters=64, bn_training_flag = training_flag, pool_flag = False, batch_norm_flag = True,lamda = lamda)
# net = ResUnit_block(net, filters=64, bn_training_flag = training_flag, pool_flag = False, batch_norm_flag = True,lamda = lamda)
# net = ResUnit_block(net, filters=64, bn_training_flag = training_flag, pool_flag = False, batch_norm_flag = True,lamda = lamda)
#flaten
net_flat = tf.reshape(net,[-1,8*8*64])
#分类
W = weight_variable([8*8*64,10], stddev=np.sqrt(2.0/(10)))
b = bias_variable([10], stddev=np.sqrt(2.0/(10)))
predict = tf.nn.softmax(tf.matmul(net_flat,W)+b)
global_step = tf.Variable(0.0)
lr = tf.train.exponential_decay(learning_rate, global_step = global_step, decay_steps = 5000, decay_rate=0.95, staircase = True)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels = y, logits = predict))+tf.contrib.layers.l2_regularizer(scale=lamda)(W)
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y,axis=1),tf.argmax(predict,axis=1)),'float'))
optimizer = tf.train.AdamOptimizer(lr)
#加入了batch normalization后,要加入update_ops,这样才能计算μ和σ的滑动平均(测试时会用到)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train = optimizer.minimize(loss,global_step=global_step)
saver = tf.train.Saver()
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
batch_position = 0
#在重新开始训练之前判断是否有模型存在,存在则基于该模型开始训练
if tf.train.latest_checkpoint("data/model_sample_cnn_bn_wd_ResNet") is not None:
saver.restore(sess, tf.train.latest_checkpoint("data/model_sample_cnn_bn_wd_ResNet"))
for train_step in range(training_number):
batch_x_data,batch_y_data,batch_position = next_batch(train_x, train_y, batch_position, batch_size)
sess.run(train,feed_dict = {x:batch_x_data,y:batch_y_data,training_flag:True})
if train_step%500==0:
print(train_step,sess.run([loss,accuracy],feed_dict = {x:batch_x_data,y:batch_y_data,training_flag:False}),end=",")
test_accuracy = sess.run(accuracy,feed_dict = {x:test_x,y:test_y,training_flag:False})
print("test_accuracy:",test_accuracy)
saver.save(sess, 'data/model_sample_cnn_bn_wd_ResNet/model.ckpt')
if test_accuracy>=0.9:
break
tf.reset_default_graph()