import os
import tensorflow as tf
from 机器学习.mnist import input_data, model
data= input_data.read_data_sets("MNIST_data", one_hot=True)
with tf.variable_scope("regression"):
x=tf.placeholder(tf.float32,[None,784])
y,variables= model.regression(x)
y_=tf.placeholder("float",[None,10])
cross_entropy=tf.reduce_sum(y_*tf.log(y))
train_step=tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
saver=tf.train.Saver(variables)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for _ in range(1000):
batch_xs,batch_ys=data.train.next_batch(100)
sess.run(train_step,feed_dict={x:batch_xs,y_:batch_ys})
print((sess.run(accuracy,feed_dict={x:data.test.images,y_:data.test.labels})))
path=saver.save(
sess,os.path.join("C:\\Users\\R\\PycharmProjects\\PyC\\mnisttest\\mnist",'dat','regression.ckpt'),
write_meta_graph=False,write_state=False)
print("Saved:",path)
print(batch_ys.shape)
mnist线性回归预测(含数据集)Python TensorFlow
需积分: 23 59 浏览量
2019-10-09
08:58:36
上传
评论
收藏 11.06MB ZIP 举报
广阔天地,大有可为
- 粉丝: 35
- 资源: 31
最新资源
- python:利用matplotlib绘制直方图
- 基于matlab块匹配全景图像拼接系统代码12
- 基于matlab小波变换图像融合系统代码11
- 精雕3.5NC后置文件
- yolov8n-pose.pt 用 yolov8n-pose.onnx下载
- C++之STL的vector详解,包括初始化和各种函数:vector的初始化、数据的增删查改等
- stable-diffusion-webui-master
- openPLC-Editor C语言编程 在mp157 arm板上调用io等使用记录
- 无人机悬停时间计算软件.rar
- 主要讲解 mybatis中 实体层的属性与表的列不一致时如何处理? 可以采用将列重命名方法还可以采用resultMap 方式
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈