clear
clc
%迁移学习基本设定
class =10; %识别数据集分类数
foldername="numberdata"; %训练集名称
%迁移学习
net = alexnet;
layers = net.Layers(1:end-3);
new_layers = [layers
fullyConnectedLayer(class,'WeightLearnRateFactor',20,'BiasLearnRateFactor',20)
softmaxLayer
classificationLayer
];
image = imageDatastore(foldername,'IncludeSubfolders',true,'LabelSource','foldernames');
[imageTrain,imageTest] = splitEachLabel(image,0.7,'randomized');
ops = trainingOptions('sgdm', ...
'InitialLearnRate',0.0001, ...
'ValidationData',imageTest, ...
'Plots','training-progress', ...
'MiniBatchSize',4, ...
'MaxEpochs',5,...%迭代设置
'ValidationPatience',Inf,...
'Verbose',false);
tic
net_train = trainNetwork(imageTrain,new_layers,ops);
toc