# -*- coding: utf-8 -*-
"""
定义:神经网络的训练过程
Created on Sun Oct 29 15:00:57 2017
@author: 余星星
e-mail:2549721818@qq.com
"""
#导入函数所需要的各种包
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference #导入前向传播函数
import os #导入系统函数
#### 1. 配置神经网络的参数。
BATCH_SIZE = 100 #一次训练的大小
LEARNING_RATE_BASE = 0.8 #基础的学习率
LEARNING_RATE_DECAY = 0.99#学习率的衰减率
REGULARIZATION_RATE = 0.0001#描述模型复杂度的正则化项,在损失函数中的系数
TRAINING_STEPS = 30000 #训练轮数
MOVING_AVERAGE_DECAY = 0.99#滑动平均的衰减值
#文件的保存路径和文件名
MODEL_SAVE_PATH="path/yxx/MNIST_model/"
MODEL_NAME="mnist_model"
###2,训练模型的过程 ,定义神经网络训练函数。
def train(mnist,start):
#2.1,定义输入输出的placeholder,即节点
x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
#计算L2正则化损失函数
regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
#直接调用mnist_inference.py函数中定义的前向传播过程
y = mnist_inference.inference(x, regularizer)
#定义训练存储轮数的变量,代表轮数的变量不可训练
global_step = tf.Variable(0, trainable=False)
#2.2,定义损失函数、学习率、滑动平均操作以及训练过程
#给定滑动平均衰减率和训练轮数变量,初始化滑动平均类(给定训练轮数的变量可以加快训练早期变量的更新速度)
variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
#在所有代表神经网络参数的变量上使用滑动平均
variables_averages_op = variable_averages.apply(tf.trainable_variables())
#计算交叉熵作为刻画预测值和真实值之间的差距的损失函数,使用tensorflow自带函数计算交叉熵
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
#计算当前所有batch中的样例的交叉熵平均值
cross_entropy_mean = tf.reduce_mean(cross_entropy)
#总损失等于交叉熵损失和正则化损失之和
loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
#设置指数衰减学习率
learning_rate = tf.train.exponential_decay(
LEARNING_RATE_BASE,#基础学习率,随着迭代的进行,更新变量时使用的学习率在此基础递增
global_step, #当前迭代轮数
mnist.train.num_examples / BATCH_SIZE,#过完所有的训练数据需要的迭代次数
LEARNING_RATE_DECAY,#学习率衰减速度
staircase=True)
#使用tf.train.GradientDescentOptimizer优化算法来优化损失函数,该损失函数包含交叉熵损失很L2正则化损失
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
#训练神经网络模型,没过一边数据,既需要反向传播更新神经网络的参数,又要更新每一个参数的滑动平均值
#为了一次完成多个操作,采用了tf.control_dependencies
with tf.control_dependencies([train_step, variables_averages_op]):
train_op = tf.no_op(name='train')
#3,初始化tensorflow持久类
saver = tf.train.Saver()
#初始化会话,开始训练过程
with tf.Session() as sess:
#初始化所有变量。
tf.initialize_all_variables().run()
#实现断点续训代码
if start==1:
#tf.train.latest_checkpoint()自动获取最后一次保存的模型
model_file=tf.train.latest_checkpoint(MODEL_SAVE_PATH)
saver.restore(sess,model_file)
#在测试的过程中,不在测试模型在验证数据上面的表现,验证和测试有独立程序mnist_eval.py来完成。
for i in range(TRAINING_STEPS):#迭代训练神经网络。
xs, ys = mnist.train.next_batch(BATCH_SIZE)
_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
#每1000轮保存一个模型
if i % 1000 == 0:
#输出当前训练情况,模型在当前训练batch上面损失函数大小
print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
#保存当前模型,可以让每个被保存的文件名字的末尾加上训练的轮数
saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
#4定义训练集的主函数
def main(argv=None):
start=input("请选择:您希望从头开始训练还是进行断点续训?\n从头开始训练请按0 断点续训请按1\n")
mnist = input_data.read_data_sets("path/yxx/MNIST_data", one_hot=True)
#训练模型
train(mnist,start)
#TensorFlow提供一个主程序入口,tf.app.run会调用上面定义的main函数。
if __name__ == '__main__':
tf.app.run()
没有合适的资源?快使用搜索试试~ 我知道了~
tensorflow实现训练代码(mnist数据集&断点续训)
共43个文件
gz:12个
data-00000-of-00001:8个
index:8个
4星 · 超过85%的资源 需积分: 46 47 下载量 50 浏览量
2017-11-04
12:14:20
上传
评论 2
收藏 66.86MB ZIP 举报
温馨提示
利用tensorflow实现读取数据集,训练数据集,最后得出准确率。基于pytho语言,有详细的注释。
资源推荐
资源详情
资源评论
收起资源包目录
tensorflow训练代码.zip (43个子文件)
tensorflow训练代码
mnist_inference.py 3KB
mnist_train.py 5KB
path
yxx
MNIST_data.zip 11.06MB
MNIST_data
t10k-labels-idx1-ubyte.gz 4KB
t10k-images-idx3-ubyte.gz 1.57MB
train-images-idx3-ubyte.gz 9.45MB
train-labels-idx1-ubyte.gz 28KB
MNIST_model
mnist_model-1.data-00000-of-00001 3.03MB
mnist_model-1.index 470B
mnist_model-26001.data-00000-of-00001 3.03MB
mnist_model-28001.meta 66KB
mnist_model-27001.index 470B
mnist_model-23001.index 470B
mnist_model-1.meta 66KB
mnist_model-25001.index 470B
mnist_model-23001.meta 66KB
mnist_model-29001.data-00000-of-00001 3.03MB
mnist_model-24001.index 470B
mnist_model-26001.meta 66KB
mnist_model-23001.data-00000-of-00001 3.03MB
mnist_model-28001.index 470B
mnist_model-25001.meta 66KB
mnist_model-29001.index 470B
mnist_model-29001.meta 66KB
mnist_model-28001.data-00000-of-00001 3.03MB
mnist_model-26001.index 470B
mnist_model-25001.data-00000-of-00001 3.03MB
mnist_model-27001.data-00000-of-00001 3.03MB
mnist_model-24001.data-00000-of-00001 3.03MB
mnist_model-27001.meta 66KB
checkpoint 83B
mnist_model-24001.meta 66KB
__pycache__
mnist_train.cpython-36.pyc 3KB
mnist_inference.cpython-36.pyc 1KB
MNIST_data
t10k-labels-idx1-ubyte.gz 4KB
t10k-images-idx3-ubyte.gz 1.57MB
train-images-idx3-ubyte.gz 9.45MB
train-labels-idx1-ubyte.gz 28KB
mnist_eval.py 4KB
MNIST_data
t10k-labels-idx1-ubyte.gz 4KB
t10k-images-idx3-ubyte.gz 1.57MB
train-images-idx3-ubyte.gz 9.45MB
train-labels-idx1-ubyte.gz 28KB
共 43 条
- 1
资源评论
- qq_173197832017-12-28试一试好不好用
- liaooail2018-03-21多些分享,有用
Bluesyxx
- 粉丝: 7
- 资源: 10
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- 基于STM8S103F3P6+STM8S207C8T6+STM32F103 单片机三合一最小系统开发板硬件(原理图+PCB)工程
- 基于C语言实现的打印杨辉三角
- 基于ASIO的插件式服务器,支持TCP,UDP,串口,Http,Websocket统一化的数据接口,隔离开发人员和IO之间的操作
- stm32 usb接口通信
- Chessmate是一款完全免费的国际象棋学习软件,支持引擎分析,学开局、残局、棋书解读、大数据分析等功能
- 总结整理的Android面试Java基础知识点面试资料精编汇总文档资料合集.zip
- .android_lq
- FDN5632N-VB一款SOT23封装N-Channel场效应MOS管
- 毛老板-2404250902.amr
- Java类加载流程(双亲委派)流程图.zip
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功