import torch
import glob
import scipy
import numpy as np
import os
from scipy.io import loadmat
from ResNet import ResidualNet
from torch import optim
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from data_generator import DenoisingDataset
batch_size = 16
clean_dir = 'D:\\DAS数据集\\duoleixing\\chunjing\\'
noise_dir = 'D:\\DAS数据集\\duoleixing\\suijinoise\\'
save_dir = 'D:\\DAS数据集\\duoleixing\\suiji_model\\'
# 创建网络实例
model = ResidualNet()
model.train()
criterion = torch.nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
#####数据加载函数
file_list = glob.glob(clean_dir + '/*.mat') # get name list of all .png files
data1 = []
for i in range(len(file_list)):
mat_content = loadmat(file_list[i])
patches = mat_content['chunjing']
data1.append(patches)
data1 = np.array(data1, dtype='float32')
data1 = np.expand_dims(data1, axis=3) #给每个数据加一个中括号
discard_n = len(data1) - len(data1) // batch_size * batch_size # because of batch namalization
data1 = np.delete(data1, range(discard_n), axis=0)
print('^_^-clean data finished-^_^')
file_list2 = glob.glob(noise_dir + '/*.mat') # get name list of all .png files
data2 = []
for i in range(len(file_list2)):
mat_content2 = loadmat(file_list2[i])
patches2 = mat_content2['noise']
data2.append(patches2)
data2 = np.array(data2, dtype='float32')
data2 = np.expand_dims(data2, axis=3) #给每个数据加一个中括号
discard_n = len(data2) - len(data2) // batch_size * batch_size # because of batch namalization
data2 = np.delete(data2, range(discard_n), axis=0)
print('^_^-noise data finished-^_^')
#####数据加载结束
xs = torch.from_numpy(data1.transpose((0, 3, 1, 2)))
sigma = torch.from_numpy(data2.transpose((0, 3, 1, 2)))
DDataset = DenoisingDataset(xs, sigma)
DLoader = DataLoader(dataset=DDataset, num_workers=0, drop_last=True, batch_size=batch_size, shuffle=True)
epoch_loss = 0
num_epochs = 100
for epoch in range(num_epochs):
for n_count, batch_yx in enumerate(DLoader):
optimizer.zero_grad()
targets, inputs = batch_yx[1], batch_yx[0]
outputs = model(inputs) ##得到输入输出
loss = criterion(outputs, targets)
epoch_loss += loss.item()
loss.backward()
optimizer.step()
print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}')
torch.save(model, os.path.join(save_dir, 'model_%03d.pth' % (epoch + 1)))
![avatar](https://profile-avatar.csdnimg.cn/default.jpg!1)
highheart
- 粉丝: 25
- 资源: 8
最新资源
- #_ssm_127_mysql_私人书店管理系统_.zip
- #_ssm_128_mysql_网络安全与信息管理学院班级管理系统_.zip
- #_ssm_132_mysql_校园生活管理系统_.zip
- #_ssm_133_mysql_校园招聘信息管理系统_.zip
- #_ssm_135_mysql_新疆旅游管理系统_.zip
- #_ssm_139_mysql_一站式乡村服务系统wlw_.zip
- #_ssm_137_mysql_数据结构课堂学生考勤管理系统_.zip
- #_ssm_145_mysql_中学教务管理系统_.zip
- #_ssm_146_mysql_作业提交与批改程序_.zip
- #_ssm_147_mysql_毕业生离校管理系统_.zip
- #_ssm_151_mysql_在线汽车交易系统_.zip
- C++学习项目资料分享
- 利用ai漫改渐变国庆头像项目玩法教程,可一键生成风口赛道
- #_ssm_154_mysql_中小型超市管理系统_.zip
- 混剪德云语录项目玩法教程,带你揭秘流量密码
- Redis-Windows-8.0
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
![feedback](https://img-home.csdnimg.cn/images/20220527035711.png)
![feedback](https://img-home.csdnimg.cn/images/20220527035711.png)
![feedback-tip](https://img-home.csdnimg.cn/images/20220527035111.png)