import torch
import torch.nn as nn
# residual block
class BasicBlock(nn.Module):
expansion = 1
def __init__(self,in_channel,out_channel,stride=1,downsample=None):
super(BasicBlock,self).__init__()
self.conv1 = nn.Conv2d(in_channel,out_channel,kernel_size=3,stride=stride,padding=1,bias=False) # 第一层的话,可能会缩小size,这时候 stride = 2
self.bn1 = nn.BatchNorm2d(out_channel)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channel,out_channel,kernel_size=3,stride=1,padding=1,bias=False)
self.bn2 = nn.BatchNorm2d(out_channel)
self.downsample = downsample
def forward(self,x):
identity = x
if self.downsample is not None: # 有下采样,意味着需要1*1进行降维,同时channel翻倍,residual block虚线部分
identity = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity
out = self.relu(out)
return out
# bottleneck
class Bottleneck(nn.Module):
expansion = 4 # 卷积核的变化
def __init__(self,in_channel,out_channel,stride=1,downsample=None):
super(Bottleneck,self).__init__()
# 1*1 降维度 --------> padding默认为 0,size不变,channel被降低
self.conv1 = nn.Conv2d(in_channel,out_channel,kernel_size=1,stride=1,bias=False)
self.bn1 = nn.BatchNorm2d(out_channel)
# 3*3 卷积
self.conv2 = nn.Conv2d(out_channel,out_channel,kernel_size=3,stride=stride,bias=False)
self.bn2 = nn.BatchNorm2d(out_channel)
# 1*1 还原维度 --------> padding默认为 0,size不变,channel被还原
self.conv3 = nn.Conv2d(out_channel,out_channel*self.expansion,kernel_size=1,stride=1,bias=False)
self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
# other
self.relu = nn.ReLU(inplace=True)
self.downsample =downsample
def forward(self,x):
identity = x
if self.downsample is not None:
identity = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out += identity
out = self.relu(out)
return out
# resnet
class ResNet(nn.Module):
def __init__(self,block,block_num,num_classes=1000,include_top=True):
super(ResNet, self).__init__()
self.include_top = include_top
self.in_channel = 64 # max pool 之后的 depth
# 网络最开始的部分,输入是RGB图像,经过卷积,图像size减半,通道变为64
self.conv1 = nn.Conv2d(3,self.in_channel,kernel_size=7,stride=2,padding=3,bias=False)
self.bn1 = nn.BatchNorm2d(self.in_channel)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1) # size减半,padding = 1
self.layer1 = self.__make_layer(block,64,block_num[0]) # conv2_x
self.layer2 = self.__make_layer(block,128,block_num[1],stride=2) # conv3_x
self.layer3 = self.__make_layer(block,256,block_num[2],stride=2) # conv4_X
self.layer4 = self.__make_layer(block,512,block_num[3],stride=2) # conv5_x
if self.include_top: # 分类部分
self.avgpool = nn.AdaptiveAvgPool2d((1,1)) # out_size = 1*1
self.fc = nn.Linear(512*block.expansion,num_classes)
def __make_layer(self,block,channel,block_num,stride=1):
downsample =None
if stride != 1 or self.in_channel != channel*block.expansion: # shortcut 部分,1*1 进行升维
downsample=nn.Sequential(
nn.Conv2d(self.in_channel,channel*block.expansion,kernel_size=1,stride=stride,bias=False),
nn.BatchNorm2d(channel*block.expansion)
)
layers =[]
layers.append(block(self.in_channel, channel, downsample =downsample, stride=stride))
self.in_channel = channel * block.expansion
for _ in range(1,block_num): # residual 实线的部分
layers.append(block(self.in_channel,channel))
return nn.Sequential(*layers)
def forward(self,x):
# resnet 前面的卷积部分
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
# residual 特征提取层
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# 分类
if self.include_top:
x = self.avgpool(x)
x = torch.flatten(x,start_dim=1)
x = self.fc(x)
return x
# 定义网络
def resnet34(num_classes=1000,include_top=True):
return ResNet(BasicBlock,[3,4,6,3],num_classes=num_classes,include_top=include_top)
def resnet101(num_classes=1000,include_top=True):
return ResNet(Bottleneck,[3,4,23,3],num_classes=num_classes,include_top=include_top)
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
收起资源包目录
demo.zip (15个子文件)
demo
resnet_pre_model.pth 83.27MB
data
cifar-10-batches-py
readme.html 88B
data_batch_2 29.6MB
data_batch_3 29.6MB
data_batch_4 29.6MB
data_batch_5 29.6MB
data_batch_1 29.6MB
batches.meta 158B
test_batch 29.6MB
cifar-10-python.tar.gz 162.6MB
resnet.pth 81.34MB
predict.py 2KB
model.py 5KB
train.py 3KB
__pycache__
model.cpython-310.pyc 4KB
共 15 条
- 1
资源评论
听风吹等浪起
- 粉丝: 1w+
- 资源: 1178
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功