import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference
import os
# 配置神经网络的参数
BATCH_SIZE = 100
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
TRAINING_STEPS = 30000
MOVING_AVERAGE_DECAY = 0.99
# 模型保存的路径和文件名
MODEL_SAVE_PATH = "MNIST_model/"
MODEL_NAME = "mnist_model"
def train(mnist):
# 定义输入输出placeholder
x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
# 使用mnist_inference.py文件中定义的前向传播的过程。
y = mnist_inference.inference(x, regularizer)
global_step = tf.Variable(0, trainable=False)
# 定义损失函数、学习率、滑动平均操作以及训练过程
variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables())
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
cross_entropy_mean = tf.reduce_mean(cross_entropy)
loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
learning_rate = tf.train.exponential_decay(
LEARNING_RATE_BASE,
global_step,
mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY,
staircase=True)
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
with tf.control_dependencies([train_step, variables_averages_op]):
train_op = tf.no_op(name='train')
# 使用tf.train.Saver()类保存模型
saver = tf.train.Saver()
with tf.Session() as sess:
# 初始化所有变量
tf.global_variables_initializer().run()
for i in range(TRAINING_STEPS):
xs, ys = mnist.train.next_batch(BATCH_SIZE)
_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
if i % 1000 == 0:
# 输出当前的训练情况,损失函数的大小
print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
def main(argv=None):
mnist = input_data.read_data_sets("../../../datasets/MNIST_data", one_hot=True)
train(mnist)
if __name__ == '__main__':
tf.app.run()
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
基于Python实现的手写数字识别系统源码+全部数据.zip 主要针对计算机相关专业的正在做毕设的学生和需要项目实战练习的学习者。也可作为课程设计、期末大作业。包含全部项目源码、该项目可以直接作为毕设使用。项目都经过严格调试,下载即用确保可以运行! 基于Python实现的手写数字识别系统源码+全部数据.zip 主要针对计算机相关专业的正在做毕设的学生和需要项目实战练习的学习者。也可作为课程设计、期末大作业。包含全部项目源码、该项目可以直接作为毕设使用。项目都经过严格调试,下载即用确保可以运行!基于Python实现的手写数字识别系统源码+全部数据.zip 主要针对计算机相关专业的正在做毕设的学生和需要项目实战练习的学习者。也可作为课程设计、期末大作业。包含全部项目源码、该项目可以直接作为毕设使用。项目都经过严格调试,下载即用确保可以运行!基于Python实现的手写数字识别系统源码+全部数据.zip 主要针对计算机相关专业的正在做毕设的学生和需要项目实战练习的学习者。也可作为课程设计、期末大作业。包含全部项目源码、该项目可以直接作为毕设使用。项目都经过严格调试,下载即用确保可以
资源推荐
资源详情
资源评论
收起资源包目录
基于Python实现的手写数字识别系统源码.zip (42个子文件)
主-master
mnist_inference.py 1KB
minist_eval.py 2KB
picture
number2.png 130KB
9.png 201KB
3.png 205KB
0.png 200KB
number3.png 69KB
1.png 172KB
number4.png 152KB
number5.png 149KB
number8.png 122KB
6.png 205KB
number6.png 209KB
number1.png 65KB
number9.png 139KB
5.png 195KB
4.png 209KB
8.png 220KB
number0.png 101KB
7.png 289KB
2.png 182KB
number7.png 196KB
app.py 2KB
MNIST_model
checkpoint 283B
mnist_model-26001.data-00000-of-00001 3.03MB
mnist_model-27001.data-00000-of-00001 3.03MB
mnist_model-27001.meta 62KB
mnist_model-29001.data-00000-of-00001 3.03MB
mnist_model-28001.meta 62KB
mnist_model-30000.index 470B
mnist_model-30000.meta 62KB
mnist_model-29001.meta 62KB
mnist_model-29001.index 470B
mnist_model-27001.index 470B
mnist_model-28001.data-00000-of-00001 3.03MB
mnist_model-28001.index 470B
mnist_model-30000.data-00000-of-00001 3.03MB
mnist_model-26001.index 470B
mnist_model-26001.meta 62KB
mnist_train.py 3KB
__pycache__
mnist_inference.cpython-37.pyc 1003B
mnist_train.cpython-37.pyc 2KB
共 42 条
- 1
资源评论
程序员张小妍
- 粉丝: 1w+
- 资源: 3252
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功