'''''''''
@file: net_model.py
@author: MRL Liu
@time: 2021/4/14 22:52
@env: Python,Numpy
@desc: 本模块提供定义模型、训练模型、评估模型的方法
@ref:
@blog: https://blog.csdn.net/qq_41959920
'''''''''
import os
import random
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import datahelpter
# 神经网络参数
INPUT_NODE=784 # 输入维度
OUTPUT_NODE=10 # 输出维度
LAYER1_NODE=500 # 第一层维度
# 学习率
LEARNING_RATE_BASE=0.8
LEARNING_RATE_DECAY=0.99
REGULARIZTION_RATE=0.0001
# 训练参数
TRAINING_STEPS=3000
BATCH_SIZE=100
MOVING_AVERAGE_DECAY=0.99
#模型保存的路径和文件名
DATASET_SAVE_PATH='./mnist/' # 数据集保存路径
MODEL_SAVE_PATH='./models/'
LOGS_SAVE_PATH='./logs/'
MODEL_NAME='model.ckpt'
class net_model(object):
def __init__(self,num_examples):
self.n_input =INPUT_NODE
self.n_layer_1 =LAYER1_NODE
self.n_output = OUTPUT_NODE
# 滑动平均模型相关参数
self.moving_average_decay = MOVING_AVERAGE_DECAY
# 训练相关参数
self.training_step = TRAINING_STEPS
self.batch_size = BATCH_SIZE
# 学习率的相关参数(指数衰减法)
self.learn_rate_base = LEARNING_RATE_BASE # 学习率初始值
self.learn_rate_decay = LEARNING_RATE_DECAY # 学习率衰减率
self.learn_rate_num = num_examples / self.batch_size, # 学习率衰减次数
# 相关路径
self.model_save_path = MODEL_SAVE_PATH
self.logs_save_path = LOGS_SAVE_PATH
self.model_name = MODEL_NAME
def test_accuracy(self,_images,_labels):
with tf.Graph().as_default() as g:
x_input = tf.placeholder(tf.float32, [None, self.n_input], name='x-input')
y_input = tf.placeholder(tf.float32, [None, self.n_output], name='y-input')
output = self._define_net(x_input, regularizer__function=None, is_historgram=False)
# 计算准确率
correct_prediction = tf.equal(tf.argmax(output, 1), tf.argmax(y_input, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# 滑动平均模型变量
variables_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY) # 定义一个滑动平均类
variables_to_restore = variables_averages.variables_to_restore() # 生成变量重命名的列表
# 创建加载变量重命名后的保存器
saver = tf.train.Saver(variables_to_restore)
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(self.model_save_path) # 获取ckpt的模型文件的路径
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path) # ckpt.model_checkpoint_path保存了最新次数的模型文件路径
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] # 获取训练次数
accuracy_score = sess.run(accuracy, feed_dict={x_input: _images,y_input: _labels}) # 运行计算图,获取准确率
print('After %s training step(s), accuracy on validation is %g.' % (global_step, accuracy_score))
else:
print('No checkpoint file found')
return
def test_random(self,_images,_labels):
# 随机挑选9个照片
random_indices = random.sample(range(len(_images)), min(len(_images), 9))
images, labels = zip(*[(_images[i], _labels[i]) for i in random_indices])
# 加载模型
pred = self._run_saved_model(images,labels)
if pred is not None:
datahelpter.plot_images(images=images, cls_true=np.argmax(labels, 1), cls_pred=np.argmax(pred, 1),img_size=28, num_channels=1)
def _run_saved_model(self,images,labels):
# 加载模型
with tf.Graph().as_default() as g:
x_input = tf.placeholder(tf.float32, [None, self.n_input], name='x-input')
y_input = tf.placeholder(tf.float32, [None, self.n_output], name='y-input')
output = self._define_net(x_input, regularizer__function=None, is_historgram=False)
# 滑动平均变量
variables_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY) # 定义一个滑动平均类
variables_to_restore = variables_averages.variables_to_restore() # 生成变量重命名的列表
# 创建加载变量重命名后的保存器
saver = tf.train.Saver(variables_to_restore)
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(self.model_save_path) # 获取ckpt的模型文件的路径
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path) # 恢复模型参数
pred = sess.run(output, feed_dict= {x_input: images, y_input: labels}) # 运行计算图,获取准确率
return pred
else:
print('No checkpoint file found')
return None
def train(self,train):
""" 训练一个计算图模型"""
self._define_graph()
merged_summary_op = tf.summary.merge_all() # 合并所有的summary为一个操作节点,方便运行
saver = tf.train.Saver()# 网络模型保存器
# 开始训练
with tf.Session() as sess:
tf.global_variables_initializer().run() # 初始化所有变量
train_writer = tf.summary.FileWriter(self.logs_save_path, sess.graph) # 文件输出对象,用于生成graph event文件
for i in range(1, self.training_step + 1):
xs, ys = train.next_batch(self.batch_size) # 获取一个批次
# 定期保存网络
if i % 1000 == 0:
saver.save(sess, os.path.join(self.model_save_path, self.model_name), global_step=self.global_step) # 保存cnpk模型
# 执行优化器、损失值和step
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)#配置运行时需要记录的信息
run_metadata = tf.RunMetadata()#运行时记录运行信息的proto
_, loss_value, step = sess.run([self.optimizer, self.loss, self.global_step], feed_dict={self.x_input: xs, self.y_input: ys},
options=run_options, run_metadata=run_metadata)
train_writer.add_run_metadata(run_metadata, 'step%03d' % i) # #将节点在运行时的信息写入日志文件
print('Epoich: %d , loss: %g. and save model successfully' % (step, loss_value))
# 定期打印信息和记录变量
elif i % 10 == 0:
# 直接执行优化器、损失值和step和合并操作
_, loss_value, step, summary = sess.run([self.optimizer, self.loss, self.global_step, merged_summary_op],
feed_dict={self.x_input: xs, self.y_input: ys})
print('Epoich: %d , loss: %g.' % (step, loss_value))
train_writer.add_summary(summary, i) # 添加到graph event文件中用于TensorBoard的显示
else:
_ = sess.run( [self.optimizer],feed_dict={self.x_input: xs, self.y_input: ys})# 优化参数
train_writer.close()
def _define_graph(self):
""" 定义一个计算图"""
# 定义计算图的输入结构
with tf.name_scope('input'):
self.x_input = tf.placeholder(dtype=tf.float32, shape=[None, self.n_input], name='x-input') # 网络输入格式
self.y_input = tf.placeholder(dtype=tf.float32, shape=[None, self.n_output], name='y-input') # 网络标签格式
# 定义计算图
没有合适的资源?快使用搜索试试~ 我知道了~
基于TensorFlow框架训练的一个全连接网络的手写数字识别器.zip
共18个文件
gz:4个
meta:3个
index:3个
1.该资源内容由用户上传,如若侵权请联系客服进行举报
2.虚拟产品一经售出概不退款(资源遇到问题,请及时私信上传者)
2.虚拟产品一经售出概不退款(资源遇到问题,请及时私信上传者)
版权申诉
0 下载量 102 浏览量
2024-02-20
13:56:16
上传
评论
收藏 20.06MB ZIP 举报
温馨提示
基于MNIST的手写数字识别项目已是深度学习入门的必备项目,但区别于其他,本项目的特色是添加了模型的保存与加载功能、TensorFlow可视化功能、指数衰减法的学习率、滑动平均模型技术、L2正则化等先进技术和可视化手写数字图片功能等。整个项目基于良好的面向对象思想,方法定义层层推进,可以说是非常好的总结性学习材料。
资源推荐
资源详情
资源评论
收起资源包目录
基于TensorFlow框架训练的一个全连接网络的手写数字识别器.zip (18个子文件)
MRL-Mnist-Number-Master-main
net_model.py 16KB
logs
events.out.tfevents.1618902502.DESKTOP-E4438MJ 4.22MB
models
model.ckpt-2999.data-00000-of-00001 3.03MB
checkpoint 178B
model.ckpt-999.data-00000-of-00001 3.03MB
model.ckpt-2999.index 497B
model.ckpt-1999.meta 76KB
model.ckpt-1999.data-00000-of-00001 3.03MB
model.ckpt-999.index 497B
model.ckpt-2999.meta 76KB
model.ckpt-1999.index 497B
model.ckpt-999.meta 76KB
__pycache__
datahelpter.cpython-36.pyc 2KB
datahelpter.py 2KB
mnist
t10k-images-idx3-ubyte.gz 1.57MB
train-labels-idx1-ubyte.gz 28KB
train-images-idx3-ubyte.gz 9.45MB
t10k-labels-idx1-ubyte.gz 4KB
共 18 条
- 1
资源评论
博士僧小星
- 粉丝: 1883
- 资源: 5877
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- 【代码】RT1021-100P MicroPython固件使用例程.7z
- springboot176基于Spring Boot的装饰工程管理系统.rar
- springboot171社区医院管理系统.rar
- 1251+1255.pdf
- springboot169基于vue的工厂车间管理系统的设计.rar
- springboot168基于springboot + vue的疫情隔离管理系统.rar
- 高校人事管理系统.zip
- ACM比赛算法:ACM 树同构-比赛常用的算法
- Python课程大作业二手车价格预测案例数据挖掘源码+数据集+实验报告+详细注释.zip
- springboot167基于springboot的医院后台管理系统的设计与实现.rar
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功