pytorch VGG11识别识别cifar10数据集数据集(训练训练+预测单张输入图片操预测单张输入图片操
作作)
主要介绍了pytorch VGG11识别cifar10数据集(训练+预测单张输入图片操作),具有很好的参考价值,希望对大
家有所帮助。一起跟随小编过来看看吧
首先这是VGG的结构图,VGG11则是红色框里的结构,共分五个block,如红框中的VGG11第一个block就是一个conv3-64卷
积层:
一,写VGG代码时,首先定义一个 vgg_block(n,in,out)方法,用来构建VGG中每个block中的卷积核和池化层:
n是这个block中卷积层的数目,in是输入的通道数,out是输出的通道数
有了block以后,我们还需要一个方法把形成的block叠在一起,我们定义这个方法叫vgg_stack:
def vgg_stack(num_convs, channels): # vgg_net = vgg_stack((1, 1, 2, 2, 2), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))
net = []
for n, c in zip(num_convs, channels):
in_c = c[0]
out_c = c[1]
net.append(vgg_block(n, in_c, out_c))
return nn.Sequential(*net)
右边的注释
vgg_net = vgg_stack((1, 1, 2, 2, 2), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))
里,(1, 1, 2, 2, 2)表示五个block里,各自的卷积层数目,((3, 64), (64, 128), (128, 256), (256, 512), (512, 512))表示每个block
中的卷积层的类型,如(3,64)表示这个卷积层输入通道数是3,输出通道数是64。vgg_stack方法返回的就是完整的vgg11模型
评论0
最新资源