import torch
import torch.nn as nn
# residual block
class BasicBlock(nn.Module):
expansion = 1
def __init__(self,in_channel,out_channel,stride=1,downsample=None):
super(BasicBlock,self).__init__()
self.conv1 = nn.Conv2d(in_channel,out_channel,kernel_size=3,stride=stride,padding=1,bias=False) # 第一层的话,可能会缩小size,这时候 stride = 2
self.bn1 = nn.BatchNorm2d(out_channel)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channel,out_channel,kernel_size=3,stride=1,padding=1,bias=False)
self.bn2 = nn.BatchNorm2d(out_channel)
self.downsample = downsample
def forward(self,x):
identity = x
if self.downsample is not None: # 有下采样,意味着需要1*1进行降维,同时channel翻倍,residual block虚线部分
identity = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity
out = self.relu(out)
return out
# bottleneck
class Bottleneck(nn.Module):
expansion = 4 # 卷积核的变化
def __init__(self,in_channel,out_channel,stride=1,downsample=None):
super(Bottleneck,self).__init__()
# 1*1 降维度 --------> padding默认为 0,size不变,channel被降低
self.conv1 = nn.Conv2d(in_channel,out_channel,kernel_size=1,stride=1,bias=False)
self.bn1 = nn.BatchNorm2d(out_channel)
# 3*3 卷积
self.conv2 = nn.Conv2d(out_channel,out_channel,kernel_size=3,stride=stride,bias=False)
self.bn2 = nn.BatchNorm2d(out_channel)
# 1*1 还原维度 --------> padding默认为 0,size不变,channel被还原
self.conv3 = nn.Conv2d(out_channel,out_channel*self.expansion,kernel_size=1,stride=1,bias=False)
self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
# other
self.relu = nn.ReLU(inplace=True)
self.downsample =downsample
def forward(self,x):
identity = x
if self.downsample is not None:
identity = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out += identity
out = self.relu(out)
return out
# resnet
class ResNet(nn.Module):
def __init__(self,block,block_num,num_classes=1000,include_top=True):
super(ResNet, self).__init__()
self.include_top = include_top
self.in_channel = 64 # max pool 之后的 depth
# 网络最开始的部分,输入是RGB图像,经过卷积,图像size减半,通道变为64
self.conv1 = nn.Conv2d(3,self.in_channel,kernel_size=7,stride=2,padding=3,bias=False)
self.bn1 = nn.BatchNorm2d(self.in_channel)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1) # size减半,padding = 1
self.layer1 = self.__make_layer(block,64,block_num[0]) # conv2_x
self.layer2 = self.__make_layer(block,128,block_num[1],stride=2) # conv3_x
self.layer3 = self.__make_layer(block,256,block_num[2],stride=2) # conv4_X
self.layer4 = self.__make_layer(block,512,block_num[3],stride=2) # conv5_x
if self.include_top: # 分类部分
self.avgpool = nn.AdaptiveAvgPool2d((1,1)) # out_size = 1*1
self.fc = nn.Linear(512*block.expansion,num_classes)
def __make_layer(self,block,channel,block_num,stride=1):
downsample =None
if stride != 1 or self.in_channel != channel*block.expansion: # shortcut 部分,1*1 进行升维
downsample=nn.Sequential(
nn.Conv2d(self.in_channel,channel*block.expansion,kernel_size=1,stride=stride,bias=False),
nn.BatchNorm2d(channel*block.expansion)
)
layers =[]
layers.append(block(self.in_channel, channel, downsample =downsample, stride=stride))
self.in_channel = channel * block.expansion
for _ in range(1,block_num): # residual 实线的部分
layers.append(block(self.in_channel,channel))
return nn.Sequential(*layers)
def forward(self,x):
# resnet 前面的卷积部分
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
# residual 特征提取层
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# 分类
if self.include_top:
x = self.avgpool(x)
x = torch.flatten(x,start_dim=1)
x = self.fc(x)
return x
# 定义网络
def resnet34(num_classes=1000,include_top=True):
return ResNet(BasicBlock,[3,4,6,3],num_classes=num_classes,include_top=include_top)
def resnet101(num_classes=1000,include_top=True):
return ResNet(Bottleneck,[3,4,23,3],num_classes=num_classes,include_top=include_top)