clc, clearvars, close all
%------------------------------------------------------------------------------------------------
%%%%% Deep learning convolutional neural network regression v1.1 %%%%%
% With network parameter gridsearch, input normalization and geometric image augmenation:
% Network is designed to learn to predict a numerical value from images
% Provided example input (X)data is 3 channel (R2*, QSM, GRE mag) MR image slices of ex vivo stroke blood clots
% Provided example output (Y)data is RBC content of clots determined through histological analysis
% Built to iterate cross-validation experiments over a set of specified network parameters (gridsearch)
% Can customize which parameters to iterate over by changing the for loops
% Network settings which produce the highest accuracy will be used to form predictions on a separate test set
% You must add iterated variables in for loops (line 87) and results table pull (line 275) for this to work
% Network can normalize training data input based on distribution of the output variable using random oversampling (ROS)
% Network can apply random geometric augmentation operations to training data
% And duplicate training dataset to increase size, prior to augmenation (duplication factor)
% Handles input predictor (x) data as [X, Y, nchannel, nslice] where nchannel = 1 or 3
% Handles input prediction (y) data as [y(1):y(nslice)]
% v1.0.1 code written and posted by Spencer D. Christiansen, October 2021. Tested on Matlab 2020b
% v1.1 updated March 2022:
% -fixed overwrite_crossval_results_table glitch that created new table for every iteration
% -fixed nonexistent aug_params matrix when augment_training_data set to 0
% -minor changes to comments and variable names for clarity
%------------------------------------------------------------------------------------------------
load([ 'test_input_xdata.mat'])
load([ 'test_input_ydata.mat'])
save_directory = '';
summary_file_name = 'Crossval_results_summary_table';
%Network settings:
augment_training_data = 1; %Randomly augment training dataset
normalize_training_input = 1; %Normalize training data based on predictor distribution
overwrite_crossval_results_table = 1; %1: create new cross validation results table, 0: add lines to existing table
%Input data pre-processing:
xdata = permute(xdata,[1,2,4,3]); %input needs to be [X, Y, nchannel, nslice]
n_images = size(ydata,1); %For random splitting of input into training/testing groups
img_number = 1:n_images;
rng(20211013);
idx = randperm(n_images);
%Training/testing data split:
n_test_images = 28; %Number of images to kept for independent test set
n_crossval_folds = 8; %Number of training cross-validation folds
n_train_images = n_images-n_test_images;
idx_test = find(ismember(img_number, idx(end-(n_test_images-1):end)));
XTest = xdata(:,:,:,idx_test);
YTest = ydata(idx_test);
xdata(:,:,:,idx_test) = [];
ydata(idx_test) = [];
img_number(idx_test) = [];
idx(end-(n_test_images-1):end) = [];
%Network parameters (can iterate over by moving into for loop):
%params.optimizer = {'sgdm','adam'}; %'sgdm' | 'adam'
params.batch_size = 8;
%params.max_epochs = [4,6,8];
params.learn_rate = 0.001;
params.learn_rate_drop_factor = 0.1;
params.learn_rate_drop_period = 20;
params.learn_rate_schedule = 'none'; %'none', 'piecewise'
params.shuffle = 'every-epoch';
params.momentum = 0.9; %for sgdm optimizer
params.L2_reg = 0.01;
params.conv_features = [16, 16, 32]; %Number of feature channels in convolutional layers
params.conv_filter_size = 3;
params.conv_padding = 'same';
params.pooling_size = 2;
params.pooling_stride = 1;
params.dropout_factor = 0.2;
params.duplication_factor = 3; %Duplicate training set by N times
show_plots = 0; %1: show plots of training progress
tic
iter = 1;
disp('Performing cross validation evaluation over all network iterations:')
for var1 = {'sgdm','adam'}
params.optimizer = var1;
for var2 = [4,6,8]
params.max_epochs = var2;
%for var3 = X:Y
%params.example = var3;
%etc.
%Splitting training data into k-folds
for k = 1:n_crossval_folds
images_per_fold = floor(n_train_images/n_crossval_folds);
idx_val = find(ismember(img_number, idx(1+(k-1)*images_per_fold:images_per_fold+(k-1)*images_per_fold)));
YValidation = ydata(idx_val);
XValidation = xdata(:,:,:,idx_val);
XTrain = xdata;
XTrain(:,:,:,idx_val) = [];
YTrain = ydata;
YTrain(idx_val) = [];
%ROS input normalization:
if normalize_training_input == 1
[XTrain, YTrain] = ROS(XTrain, YTrain, params.duplication_factor);
else
XTrain = repmat(XTrain,1,1,1,params.duplication_factor);
YTrain = repmat(YTrain,params.duplication_factor,1);
end
%Random geometric image augmenation:
%Augmentation parameters
aug_params.rot = [-90,90]; %Image rotation range
aug_params.trans_x = [-5 5]; %Image translation in X direction range
aug_params.trans_y = [-5 5]; %Image translation in Y direction range
aug_params.refl_x = 1; %Image reflection across X axis
aug_params.refl_y = 1; %Image reflection across Y axis
aug_params.scale = [0.7,1.3]; %Imaging scaling range
aug_params.shear_x = [-30,50]; %Image shearing in X direction range
aug_params.shear_y = [-30,50]; %Image shearing in Y direction range
aug_params.add_gauss_noise = 0; %Add Gaussian noise
aug_params.gauss_noise_var = 0.0005; %Gaussian noise variance
if augment_training_data == 1
XTrain = image_augmentation(XTrain,aug_params);
else
aug_params = structfun(@(x) [], aug_params, 'UniformOutput', false);
end
%Network structure:
layers = [
imageInputLayer([size(XTrain,1),size(XTrain,2),size(XTrain,3)])
convolution2dLayer(params.conv_filter_size,params.conv_features(1),'Padding',params.conv_padding)
%batchNormalizationLayer
reluLayer
averagePooling2dLayer(params.pooling_size,'Stride',params.pooling_stride)
convolution2dLayer(params.conv_filter_size,params.conv_features(2),'Padding',params.conv_padding)
%batchNormalizationLayer
reluLayer
averagePooling2dLayer(params.pooling_size,'Stride',params.pooling_stride)
convolution2dLayer(params.conv_filter_size,params.conv_features(3),'Padding',params.conv_padding)
%batchNormalizationLayer
reluLayer
dropoutLayer(params.dropout_factor)
fullyConnectedLayer(1)
regressionLayer];
params.validationFrequency = floor(numel(YTrain)/params.batch_size);
options = network_options(params,XValidation,YValidation,show_plots);
net = trainNetwork(XTrain,YTrain,layers,options);
%Network results:
accuracy_threshold = 0.1; %Predictions within 10% will be considered 'accurate'
predicted_train = predict(net,XTrain);
predictionError_train = YTrain - predicted_train;
numCorrect_train = sum(abs(predictionError_train) < accuracy_threshold);
accuracy_train(k) = numCorrect_train/numel(YTrain);
error_abs_train(k) = mean(abs(predictionError_train));
rmse_train(k) = sqrt(mean(predictionError_train.^2));
predicted_val = predict(net,XValidation);
predictionError_val = YValidation - predicted_val;
numCorrect_val = sum(abs(predictionError_val) < accuracy_threshold);
accuracy_val(k) = numCorrect_val/numel(YValidation);
error_abs_val(k) = mean(abs(predictionError_val));
rmse_val(k) = sqrt(mean(predictionError_val.^2));
if k == 1
predicted_val_table(1:numel(YValidation),1) = predict(net,XValidation);
if iter == 1
YValidation_table(1:numel(YValidation),1) = YValidation;
end
else
predicted_val_table(end+1:end+numel(YValidation),1) = predict(net,XValidation);
if
没有合适的资源?快使用搜索试试~ 我知道了~
【CNN回归预测】基于matlab卷积神经网络CNN数据回归预测【含Matlab源码 2003期】.zip
共9个文件
m:4个
mat:3个
png:2个
1.该资源内容由用户上传,如若侵权请联系客服进行举报
2.虚拟产品一经售出概不退款(资源遇到问题,请及时私信上传者)
2.虚拟产品一经售出概不退款(资源遇到问题,请及时私信上传者)
版权申诉
5星 · 超过95%的资源 1 下载量 52 浏览量
2023-09-10
14:45:51
上传
评论
收藏 1.14MB ZIP 举报
温馨提示
CSDN海神之光上传的全部代码均可运行,亲测可用,直接替换数据即可,适合小白; 1、代码压缩包内容 主函数:CNN_regression.m; 数据; 调用函数:其他m文件;无需运行 运行结果效果图; 2、代码运行版本 Matlab 2019b;若运行有误,根据提示修改;若不会,可私信博主; 3、运行操作步骤 步骤一:将所有文件放到Matlab的当前文件夹中; 步骤二:双击打开除CNN_regr.m的其他m文件; 步骤三:点击运行,等程序运行完得到结果; 4、仿真咨询 如需其他服务,可私信博主或扫描博主博客文章底部QQ名片; 4.1 CSDN博客或资源的完整代码提供 4.2 期刊或参考文献复现 4.3 Matlab程序定制 4.4 科研合作 智能优化算法优化CNN卷积神经网络预测系列程序定制或科研合作方向: 4.4.1 遗传算法GA/蚁群算法ACO优化CNN 4.4.2 粒子群算法PSO/蛙跳算法SFLA优化CNN 4.4.3 灰狼算法GWO/狼群算法WPA优化CNN 4.4.4 鲸鱼算法WOA/麻雀算法SSA优化CNN 4.4.5 萤火虫算法FA/差分算法DE优化CNN
资源推荐
资源详情
资源评论
收起资源包目录
【CNN回归预测】基于matlab卷积神经网络CNN数据回归预测【含Matlab源码 2003期】.zip (9个子文件)
【CNN回归预测】基于matlab卷积神经网络CNN数据回归预测【含Matlab源码 2003期】
test_input_ydata.mat 592B
Crossval_results_summary_table.mat 1KB
运行结果1.png 27KB
test_input_xdata.mat 1.1MB
CNN_regression.m 15KB
image_augmentation.m 871B
ROS.m 1KB
运行结果2.png 25KB
network_options.m 3KB
共 9 条
- 1
资源评论
- 吉吉A5-2082024-04-10资源内容详实,描述详尽,解决了我的问题,受益匪浅,学到了。
海神之光
- 粉丝: 5w+
- 资源: 6110
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- (源码)基于C++的Unix V6++二级文件系统.zip
- (源码)基于Spring Boot和JPA的皮皮虾图片收集系统.zip
- (源码)基于Arduino和Python的实时歌曲信息液晶显示屏展示系统.zip
- (源码)基于C++和C混合模式的操作系统开发项目.zip
- (源码)基于Arduino的全球天气监控系统.zip
- OpenCVForUnity2.6.0.unitypackage
- (源码)基于SimPy和贝叶斯优化的流程仿真系统.zip
- (源码)基于Java Web的个人信息管理系统.zip
- (源码)基于C++和OTL4的PostgreSQL数据库连接系统.zip
- (源码)基于ESP32和AWS IoT Core的室内温湿度监测系统.zip
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功