import torch.nn as nn # 用来实现新的子网络模型
from torchvision.models import vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19, vgg19_bn
from torch import load # 记载vgg已训练模型
from torch.nn import ModuleDict # 用来实现新的子网络模型
from torch.nn.functional import relu # 激活函数
from torchvision.models.vgg import cfgs, make_layers # VGG网络配置与层构建工具
class NEW_VGG19(nn.ModuleDict):
# 在构造器初始化新的网络结构层
def __init__(self, is_avg_pool=True):
super().__init__()
# 池化层类型参数
self.is_avg_pool = is_avg_pool
# 使用数组表示卷积层名字,这样便于使用名字直接指定输出层
self.layer_names = [
"conv1_1", "conv1_2", "pool1",
"conv2_1", "conv2_2", "pool2",
"conv3_1", "conv3_2", "conv3_3", "conv3_4", "pool3",
"conv4_1", "conv4_2", "conv4_3", "conv4_4", "pool4",
"conv5_1", "conv5_2", "conv5_3", "conv5_4", "pool5"
]
# 使用make_layers函数构造VGG的卷积层
layers = make_layers(cfgs["E"])
# make_layers(cfgs["E"])生成的网络中,把过滤掉激活层
layers = filter(lambda m: not isinstance(m, nn.ReLU), layers)
# 根据参数is_avg_pool,把layers中的最大池化运算替换成均值池化。
layers = map(lambda m: nn.AvgPool2d(2, 2) if (isinstance(m, nn.MaxPool2d) and self.is_avg_pool) else m, layers)
# 把过滤,替换后的vgg卷积层与每次的名字生成字典,并更新到ModuleDict结构中
layers = dict(zip(self.layer_names, layers))
for name, layer in layers.items():
self[name] = layer
# self.update(dict(zip(self.layer_names, layers))) # update有排序功能
# self.update(dd)
# 关闭自动求导
for p in self.parameters():
p.requires_grad_(False)
# 加载原生VGG的训练参数模型(需要原生参数模型映射到新的模型的对应层)
def load_state_dict(self, state_dict, **kwargs):
# 原来存储结构中的层数,注意其中只包含卷积层与池化层(与layer_names对应)
original_names = ["0", "2", "4", "5", "7", "9", "10", "12", "14", "16", "18", "19", "21", "23", "25", "27", "28", "30", "32", "34", "36"]
# 把新的层名与原来层的序号对应起来,这样便于赋值正确
new_mapping = dict(zip(original_names, self.layer_names))
# 生成新的卷积层与池化层的字典(用新的名字替换旧的名字)
new_state_dict = state_dict.copy()
# 循环替换
for k in state_dict.keys():
# 把new_state_dict中分类层删除
if "classifier" in k:
# 删除分类层
del new_state_dict[k]
# 继续下一次循环
continue
# 取卷积层或者池化层索引id(0-36,一共37层)
idx = k.split(".")[1]
# 根据id获取新的名字(替换k中features.idx,保留weights与bias后缀)
name = k.replace("features." + idx, new_mapping[idx])
# 增加一个新的key与值
new_state_dict[name] = state_dict[k]
# 删除原来的key与值
del new_state_dict[k]
# 加载新的存储结构到模型
super().load_state_dict(new_state_dict, **kwargs)
# 定制输出多层结果
def forward(self, x, layers=None):
# 如果没有layers参数,就默认输出所有层
layers = layers or self.keys()
# 输入数据
outputs = {"input": x} # 输入数据永远放在字典的最后一个。
for name, layer in self.items():
inp = outputs[[*outputs.keys()][-1]] # 取出输入数据
out = relu(layer(inp)) if "pool" not in name else layer(inp) # 计算输出,pool层不做激活函数运算
outputs.update({name: out}) # 添加输入数据到字典
del outputs[[*outputs.keys()][-2]] # 删除本次输入数据,本次输出作为下一次输入。
if name in layers:
yield outputs[name]
# 测试代码
if __name__ == "__main__":
new_vgg19_net = NEW_VGG19(is_avg_pool=True)
state_dict =load("../data/vgg19-dcbb9e9d.pth")
new_vgg19_net.load_state_dict(state_dict)
print("----------")
print(new_vgg19_net)
没有合适的资源?快使用搜索试试~ 我知道了~
资源详情
资源评论
资源推荐
收起资源包目录
![package](https://csdnimg.cn/release/downloadcmsfe/public/img/package.f3fc750b.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![folder](https://csdnimg.cn/release/downloadcmsfe/public/img/folder.005fa2e5.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![folder](https://csdnimg.cn/release/downloadcmsfe/public/img/folder.005fa2e5.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
![file-type](https://csdnimg.cn/release/download/static_files/pc/images/minetype/UNKNOWN.png)
共 8 条
- 1
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![rar](https://img-home.csdnimg.cn/images/20210720083606.png)
![pdf](https://img-home.csdnimg.cn/images/20210720083512.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![rar](https://img-home.csdnimg.cn/images/20210720083606.png)
![avatar](https://profile-avatar.csdnimg.cn/955e2f54fd1049edb763ecb6b8309a38_a1234556667.jpg!1)
撸码的xiao摩羯
- 粉丝: 183
- 资源: 92
上传资源 快速赚钱
我的内容管理 展开
我的资源 快来上传第一个资源
我的收益
登录查看自己的收益我的积分 登录查看自己的积分
我的C币 登录后查看C币余额
我的收藏
我的下载
下载帮助
![voice](https://csdnimg.cn/release/downloadcmsfe/public/img/voice.245cc511.png)
![center-task](https://csdnimg.cn/release/downloadcmsfe/public/img/center-task.c2eda91a.png)
安全验证
文档复制为VIP权益,开通VIP直接复制
![dialog-icon](https://csdnimg.cn/release/downloadcmsfe/public/img/green-success.6a4acb44.png)
评论0