# #!/usr/bin/env python
# # -*- encoding: utf-8 -*-
# '''
# @文件 :models.py
# @说明 :模型定义文件
# @时间 :2020/02/13 11:42:33
# @作者 :钱彬
# @版本 :1.0
# '''
import torch
from torch import nn
import torchvision
import math
class ConvolutionalBlock(nn.Module):
"""
卷积模块,由卷积层, BN归一化层, 激活层构成.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, batch_norm=False, activation=None):
"""
:参数 in_channels: 输入通道数
:参数 out_channels: 输出通道数
:参数 kernel_size: 核大小
:参数 stride: 步长
:参数 batch_norm: 是否包含BN层
:参数 activation: 激活层类型; 如果没有则为None
"""
super(ConvolutionalBlock, self).__init__()
if activation is not None:
activation = activation.lower()
assert activation in {'prelu', 'leakyrelu', 'tanh'}
# 层列表
layers = list()
# 1个卷积层
layers.append(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
padding=kernel_size // 2))
# 1个BN归一化层
if batch_norm is True:
layers.append(nn.BatchNorm2d(num_features=out_channels))
# 1个激活层
if activation == 'prelu':
layers.append(nn.PReLU())
elif activation == 'leakyrelu':
layers.append(nn.LeakyReLU(0.2))
elif activation == 'tanh':
layers.append(nn.Tanh())
# 合并层
self.conv_block = nn.Sequential(*layers)
def forward(self, input):
"""
前向传播
:参数 input: 输入图像集,张量表示,大小为 (N, in_channels, w, h)
:返回: 输出图像集,张量表示,大小为(N, out_channels, w, h)
"""
output = self.conv_block(input)
return output
class SubPixelConvolutionalBlock(nn.Module):
"""
子像素卷积模块, 包含卷积, 像素清洗和激活层.
"""
def __init__(self, kernel_size=3, n_channels=64, scaling_factor=2):
"""
:参数 kernel_size: 卷积核大小
:参数 n_channels: 输入和输出通道数
:参数 scaling_factor: 放大比例
"""
super(SubPixelConvolutionalBlock, self).__init__()
# 首先通过卷积将通道数扩展为 scaling factor^2 倍
self.conv = nn.Conv2d(in_channels=n_channels, out_channels=n_channels * (scaling_factor ** 2),
kernel_size=kernel_size, padding=kernel_size // 2)
# 进行像素清洗,合并相关通道数据
self.pixel_shuffle = nn.PixelShuffle(upscale_factor=scaling_factor)
# 最后添加激活层
self.prelu = nn.PReLU()
def forward(self, input):
"""
前向传播.
:参数 input: 输入图像数据集,张量表示,大小为(N, n_channels, w, h)
:返回: 输出图像数据集,张量表示,大小为 (N, n_channels, w * scaling factor, h * scaling factor)
"""
output = self.conv(input) # (N, n_channels * scaling factor^2, w, h)
output = self.pixel_shuffle(output) # (N, n_channels, w * scaling factor, h * scaling factor)
output = self.prelu(output) # (N, n_channels, w * scaling factor, h * scaling factor)
return output
class ResidualBlock(nn.Module):
"""
残差模块, 包含两个卷积模块和一个跳连.
"""
def __init__(self, kernel_size=3, n_channels=64):
"""
:参数 kernel_size: 核大小
:参数 n_channels: 输入和输出通道数(由于是ResNet网络,需要做跳连,因此输入和输出通道数是一致的)
"""
super(ResidualBlock, self).__init__()
# 第一个卷积块
self.conv_block1 = ConvolutionalBlock(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size,
batch_norm=True, activation='PReLu')
# 第二个卷积块
self.conv_block2 = ConvolutionalBlock(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size,
batch_norm=True, activation=None)
def forward(self, input):
"""
前向传播.
:参数 input: 输入图像集,张量表示,大小为 (N, n_channels, w, h)
:返回: 输出图像集,张量表示,大小为 (N, n_channels, w, h)
"""
residual = input # (N, n_channels, w, h)
output = self.conv_block1(input) # (N, n_channels, w, h)
output = self.conv_block2(output) # (N, n_channels, w, h)
output = output + residual # (N, n_channels, w, h)
return output
class SRResNet(nn.Module):
"""
SRResNet模型
"""
def __init__(self, large_kernel_size=9, small_kernel_size=3, n_channels=64, n_blocks=16, scaling_factor=4):
"""
:参数 large_kernel_size: 第一层卷积和最后一层卷积核大小
:参数 small_kernel_size: 中间层卷积核大小
:参数 n_channels: 中间层通道数
:参数 n_blocks: 残差模块数
:参数 scaling_factor: 放大比例
"""
super(SRResNet, self).__init__()
# 放大比例必须为 2、 4 或 8
scaling_factor = int(scaling_factor)
assert scaling_factor in {2, 4, 8}, "放大比例必须为 2、 4 或 8!"
# 第一个卷积块
self.conv_block1 = ConvolutionalBlock(in_channels=3, out_channels=n_channels, kernel_size=large_kernel_size,
batch_norm=False, activation='PReLu')
# 一系列残差模块, 每个残差模块包含一个跳连接
self.residual_blocks = nn.Sequential(
*[ResidualBlock(kernel_size=small_kernel_size, n_channels=n_channels) for i in range(n_blocks)])
# 第二个卷积块
self.conv_block2 = ConvolutionalBlock(in_channels=n_channels, out_channels=n_channels,
kernel_size=small_kernel_size,
batch_norm=True, activation=None)
# 放大通过子像素卷积模块实现, 每个模块放大两倍
n_subpixel_convolution_blocks = int(math.log2(scaling_factor))
self.subpixel_convolutional_blocks = nn.Sequential(
*[SubPixelConvolutionalBlock(kernel_size=small_kernel_size, n_channels=n_channels, scaling_factor=2) for i
in range(n_subpixel_convolution_blocks)])
# 最后一个卷积模块
self.conv_block3 = ConvolutionalBlock(in_channels=n_channels, out_channels=3, kernel_size=large_kernel_size,
batch_norm=False, activation='Tanh')
def forward(self, lr_imgs):
"""
前向传播.
:参数 lr_imgs: 低分辨率输入图像集, 张量表示,大小为 (N, 3, w, h)
:返回: 高分辨率输出图像集, 张量表示, 大小为 (N, 3, w * scaling factor, h * scaling factor)
"""
output = self.conv_block1(lr_imgs) # (16, 3, 24, 24)
residual = output # (16, 64, 24, 24)
output = self.residual_blocks(output) # (16, 64, 24, 24)
output = self.conv_block2(output) # (16, 64, 24, 24)
output = output + residual # (16, 64, 24, 24)
output = self.subpixel_convolutional_blocks(output) # (16, 64, 24 * 4, 24 * 4)
sr_imgs = self.conv_block3(output) # (16, 3, 24 * 4, 24 * 4)
return sr_imgs
class Generator(nn.Module):
"""
生成器模型,其结构与SRResNet完全一致.
"""
def __init__(self, large_kernel_size=9, small_kernel_size=3, n_channels=64, n_blocks=16, scaling_factor=4):
"""
参数 large_kernel_size:第一层和最后一层卷积核
没有合适的资源?快使用搜索试试~ 我知道了~
资源详情
资源评论
资源推荐
收起资源包目录
超分辨率重建SRGAN代码(包括将SRGAN生成器单独拿出来做残差网络) (523个子文件)
events.out.tfevents.1643937347.DESKTOP-OL4H2J3.5988.0 20.1MB
events.out.tfevents.1642579617.DESKTOP-OL4H2J3.14976.0 3.2MB
events.out.tfevents.1642516648.DESKTOP-OL4H2J3.10760.0 2.51MB
events.out.tfevents.1642564335.DESKTOP-OL4H2J3.22836.0 2.31MB
events.out.tfevents.1642494171.DESKTOP-OL4H2J3.2068.0 131KB
events.out.tfevents.1642491317.DESKTOP-OL4H2J3.12688.0 119KB
events.out.tfevents.1644001953.DESKTOP-OL4H2J3.14204.0 40B
events.out.tfevents.1643999850.DESKTOP-OL4H2J3.9172.0 40B
events.out.tfevents.1643974578.DESKTOP-OL4H2J3.16836.0 40B
events.out.tfevents.1644055336.DESKTOP-OL4H2J3.14048.0 40B
events.out.tfevents.1643968961.DESKTOP-OL4H2J3.14876.0 40B
events.out.tfevents.1643999153.DESKTOP-OL4H2J3.8504.0 40B
events.out.tfevents.1642566507.DESKTOP-OL4H2J3.8896.0 40B
events.out.tfevents.1643965452.DESKTOP-OL4H2J3.10364.0 40B
events.out.tfevents.1643941596.DESKTOP-OL4H2J3.6848.0 40B
events.out.tfevents.1643980201.DESKTOP-OL4H2J3.4424.0 40B
events.out.tfevents.1644044096.DESKTOP-OL4H2J3.2748.0 40B
events.out.tfevents.1643949313.DESKTOP-OL4H2J3.2496.0 40B
events.out.tfevents.1643978096.DESKTOP-OL4H2J3.9172.0 40B
events.out.tfevents.1644008971.DESKTOP-OL4H2J3.848.0 40B
events.out.tfevents.1643961939.DESKTOP-OL4H2J3.8024.0 40B
events.out.tfevents.1643957028.DESKTOP-OL4H2J3.14828.0 40B
events.out.tfevents.1644050421.DESKTOP-OL4H2J3.13280.0 40B
events.out.tfevents.1643985809.DESKTOP-OL4H2J3.9576.0 40B
events.out.tfevents.1643997747.DESKTOP-OL4H2J3.17948.0 40B
events.out.tfevents.1644020210.DESKTOP-OL4H2J3.13512.0 40B
events.out.tfevents.1643993536.DESKTOP-OL4H2J3.6360.0 40B
events.out.tfevents.1644021610.DESKTOP-OL4H2J3.7664.0 40B
events.out.tfevents.1642578920.DESKTOP-OL4H2J3.8784.0 40B
events.out.tfevents.1644048314.DESKTOP-OL4H2J3.9516.0 40B
events.out.tfevents.1643981605.DESKTOP-OL4H2J3.18340.0 40B
events.out.tfevents.1644010379.DESKTOP-OL4H2J3.2740.0 40B
events.out.tfevents.1643987919.DESKTOP-OL4H2J3.12024.0 40B
events.out.tfevents.1644030747.DESKTOP-OL4H2J3.15576.0 40B
events.out.tfevents.1643956323.DESKTOP-OL4H2J3.5748.0 40B
events.out.tfevents.1644032850.DESKTOP-OL4H2J3.10480.0 40B
events.out.tfevents.1644011081.DESKTOP-OL4H2J3.7224.0 40B
events.out.tfevents.1643962642.DESKTOP-OL4H2J3.6824.0 40B
events.out.tfevents.1642572264.DESKTOP-OL4H2J3.21056.0 40B
events.out.tfevents.1644013892.DESKTOP-OL4H2J3.4020.0 40B
events.out.tfevents.1644030045.DESKTOP-OL4H2J3.16764.0 40B
events.out.tfevents.1643983009.DESKTOP-OL4H2J3.2176.0 40B
events.out.tfevents.1644055340.DESKTOP-OL4H2J3.1788.0 40B
events.out.tfevents.1643961240.DESKTOP-OL4H2J3.9328.0 40B
events.out.tfevents.1644037064.DESKTOP-OL4H2J3.564.0 40B
events.out.tfevents.1643943698.DESKTOP-OL4H2J3.18192.0 40B
events.out.tfevents.1643973879.DESKTOP-OL4H2J3.17524.0 40B
events.out.tfevents.1643971769.DESKTOP-OL4H2J3.9132.0 40B
events.out.tfevents.1644058148.DESKTOP-OL4H2J3.6300.0 40B
events.out.tfevents.1644034255.DESKTOP-OL4H2J3.18080.0 40B
events.out.tfevents.1643952118.DESKTOP-OL4H2J3.13012.0 40B
events.out.tfevents.1643961943.DESKTOP-OL4H2J3.11992.0 40B
events.out.tfevents.1643963348.DESKTOP-OL4H2J3.7932.0 40B
events.out.tfevents.1643979496.DESKTOP-OL4H2J3.11764.0 40B
events.out.tfevents.1644004060.DESKTOP-OL4H2J3.11568.0 40B
events.out.tfevents.1643952121.DESKTOP-OL4H2J3.14516.0 40B
events.out.tfevents.1643951420.DESKTOP-OL4H2J3.8076.0 40B
events.out.tfevents.1644040579.DESKTOP-OL4H2J3.9280.0 40B
events.out.tfevents.1643975989.DESKTOP-OL4H2J3.5960.0 40B
events.out.tfevents.1643979499.DESKTOP-OL4H2J3.5612.0 40B
events.out.tfevents.1643978092.DESKTOP-OL4H2J3.10588.0 40B
events.out.tfevents.1643939490.DESKTOP-OL4H2J3.6072.0 40B
events.out.tfevents.1644056740.DESKTOP-OL4H2J3.16532.0 40B
events.out.tfevents.1643987212.DESKTOP-OL4H2J3.18368.0 40B
events.out.tfevents.1643959831.DESKTOP-OL4H2J3.14924.0 40B
events.out.tfevents.1643938078.DESKTOP-OL4H2J3.9840.0 40B
events.out.tfevents.1644039178.DESKTOP-OL4H2J3.11384.0 40B
events.out.tfevents.1643949309.DESKTOP-OL4H2J3.17008.0 40B
events.out.tfevents.1642566502.DESKTOP-OL4H2J3.17520.0 40B
events.out.tfevents.1643986514.DESKTOP-OL4H2J3.17980.0 40B
events.out.tfevents.1644007571.DESKTOP-OL4H2J3.2052.0 40B
events.out.tfevents.1643999149.DESKTOP-OL4H2J3.13884.0 40B
events.out.tfevents.1644026533.DESKTOP-OL4H2J3.14676.0 40B
events.out.tfevents.1642578192.DESKTOP-OL4H2J3.9072.0 40B
events.out.tfevents.1643951416.DESKTOP-OL4H2J3.17260.0 40B
events.out.tfevents.1644001251.DESKTOP-OL4H2J3.16960.0 40B
events.out.tfevents.1643955625.DESKTOP-OL4H2J3.14648.0 40B
events.out.tfevents.1643968263.DESKTOP-OL4H2J3.16120.0 40B
events.out.tfevents.1642565789.DESKTOP-OL4H2J3.1540.0 40B
events.out.tfevents.1642576752.DESKTOP-OL4H2J3.6384.0 40B
events.out.tfevents.1643998446.DESKTOP-OL4H2J3.11924.0 40B
events.out.tfevents.1643973876.DESKTOP-OL4H2J3.4032.0 40B
events.out.tfevents.1643975985.DESKTOP-OL4H2J3.10680.0 40B
events.out.tfevents.1643990724.DESKTOP-OL4H2J3.5104.0 40B
events.out.tfevents.1644036361.DESKTOP-OL4H2J3.15036.0 40B
events.out.tfevents.1644015997.DESKTOP-OL4H2J3.1032.0 40B
events.out.tfevents.1643991430.DESKTOP-OL4H2J3.4884.0 40B
events.out.tfevents.1643993533.DESKTOP-OL4H2J3.11696.0 40B
events.out.tfevents.1643977390.DESKTOP-OL4H2J3.8716.0 40B
events.out.tfevents.1642570834.DESKTOP-OL4H2J3.18960.0 40B
events.out.tfevents.1644000550.DESKTOP-OL4H2J3.15104.0 40B
events.out.tfevents.1643983706.DESKTOP-OL4H2J3.6412.0 40B
events.out.tfevents.1643964047.DESKTOP-OL4H2J3.8248.0 40B
events.out.tfevents.1642564346.DESKTOP-OL4H2J3.15876.0 40B
events.out.tfevents.1643960537.DESKTOP-OL4H2J3.10636.0 40B
events.out.tfevents.1643944402.DESKTOP-OL4H2J3.6008.0 40B
events.out.tfevents.1643988618.DESKTOP-OL4H2J3.17916.0 40B
events.out.tfevents.1644039876.DESKTOP-OL4H2J3.15944.0 40B
events.out.tfevents.1644009677.DESKTOP-OL4H2J3.17424.0 40B
events.out.tfevents.1644045498.DESKTOP-OL4H2J3.13280.0 40B
共 523 条
- 1
- 2
- 3
- 4
- 5
- 6
Kidolle
- 粉丝: 14
- 资源: 2
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功
评论1