没有合适的资源?快使用搜索试试~ 我知道了~
今天小编就为大家分享一篇pytorch之添加BN的实现,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
资源推荐
资源详情
资源评论
pytorch之添加之添加BN的实现的实现
今天小编就为大家分享一篇pytorch之添加BN的实现,具有很好的参考价值,希望对大家有所帮助。一起跟随小
编过来看看吧
pytorch之添加BN层
批标准化批标准化
模型训练并不容易,特别是一些非常复杂的模型,并不能非常好的训练得到收敛的结果,所以对数据增加一些预处理,同时使
用批标准化能够得到非常好的收敛结果,这也是卷积网络能够训练到非常深的层的一个重要原因。
数据预处理数据预处理
目前数据预处理最常见的方法就是中心化和标准化,中心化相当于修正数据的中心位置,实现方法非常简单,就是在每个特征
维度上减去对应的均值,最后得到 0 均值的特征。标准化也非常简单,在数据变成 0 均值之后,为了使得不同的特征维度有
着相同的规模,可以除以标准差近似为一个标准正态分布,也可以依据最大值和最小值将其转化为 -1 ~ 1之间,这两种方法非
常的常见,如果你还记得,前面我们在神经网络的部分就已经使用了这个方法实现了数据标准化,至于另外一些方法,比如
PCA 或者 白噪声已经用得非常少了。
Batch Normalization
前面在数据预处理的时候,尽量输入特征不相关且满足一个标准的正态分布,
这样模型的表现一般也较好。但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相关,且不再满足一个标准的
N(0, 1) 的分布,甚至输出的中心已经发生了偏移,这对于模型的训练,特别是深层的模型训练非常的困难。
所以在 2015 年一篇论文提出了这个方法,批标准化,简而言之,就是对于每一层网络的输出,对其做一个归一化,使其服从
标准的正态分布,这样后一层网络的输入也是一个标准的正态分布,所以能够比较好的进行训练,加快收敛速度。batch
normalization 的实现非常简单,对于给定的一个 batch 的数据 算法的公式如下
第一行和第二行是计算出一个 batch 中数据的均值和方差,接着使用第三个公式对 batch 中的每个数据点做标准化,ϵ是为了
计算稳定引入的一个小的常数,通常取 ,最后利用权重修正得到最后的输出结果,非常的简单,
实现一下简单的一维的情况,也就是神经网络中的情况
import sys
sys.path.append('..')
import torch
def simple_batch_norm_1d(x, gamma, beta):
eps = 1e-5
x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)
x = torch.arange(15).view(5, 3)
gamma = torch.ones(x.shape[1])
beta = torch.zeros(x.shape[1])
print('before bn: ')
print(x)
y = simple_batch_norm_1d(x, gamma, beta)
print('after bn: ')
print(y)
可以看到这里一共是 5 个数据点,三个特征,每一列表示一个特征的不同数据点,使用批标准化之后,每一列都变成了标准
的正态分布这个时候会出现一个问题,就是测试的时候该使用批标准化吗?答案是肯定的,因为训练的时候使用了,而测试的
时候不使用肯定会导致结果出现偏差,但是测试的时候如果只有一个数据集,那么均值不就是这个值,方差为 0 吗?这显然
资源评论
weixin_38677046
- 粉丝: 6
- 资源: 912
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功