没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
pytorch中获取模型中获取模型input/output shape实例实例
今天小编就为大家分享一篇pytorch中获取模型input/output shape实例,具有很好的参考价值,希望对大家有所
帮助。一起跟随小编过来看看吧
Pytorch官方目前无法像tensorflow, caffe那样直接给出shape信息,详见
https://github.com/pytorch/pytorch/pull/3043
以下代码算一种workaround。由于CNN, RNN等模块实现不一样,添加其他模块支持可能需要改代码。
例如RNN中bias是bool类型,其权重也不是存于weight属性中,不过我们只关注shape够用了。
该方法必须构造一个输入调用forward后(model(x)调用)才可获取shape
#coding:utf-8
from collections import OrderedDict
import torch
from torch.autograd import Variable
import torch.nn as nn
import models.crnn as crnn
import json
def get_output_size(summary_dict, output):
if isinstance(output, tuple):
for i in xrange(len(output)):
summary_dict[i] = OrderedDict()
summary_dict[i] = get_output_size(summary_dict[i],output[i])
else:
summary_dict['output_shape'] = list(output.size())
return summary_dict
def summary(input_size, model):
def register_hook(module):
def hook(module, input, output):
class_name = str(module.__class__).split('.')[-1].split("'")[0]
module_idx = len(summary)
m_key = '%s-%i' % (class_name, module_idx+1)
summary[m_key] = OrderedDict()
summary[m_key]['input_shape'] = list(input[0].size())
summary[m_key] = get_output_size(summary[m_key], output)
params = 0
if hasattr(module, 'weight'):
params += torch.prod(torch.LongTensor(list(module.weight.size())))
if module.weight.requires_grad:
summary[m_key]['trainable'] = True
else:
summary[m_key]['trainable'] = False
#if hasattr(module, 'bias'):
# params += torch.prod(torch.LongTensor(list(module.bias.size())))
summary[m_key]['nb_params'] = params
if not isinstance(module, nn.Sequential) and \
not isinstance(module, nn.ModuleList) and \
not (module == model):
hooks.append(module.register_forward_hook(hook))
# check if there are multiple inputs to the network
if isinstance(input_size[0], (list, tuple)):
x = [Variable(torch.rand(1,*in_size)) for in_size in input_size]
else:
x = Variable(torch.rand(1,*input_size))
# create properties
summary = OrderedDict()
hooks = []
# register hook
model.apply(register_hook)
# make a forward pass
model(x)
# remove these hooks
for h in hooks:
h.remove()
weixin_38544152
- 粉丝: 4
- 资源: 924
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功
- 1
- 2
前往页