function [net_iterative, net_weights_inc, net_grad_ssqr] = learner(net_iterative, momentum, net_weights_inc, net_grad_ssqr, net_grad, opts, epoch, bid, num_batch, GD)
num_net_layer = length(net_iterative);
net_ada_eta = net_weights_inc;
if strcmp(GD,'G')
batchNorm = opts.batchNorm_G;
else
batchNorm = opts.batchNorm_D;
end
for ll = 1:num_net_layer
switch opts.learner
case 'sgd'
net_weights_inc(ll).W = momentum*net_weights_inc(ll).W + opts.sgd_learn_rate(epoch)*net_grad(ll).W;
net_weights_inc(ll).b = momentum*net_weights_inc(ll).b + opts.sgd_learn_rate(epoch)*net_grad(ll).b;
if opts.batchNormlization && batchNorm(ll) == 1
net_weights_inc(ll).gamma = momentum*net_weights_inc(ll).gamma + opts.sgd_learn_rate(epoch)*net_grad(ll).gamma;
net_weights_inc(ll).beta = momentum*net_weights_inc(ll).beta + opts.sgd_learn_rate(epoch)*net_grad(ll).beta;
end
case 'ada_sgd'
net_grad_ssqr(ll).W = net_grad_ssqr(ll).W + (net_grad(ll).W).^2;
net_grad_ssqr(ll).b = net_grad_ssqr(ll).b + (net_grad(ll).b).^2;
net_ada_eta(ll).W = opts.learner_scale./sqrt(net_grad_ssqr(ll).W + 10^-8);
net_ada_eta(ll).b = opts.learner_scale./sqrt(net_grad_ssqr(ll).b + 10^-8);
net_weights_inc(ll).W = momentum*net_weights_inc(ll).W + net_ada_eta(ll).W.*net_grad(ll).W;
net_weights_inc(ll).b = momentum*net_weights_inc(ll).b + net_ada_eta(ll).b.*net_grad(ll).b;
if opts.batchNormlization && batchNorm(ll) == 1
net_grad_ssqr(ll).gamma = net_grad_ssqr(ll).gamma + (net_grad(ll).gamma).^2;
net_grad_ssqr(ll).beta = net_grad_ssqr(ll).beta + (net_grad(ll).beta).^2;
net_ada_eta(ll).gamma = opts.learner_scale./sqrt(net_grad_ssqr(ll).gamma + 10^-8);
net_ada_eta(ll).beta = opts.learner_scale./sqrt(net_grad_ssqr(ll).beta + 10^-8);
net_weights_inc(ll).gamma = momentum*net_weights_inc(ll).gamma + net_ada_eta(ll).gamma.*net_grad(ll).gamma;
net_weights_inc(ll).beta = momentum*net_weights_inc(ll).beta + net_ada_eta(ll).beta.*net_grad(ll).beta;
end
case 'ada_delta'
gamma = 0.9;
net_grad_ssqr(ll).W = gamma*net_grad_ssqr(ll).W + (1-gamma)*(net_grad(ll).W).^2;
net_grad_ssqr(ll).b = gamma*net_grad_ssqr(ll).b + (1-gamma)*(net_grad(ll).b).^2;
net_ada_eta(ll).W = sqrt((net_grad_ssqr(ll).W2 + 10^-8)./(net_grad_ssqr(ll).W + 10^-8));
net_ada_eta(ll).b = sqrt((net_grad_ssqr(ll).b2 + 10^-8)./(net_grad_ssqr(ll).b + 10^-8));
net_weights_inc(ll).W = momentum*net_weights_inc(ll).W + net_ada_eta(ll).W.*net_grad(ll).W;
net_weights_inc(ll).b = momentum*net_weights_inc(ll).b + net_ada_eta(ll).b.*net_grad(ll).b;
net_grad_ssqr(ll).W2 = gamma*net_grad_ssqr(ll).W2 + (1-gamma)*(net_grad(ll).W).^2;
net_grad_ssqr(ll).b2 = gamma*net_grad_ssqr(ll).b2 + (1-gamma)*(net_grad(ll).b).^2;
if opts.batchNormlization && batchNorm(ll) == 1
net_grad_ssqr(ll).gamma = gamma*net_grad_ssqr(ll).gamma + (1-gamma)*(net_grad(ll).gamma).^2;
net_grad_ssqr(ll).beta = gamma*net_grad_ssqr(ll).beta + (1-gamma)*(net_grad(ll).beta).^2;
net_ada_eta(ll).gamma = sqrt((net_grad_ssqr(ll).gamma2 + 10^-8)./(net_grad_ssqr(ll).gamma + 10^-8));
net_ada_eta(ll).beta = sqrt((net_grad_ssqr(ll).beta2 + 10^-8)./(net_grad_ssqr(ll).beta + 10^-8));
net_weights_inc(ll).gamma = momentum*net_weights_inc(ll).gamma + net_ada_eta(ll).gamma.*net_grad(ll).gamma;
net_weights_inc(ll).beta = momentum*net_weights_inc(ll).beta + net_ada_eta(ll).b.*net_grad(ll).beta;
net_grad_ssqr(ll).gamma2 = gamma*net_grad_ssqr(ll).gamma2 + (1-gamma)*(net_grad(ll).gamma).^2;
net_grad_ssqr(ll).beta2 = gamma*net_grad_ssqr(ll).beta2 + (1-gamma)*(net_grad(ll).beta).^2;
end
case 'rms'
gamma = 0.99;
net_grad_ssqr(ll).W = gamma*net_grad_ssqr(ll).W + (1-gamma)*(net_grad(ll).W).^2;
net_grad_ssqr(ll).b = gamma*net_grad_ssqr(ll).b + (1-gamma)*(net_grad(ll).b).^2;
net_ada_eta(ll).W = opts.learner_scale./sqrt(net_grad_ssqr(ll).W + 10^-8);
net_ada_eta(ll).b = opts.learner_scale./sqrt(net_grad_ssqr(ll).b + 10^-8);
net_weights_inc(ll).W = momentum*net_weights_inc(ll).W + net_ada_eta(ll).W.*net_grad(ll).W;
net_weights_inc(ll).b = momentum*net_weights_inc(ll).b + net_ada_eta(ll).b.*net_grad(ll).b;
if opts.batchNormlization && batchNorm(ll) == 1
net_grad_ssqr(ll).gamma = gamma*net_grad_ssqr(ll).gamma + (1-gamma)*(net_grad(ll).gamma).^2;
net_grad_ssqr(ll).beta = gamma*net_grad_ssqr(ll).beta + (1-gamma)*(net_grad(ll).beta).^2;
net_ada_eta(ll).gamma = opts.learner_scale./sqrt(net_grad_ssqr(ll).gamma + 10^-8);
net_ada_eta(ll).beta = opts.learner_scale./sqrt(net_grad_ssqr(ll).beta + 10^-8);
net_weights_inc(ll).gamma = momentum*net_weights_inc(ll).gamma + net_ada_eta(ll).gamma.*net_grad(ll).gamma;
net_weights_inc(ll).beta = momentum*net_weights_inc(ll).beta + net_ada_eta(ll).beta .*net_grad(ll).beta;
end
case 'adam'
beta1 = 0.9; beta2 = 0.999;
timestamp = (epoch-1)*num_batch + bid;
net_grad_ssqr(ll).W = beta1*net_grad_ssqr(ll).W + (1-beta1)*(net_grad(ll).W); % m
net_grad_ssqr(ll).b = beta1*net_grad_ssqr(ll).b + (1-beta1)*(net_grad(ll).b);
net_grad_ssqr(ll).W2 = beta2*net_grad_ssqr(ll).W2 + (1-beta2)*(net_grad(ll).W).^2; % v
net_grad_ssqr(ll).b2 = beta2*net_grad_ssqr(ll).b2 + (1-beta2)*(net_grad(ll).b).^2;
net_weights_inc(ll).W = momentum*net_weights_inc(ll).W + (opts.learner_scale*net_grad_ssqr(ll).W/(1-beta1.^(timestamp)))./(sqrt(net_grad_ssqr(ll).W2/(1-beta2.^(timestamp)) + 10^-8));
net_weights_inc(ll).b = momentum*net_weights_inc(ll).b + (opts.learner_scale*net_grad_ssqr(ll).b/(1-beta1.^(timestamp)))./(sqrt(net_grad_ssqr(ll).b2/(1-beta2.^(timestamp)) + 10^-8));
if opts.batchNormlization && batchNorm(ll) == 1
net_grad_ssqr(ll).gamma = beta1*net_grad_ssqr(ll).gamma + (1-beta1)*(net_grad(ll).gamma); % m
net_grad_ssqr(ll).beta = beta1*net_grad_ssqr(ll).beta + (1-beta1)*(net_grad(ll).beta);
net_grad_ssqr(ll).gamma2 = beta2*net_grad_ssqr(ll).gamma2 + (1-beta2)*(net_grad(ll).gamma).^2; % v
net_grad_ssqr(ll).beta2 = beta2*net_grad_ssqr(ll).beta2 + (1-beta2)*(net_grad(ll).beta).^2;
net_weights_inc(ll).gamma = momentum*net_weights_inc(ll).gamma + (opts.learner_scale*net_grad_ssqr(ll).gamma/(1-beta1.^(timestamp)))./(sqrt(net_grad_ssqr(ll).gamma2/(1-beta2.^(timestamp)) + 10^-8));
net_weights_inc(ll).beta = momentum*net_weights_inc(ll).beta + (opts.learner_scale*net_grad_ssqr(ll).beta /(1-beta1.^(timestamp)))./(sqrt(net_grad_ssqr(ll).beta2 /(1-beta2.^(timestamp)) + 10^-8));
end
case 'ams'
beta1 = 0.9; beta2 = 0.999;
net_grad_ssqr(ll).W = beta1*net_grad_ssqr(ll).W + (1-beta1)*(net_grad(ll).W); %
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
1、资源内容:基于Matlab生成对抗性网络仿真(源码+数据+说明文档).rar 2、适用人群:计算机,电子信息工程、数学等专业的大学生课程设计、期末大作业或毕业设计,作为“参考资料”使用。 3、解压说明:本资源需要电脑端使用WinRAR、7zip等解压工具进行解压,没有解压工具的自行百度下载即可。 4、免责声明:本资源作为“参考资料”而不是“定制需求”不一定能够满足所有人的需求,需要有一定的基础能够看懂代码,能够自行调试代码并解决报错,能够自行添加功能修改代码。由于作者大厂工作较忙,不提供答疑服务,如不存在资源缺失问题概不负责,谢谢理解。
资源推荐
资源详情
资源评论
收起资源包目录
基于Matlab生成对抗性网络仿真(源码+数据+说明文档).rar (45个子文件)
基于Matlab生成对抗性网络仿真(源码+数据+说明文档)
mnist.mat 13.99MB
opt_config.m 1KB
gpu_try.m 358B
说明文档.md 821B
gan.m 970B
gan
utility
genBatchID.m 335B
randInitNet.m 1KB
getNetParamStr.m 304B
randinitWbSparse.m 404B
leakyrelu.m 189B
zeroInitNet.m 1KB
compute_unit_activation.m 553B
softmax.m 143B
relu.m 156B
initRandW.m 376B
mean_var_norm.m 620B
gather_net.m 178B
avgRecur.m 420B
meanVarNormalize_Test.m 273B
relu_grad.m 43B
sigmoid.m 51B
format_print.m 321B
mean_var_norm_testing.m 279B
deltas.m 656B
batchComputeMeanStd.m 405B
meanVarArmaNormalize_Test.m 177B
count_struct.m 188B
compute_unit_gradient.m 561B
meanVarNormalize.m 244B
netRolling.m 396B
initRandWSparse.m 251B
initializeRandWSparse.m 374B
netUnRolling.m 124B
sigmoid_grad.m 43B
real_fake.m 286B
getMSE.m 1KB
main
computeNetGradient.m 6KB
learner.m 10KB
getOutputFromNetSplit.m 935B
forwardPass.m 2KB
batchNorm_forward.m 308B
batchNorm_backpass.m 309B
funcTrain.m 8KB
forwardPass_diff_drop_ratio.m 1KB
getOutputFromNet.m 968B
共 45 条
- 1
资源评论
- weixin_432068972024-05-07资源内容详细全面,与描述一致,对我很有用,有一定的使用价值。
Matlab仿真实验室
- 粉丝: 2w+
- 资源: 2180
下载权益
C知道特权
VIP文章
课程特权
开通VIP
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功