没有合适的资源?快使用搜索试试~ 我知道了~
使用TensorFlow实现二分类的方法示例
13 下载量 153 浏览量
2020-09-19
15:38:57
上传
评论
收藏 118KB PDF 举报
温馨提示
试读
4页
主要介绍了使用TensorFlow实现二分类的方法示例,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
资源推荐
资源详情
资源评论
使用使用TensorFlow实现二分类的方法示例实现二分类的方法示例
主要介绍了使用TensorFlow实现二分类的方法示例,小编觉得挺不错的,现在分享给大家,也给大家做个参
考。一起跟随小编过来看看吧
使用使用TensorFlow构建一个神经网络来实现二分类,主要包括输入数据格式、隐藏层数的定义、损失函数的选择、优化函数的构建一个神经网络来实现二分类,主要包括输入数据格式、隐藏层数的定义、损失函数的选择、优化函数的
选择、输出层。下面通过选择、输出层。下面通过numpy来随机生成一组数据,通过定义一种正负样本的区别,通过来随机生成一组数据,通过定义一种正负样本的区别,通过TensorFlow来构造一个神经网络来构造一个神经网络
来实现二分类。来实现二分类。
一、神经网络结构一、神经网络结构
输入数据:定义输入一个二维数组(x1,x2),数据通过numpy来随机产生,将输出定义为0或1,如果x1+x2<1,则y为1,否则
y为0。
隐藏层:定义两层隐藏层,隐藏层的参数为(2,3),两行三列的矩阵,输入数据通过隐藏层之后,输出的数据为(1,3),t通过矩
阵之间的乘法运算可以获得输出数据。
损失函数:使用交叉熵作为神经网络的损失函数,常用的损失函数还有平方差。
优化函数:通过优化函数来使得损失函数最小化,这里采用的是Adadelta算法进行优化,常用的还有梯度下降算法。
输出数据:将隐藏层的输出数据通过(3,1)的参数,输出一个一维向量,值的大小为0或1。
二、TensorFlow代码的实现
import tensorflow as tf
from numpy.random import RandomState
if __name__ == "__main__":
#定义每次训练数据batch的大小为8,防止内存溢出
batch_size = 8
#定义神经网络的参数
w1 = tf.Variable(tf.random_normal([2,3],stddev=1,seed=1))
w2 = tf.Variable(tf.random_normal([3,1],stddev=1,seed=1))
#定义输入和输出
x = tf.placeholder(tf.float32,shape=(None,2),name="x-input")
y_ = tf.placeholder(tf.float32,shape=(None,1),name="y-input")
#定义神经网络的前向传播过程
a = tf.matmul(x,w1)
y = tf.matmul(a,w2)
#定义损失函数和反向传播算法
#使用交叉熵作为损失函数
#tf.clip_by_value(t, clip_value_min, clip_value_max,name=None)
#基于min和max对张量t进行截断操作,为了应对梯度爆发或者梯度消失的情况
cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y,1e-10,1.0)))
# 使用Adadelta算法作为优化函数,来保证预测值与实际值之间交叉熵最小
train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)
#通过随机函数生成一个模拟数据集
rdm = RandomState(1)
# 定义数据集的大小
dataset_size = 128
# 模拟输入是一个二维数组
X = rdm.rand(dataset_size,2)
#定义输出值,将x1+x2 < 1的输入数据定义为正样本
Y = [[int(x1+x2 < 1)] for (x1,x2) in X]
#创建会话运行TensorFlow程序
资源评论
weixin_38675341
- 粉丝: 8
- 资源: 998
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- 基于UC3842+LTS26Q1565A设计PC机充电器 硬件(原理图+PCB)工程文件.zip
- Hive SQL经典面试题,大数据SQL经典面试题
- Qt实现喷码器代码,实现二维码、条形码、图形的旋转、移动等
- 基于LM324芯片比较器传感器模块AD09设计硬件(原理图+PCB)工程文件.zip
- HTTP请求 - 记一笔-添加记账.jmx
- 2205040245凡永超硬间隔svm.ipynb
- Qt喷码器demo,演示软件,不是代码
- 目标跟踪-基于目标中心点同时进行目标检测+目标跟踪算法实现-项目源码-优质项目实战.zip
- Python《文本特征分析-全唐诗数据挖掘及分析 》+源代码
- Netron-Setup-4.5.0
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功