import tensorflow as tf
import get_data
import model
import time
#FLAGS的好处在于他定义的是全局变量,可以共享
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('train_data_dir', './data/train',"""存放训练集的路径""")
tf.app.flags.DEFINE_string('train_log_dir', './log/train', """保存训练参数的路径""")
tf.app.flags.DEFINE_integer('max_steps', 6001,"""迭代次数""")
tf.app.flags.DEFINE_integer('log_frequency', 100,"""迭代多少次打印一次训练信息""")
#初始学习率,设置大了会训练失败,建议多试几次,调到一个最大,又不会训练失败的,这样训练才快
INITIAL_LEARNING_RATE = 0.01
#迭代多少步调整一次学习率,这个值需要根据经验设置,一般看损失的变化,如果发现损失一直在变小,没有波动的痕迹,那就设置的大一点
DECAY_STEP=700
#这个变量用于 tf.train.ExponentialMovingAverage,具体效果说不太清楚,可以查资料
MOVING_AVERAGE_DECAY = 0.9999
#学习率每次衰减多少,这个也需要经验设置
LEARNING_RATE_DECAY_FACTOR = 0.8
def loss(logits, labels):
labels = tf.cast(labels, tf.int64)#类型转换函数
#softmax计算损失,这里输出的是个矩阵
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits, name='cross_entropy_per_example')
#取均值,这里求出的才是一个损失值
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
#将变量加入到集合里,方便管理
tf.add_to_collection('losses', cross_entropy_mean)
#获取所有losses的变量
losses = tf.get_collection('losses')
#所有losses加起来,才是总损失
total_loss=tf.add_n(losses, name='total_loss')
#滑步平均法,是为了让参数不被更新得太快,使模型训练更稳定
loss_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, name='avg')
loss_averages_op = loss_averages.apply(losses + [total_loss])#这里就是把需要更新得变量都添加进去
#这个是为了记录数据,方便分析训练想过的
tf.summary.scalar(total_loss.op.name + ' (raw)', total_loss)#保存总损失到本地
tf.summary.scalar(total_loss.op.name, loss_averages.average(total_loss))#保存影子变量到本地
return total_loss,loss_averages_op
def gradient_descent(total_loss,loss_averages_op,global_step):
#计算当前的学习率
lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,global_step,DECAY_STEP,
LEARNING_RATE_DECAY_FACTOR,staircase=True)
tf.summary.scalar('learning_rate', lr)#保存数据
#这里的作用是在其中的操作必须要等括号中的操作[loss_averages_op]搞定先
with tf.control_dependencies([loss_averages_op]):
opt = tf.train.GradientDescentOptimizer(lr)#生成一个梯度下降优化器对象
grads = opt.compute_gradients(total_loss)#计算总损失的梯度
apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)#按梯度下降计算变量下降后的值
variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables())
with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
train_op = tf.no_op(name='train') #这里不做任何操作,只是为了控制变量以上变量都能够运行,因为运行这一步之前必须把上面括号里的变量都执行完
return train_op
def train ():
with tf.Graph().as_default():
#创建一个全局步调(global_step),与图有关,所以需要放在图下面执行
global_step = tf.train.get_or_create_global_step()
#第一步获取数据
#强制该操用CPU执行,避免有时操作在GPU上结束和导致减速(不懂什么意思,大概是不占用GPU的计算资源)
with tf.device('/cpu:0'):
images, labels = get_data.getData(FLAGS.train_data_dir,True)#获取训练时的数据和标签,每次获取一个batch_size
logits=model.model(images)#这里得到的是一个[batch_size,num_class]的矩阵,里面的数字越大,是这一类的概率就越大
total_loss,loss_averages_op=loss(logits,labels)#计算总损失和滑步平均损失
train_op=gradient_descent(total_loss,loss_averages_op,global_step)#梯度下降求值
#一个功能类,其本身不具有功能,专门用来拓展,可以点进去看源码
class _LoggerHook(tf.train.SessionRunHook):
#在图创建之前执行,仅执行一次
def begin(self):
self._step = -1
self._start_time = time.time()
def after_create_session(self, session, coord):
#图创建之后执行
pass
#每个step之前都执行
def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(total_loss) # 这里可以把需要获取的变量添加进去,在下一个函数中可以获取它的值
#每个step之后都执行
def after_run(self, run_context, run_values):
if self._step % FLAGS.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
loss_value = run_values.results#run_value就是获取前面before_run获取的变量的值
format_str = ('step %d, loss = %.2f (use_time =%.1f s )')
print (format_str % (self._step, loss_value,duration))
def end(self, session):
#训练结束后执行
pass
#创建训练会话config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)
with tf.train.MonitoredTrainingSession(checkpoint_dir=FLAGS.train_log_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
tf.train.NanTensorHook(total_loss),_LoggerHook()],
) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
def main(argv=None):
train()
if __name__ == '__main__':
tf.app.run()
没有合适的资源?快使用搜索试试~ 我知道了~
tensorflow下编写CNN网络的框架
共4个文件
py:4个
需积分: 16 22 下载量 151 浏览量
2018-06-08
12:48:09
上传
评论 1
收藏 8KB ZIP 举报
温馨提示
Windows下tensorflow-GPU-1.8的python下的CNN模板,内置的lenet-5模型,我特意把它修改成很容易换成别的网络模型,我是在官网的cifar-10代码的基础上进行改动,里面写了详细的中文注释,我还加上了获取混淆矩阵和分类错误图片的路径的功能,更加方便分析模型性能。
资源推荐
资源详情
资源评论
收起资源包目录
CNN_model.zip (4个子文件)
test.py 5KB
train.py 6KB
get_data.py 4KB
model.py 4KB
共 4 条
- 1
资源评论
lsjweiyi
- 粉丝: 2w+
- 资源: 3
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功