import itertools
import os
import random
import argparse
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
print(" -- 使用GPU进行训练 -- ")
def parseArgs():
parser = argparse.ArgumentParser(description='manual to this script')
parser.add_argument("--save_path", type=str, default="./result/")
parser.add_argument("--data_path", type=str, default="../Data/")
parser.add_argument("--origin", type=str, default="GT")
parser.add_argument("--hazy", type=str, default="hazy")
parser.add_argument("--batch_size", type=int, default=4)
args = parser.parse_args()
return args
## 生成器 U-Net(输入照片为256*256) ##
class Generator(nn.Module):
def __init__(self, in_ch, out_ch, ngf=64):
"""
定义生成器的网络结构
:param in_ch: 输入数据的通道数
:param out_ch: 输出数据的通道数
:param ngf: 第一层卷积的通道数 number of generator's first conv filters
"""
super(Generator, self).__init__()
# 下面的激活函数都放在下一个模块的第一步 是为了skip-connect方便
# 左半部分 U-Net encoder
# 每层输入大小折半,从输入图片大小256开始
# 256 * 256(输入)
self.en1 = nn.Sequential(
nn.Conv2d(in_ch, ngf, kernel_size=4, stride=2, padding=1),
# 输入图片已正则化 不需BatchNorm
)
# 128 * 128
self.en2 = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ngf, ngf * 2, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ngf * 2)
)
# 64 * 64
self.en3 = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ngf * 2, ngf * 4, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ngf * 4)
)
# 32 * 32
self.en4 = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ngf * 4, ngf * 8, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ngf * 8)
)
# 16 * 16
self.en5 = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ngf * 8)
)
# 8 * 8
self.en6 = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ngf * 8)
)
# 4 * 4
self.en7 = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ngf * 8)
)
# 2 * 2
self.en8 = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1)
# Encoder输出不用BatchNorm
)
# 右半部分 U-Net decoder
# skip-connect: 前一层的输出+对称的卷积层
# 1 * 1(输入)
self.de1 = nn.Sequential(
nn.ReLU(inplace=True),
nn.ConvTranspose2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ngf * 8),
nn.Dropout(p=0.5)
)
# 2 * 2
self.de2 = nn.Sequential(
nn.ReLU(inplace=True),
# skip-connect 所以输入管道数是之前输出的2倍
nn.ConvTranspose2d(ngf * 8 * 2, ngf * 8, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ngf * 8),
nn.Dropout(p=0.5)
)
# 4 * 4
self.de3 = nn.Sequential(
nn.ReLU(inplace=True),
nn.ConvTranspose2d(ngf * 8 * 2, ngf * 8, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ngf * 8),
nn.Dropout(p=0.5)
)
# 8 * 8
self.de4 = nn.Sequential(
nn.ReLU(inplace=True),
nn.ConvTranspose2d(ngf * 8 * 2, ngf * 8, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ngf * 8),
nn.Dropout(p=0.5)
)
# 16 * 16
self.de5 = nn.Sequential(
nn.ReLU(inplace=True),
nn.ConvTranspose2d(ngf * 8 * 2, ngf * 4, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ngf * 4),
nn.Dropout(p=0.5)
)
# 32 * 32
self.de6 = nn.Sequential(
nn.ReLU(inplace=True),
nn.ConvTranspose2d(ngf * 4 * 2, ngf * 2, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ngf * 2),
nn.Dropout(p=0.5)
)
# 64 * 64
self.de7 = nn.Sequential(
nn.ReLU(inplace=True),
nn.ConvTranspose2d(ngf * 2 * 2, ngf, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ngf),
nn.Dropout(p=0.5)
)
# 128 * 128
self.de8 = nn.Sequential(
nn.ReLU(inplace=True),
nn.ConvTranspose2d(ngf * 2, out_ch, kernel_size=4, stride=2, padding=1),
# Encoder输出不用BatchNorm
nn.Tanh()
)
def forward(self, X):
"""
生成器模块前向传播
:param X: 输入生成器的数据
:return: 生成器的输出
"""
# Encoder
en1_out = self.en1(X)
en2_out = self.en2(en1_out)
en3_out = self.en3(en2_out)
en4_out = self.en4(en3_out)
en5_out = self.en5(en4_out)
en6_out = self.en6(en5_out)
en7_out = self.en7(en6_out)
en8_out = self.en8(en7_out)
# Decoder
de1_out = self.de1(en8_out)
de1_cat = torch.cat([de1_out, en7_out], dim=1) # cat by channel
de2_out = self.de2(de1_cat)
de2_cat = torch.cat([de2_out, en6_out], 1)
de3_out = self.de3(de2_cat)
de3_cat = torch.cat([de3_out, en5_out], 1)
de4_out = self.de4(de3_cat)
de4_cat = torch.cat([de4_out, en4_out], 1)
de5_out = self.de5(de4_cat)
de5_cat = torch.cat([de5_out, en3_out], 1)
de6_out = self.de6(de5_cat)
de6_cat = torch.cat([de6_out, en2_out], 1)
de7_out = self.de7(de6_cat)
de7_cat = torch.cat([de7_out, en1_out], 1)
de8_out = self.de8(de7_cat)
return de8_out
## 辨别器 PatchGAN(其实就是卷积网络而已) ##
class Discriminator(nn.Module):
def __init__(self, in_ch, ndf=64):
"""
定义判别器的网络结构
:param in_ch: 输入数据的通道数
:param ndf: 第一层卷积的通道数 number of discriminator's first conv filters
"""
super(Discriminator, self).__init__()
# 不是输出一个表示真假概率的实数,而是一个N*N的Patch矩阵(此处为30*30),其中每一块对应输入数据的一小块
# in_ch + out_ch 是为将对应真假数据同时输入
# 256 * 256(输入)
self.layer1 = nn.Sequential(
nn.Conv2d(in_ch, ndf, kernel_size=4, stride=2, padding=1),
# 输入图片已正则化 不需BatchNorm
nn.LeakyReLU(0.2, inplace=True)
)
# 128 * 128
self.layer2 = nn.Sequential(
nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True)
)
# 64 * 64
self.layer3 = nn.Sequential(
nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True)
)
# 32 * 32
self.layer4 = nn.Sequential(
nn.Conv2d(ndf * 4, ndf * 8
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
基于Pytorch实现对偶生成对抗网络来实现图像去雾python源码+项目说明+代码注释.zip 【资源介绍】 DualGan含有两个生成器和辨别器 本项目中为同样结构,生成器为U-Net,辨别器为PatchGan的辨别器 G_A:有雾生成无雾 G_B: 无雾生成有雾 D_A: 辨别G_B生成的有雾图像,输入为6通道 D_B: 辨别G_A生成的无雾图像,输入为6通道 train.py用来训练网络 predict.py用来预测无雾图像 预训练好的模型在model下 训练方法 将成对的图片分别放在clear和hazy下,并放在data_path下 然后运行train.py并输入需要的参数即可 保存路径--save_path, default="./save/" 训练数据路径--data_path, default="../Data/" 训练数据清晰图像路径--clear, default="clear" 训练数据有雾图像路径--hazy", default="hazy" --image_size, default=256 --batch_size, default=4 更多请见项目说明!
资源推荐
资源详情
资源评论
收起资源包目录
基于Pytorch实现对偶生成对抗网络来实现图像去雾python源码+项目说明+代码注释.zip (24个子文件)
项目说明.md 1KB
loss.png 80KB
predict.py 870B
net
Discriminator.py 2KB
Generator.py 5KB
test_data
1404_7.png 283KB
1423_5.png 296KB
1408_10.png 347KB
1414_10.png 229KB
1403_4.png 325KB
predict
1408_10.jpg 15KB
1423_5.jpg 15KB
1404_7.jpg 12KB
1414_10.jpg 11KB
1403_4.jpg 12KB
model
discriminator_b.pkl 10.58MB
discriminator_a.pkl 10.58MB
dual.py 20KB
train.py 8KB
util
loader.py 2KB
parseArgs.py 845B
showPlit.py 362B
logger.py 651B
pre_loader.py 2KB
共 24 条
- 1
z同学的编程之路
- 粉丝: 1861
- 资源: 2130
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- sony 索尼IMX334摄像头模组电路板AD版硬件PCB图(6层板).zip
- 基于flask和echarts融合交易策略的bitfinex可视化微服务.zip
- 包含了wvp-assist.tar wvp-talk.tar zlmediakit.tar .
- 3r4efgh53wgrf43tw
- 2024新版Java基础从入门到精通全套视频+资料下载
- Spring AI大模型视频教程+ChatGPT视频教程+OpenAI大模型视频教程(资料+视频教程)
- ABB工业机器人教程PDF版本
- 123321123323211
- yolov8实战第八天-pyqt5-yolov8实现车牌识别系统(论文(8700+字+数据集+完整部署代码+代码使用说明)
- 三相桥式全桥整流电路MATALB Simulink仿真文件
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功
- 1
- 2
- 3
前往页