import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import os
BATCH_SIZE = 200
LEARNING_RATE_BASE = 0.1
LEARNING_RATE_DECAY = 0.99
REGULARIZER = 0.0001
STEPS = 50000
MOVING_AVERAGE_DECAY = 0.99
MODEL_SAVE_PATH="C:\\Users\\84031\\Desktop\\model"
MODEL_NAME="mnist_model"
def backward(mnist):
x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE])
y = mnist_forward.forward(x, REGULARIZER)
global_step = tf.Variable(0, trainable=False)
ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
cem = tf.reduce_mean(ce)
loss = cem + tf.add_n(tf.get_collection('losses'))
learning_rate = tf.train.exponential_decay(
LEARNING_RATE_BASE,
global_step,
mnist.train.num_examples / BATCH_SIZE,
LEARNING_RATE_DECAY,
staircase=True)
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
ema_op = ema.apply(tf.trainable_variables())
with tf.control_dependencies([train_step, ema_op]):
train_op = tf.no_op(name='train')
saver = tf.train.Saver()
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
for i in range(STEPS):
xs, ys = mnist.train.next_batch(BATCH_SIZE)
_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
if i % 1000 == 0:
print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
def main():
mnist =input_data.read_data_sets('C:\\Users\\84031\\Desktop\\mnist',one_hot=True)
backward(mnist)
if __name__ == '__main__':
main()
基于tensorflow的全连接神经网络的手写数字识别
需积分: 50 41 浏览量
2019-05-06
16:16:31
上传
评论 1
收藏 11.1MB ZIP 举报
gokingd
- 粉丝: 55
- 资源: 4
最新资源
- 549springboot + vue 民宿管理平台.zip (可运行源码+数据库文件+文档)
- ZArchiver.Pro_0.9.5.apk
- vmware环境配置.mp4
- 548springboot + vue 大学生社团活动平台.zip(可运行源码+数据库文件+文档)
- 微信小程序 辩论倒计时小程序源码 作业设计demo 计算机专业参考
- 深入探究文件IO,嵌入式Linux
- 微信备忘录小程序源码 作业设计demo 计算机专业作业
- 微信小程序 仿百度小说小程序 看小说小程序 实现源码
- 锂电资料包-锂离子电池技术干货资料合集.zip
- (王道计算机组成原理)第三章存储系统-第二节1:主存储器基本构成、基本的半导体原件和存储器芯片的原理_主存储器与存储芯片-CSDN博客 (2024….html
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈