没有合适的资源?快使用搜索试试~ 我知道了~
Tensorflow2.0:加载与识别经典数据集MINIST
9 下载量 88 浏览量
2021-01-20
11:35:12
上传
评论
收藏 72KB PDF 举报
温馨提示
试读
3页
一 实现思路 1. 加载 MNIST 数据集,得到训练集与测试集 2. 将训练集与测试集转换为DataSet对象 3. 将数据顺序打散 避免每次读取数据顺序相同,使得模型记住训练集的一些特点,降低模型泛化能力。 4. 设置批训练 从训练集总数中随机抽取batchsize个样本,来进行模型训练,相比于使用所用样本构建模型,批训练花费的时间更少,计算效率更高。每训练一个次,就叫一个step,当经历若干个step使得把训练集所有样本训练过以后,那叫一个epoch 5. 数据预处理 图片像素值进行标准化,使得处于0到1的区间 图片的类别转化成one-hot编码 图片的标签是数字0到数字10,是属于多分
资源详情
资源评论
资源推荐
Tensorflow2.0:加载与识别经典数据集:加载与识别经典数据集MINIST
一一 实现思路实现思路
1. 加载加载 MNIST 数据集,得到训练集与测试集数据集,得到训练集与测试集
2. 将训练集与测试集转换为将训练集与测试集转换为DataSet对象对象
3. 将数据顺序打散将数据顺序打散
避免每次读取数据顺序相同,使得模型记住训练集的一些特点,降低模型泛化能力。
4. 设置批训练设置批训练
从训练集总数中随机抽取batchsize个样本,来进行模型训练,相比于使用所用样本构建模型,批训练花费的时间更少,计算
效率更高。每训练一个次,就叫一个step,当经历若干个step使得把训练集所有样本训练过以后,那叫一个epoch
5. 数据预处理数据预处理
图片像素值进行标准化,使得处于0到1的区间
图片的类别转化成one-hot编码
图片的标签是数字0到数字10,是属于多分类问题,为了能够量化类别,将图片的类别转化成长度为10位数字的one-hot编
码,便于和神经网络输出结果比较,计算其损失。
6. 神经网络构建神经网络构建
其主要的流程为:
设置学习率
网络结构参数初始化
计算前向传播
计算损失函数
计算梯度
根据梯度更新参数(梯度下降法)
每经过固定的step记录和输出训练误差,即均方根误差
每经过固定的step,输出测试误差,即分类正确率
二二 实现方式实现方式
1.数据处理阶段数据处理阶段
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets # 导入经典数据集加载模块
# 1. 加载 MNIST 数据集
(x, y), (x_test, y_test) = datasets.mnist.load_data() # 返回数组的形状
# 2. 将数据集转换为DataSet对象,不然无法继续处理
train_db = tf.data.Dataset.from_tensor_slices((x, y))
# print(train_db) #
# 3. 将数据顺序打散
train_db = train_db.shuffle(10000) # 数字为缓冲池的大小
# print(train_db) #
# 4. 设置批训练
train_db = train_db.batch(512) # batch size 为 128
# print(train_db) #
# 5. 预处理函数
def preprocess(x, y): # 输入x的shape 为[b, 32, 32], y为[b] # 将像素值标准化到 0~1区间
x = tf.cast(x, dtype=tf.float32) / 255.
# 将图片改为28*28大小的
x = tf.reshape(x, [-1, 28 * 28])
# 这个reshape我认为是和数据的存储顺序发生冲突,读取的数据应该不是原图的数据,而是被打乱的数据
# 将数据集的类别标签(数字0-10)转换为one-hot 编码
y = tf.cast(y, dtype=tf.int32) # 转成整型张量
y = tf.one_hot(y, depth=10)
return x, y
# 将数据集传入预处理函数,train_db支持map映射函数
train_db = train_db.map(preprocess)
# print(train_db) #
# 设置训练20个epoch
train_db = train_db.repeat(20) # 将train_db在内部迭代20遍
# 查看train_db的结构
x, y = next(iter(train_db))
print(x, y)
print('train sample:', x.shape, y.shape) # (512, 784) (512, 10)
# 从上面可以看出,现在的train_db已经变成可每份512*784的矩阵,有变成了每份512*10的矩阵,784表示输入的特征
数,10表示输出的类别所对应的向量,即one-hot编码
weixin_38746818
- 粉丝: 7
- 资源: 910
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功
评论0