import os
import time
import numpy as np
import torch
import torchvision.utils as vutils
from matplotlib import pyplot as plt
from sklearn.metrics import roc_curve, auc
from scipy.optimize import brentq
from scipy.interpolate import interp1d
import torch.nn as nn
import model
import show
import record
class NNGraph(object):
def __init__(self, dataloader, config, isize):
super(NNGraph, self).__init__()
self.config = config
self.isize = isize
self.train_model = self._get_train_model(config)
record.record_dict(self.config, self.train_model["config"])
self.config = self.train_model["config"]
self.dataloader = dataloader
def _get_train_model(self, config):
train_model = model.init_train_model(config, self.isize)
# train_model = self._load_train_model(train_model)
return train_model
def _save_train_model(self):
model_dict = model.get_model_dict(self.train_model)
file_full_path = record.get_check_point_file_full_path(self.config)
torch.save(model_dict, file_full_path)
def _load_train_model(self, train_model):
'''
path: save/"dataset_image_size"_"batch_size"_
"number_of_generator_feature"_"number_of_discriminator_feature"_"size_of_z_latent"_"learn_rate"
/checkpoint.tar
'''
file_full_path = record.get_check_point_file_full_path(self.config)
if os.path.exists(file_full_path) and self.config["train_load_check_point_file"]:
checkpoint = torch.load(file_full_path)
train_model = model.load_model_dict(train_model, checkpoint)
return train_model
def _train_step(self, data, i):
netG = self.train_model["netG"]
optimizerG = self.train_model["optimizerG"]
netD = self.train_model["netD"]
optimizerD = self.train_model["optimizerD"]
device = self.config["device"]
netDTeacher = self.train_model["netTeacher"]
optimizerDTeacher = self.train_model["optimizerTeacher"]
# real_data = data[0].to(device)
input = torch.empty(size=(self.config["batch_size"], 3, self.isize, self.isize), dtype=torch.float32,
device=device)
label = torch.empty(size=(self.config["batch_size"],), dtype=torch.float32, device=device)
# gt = torch.empty(size=(self.config["batch_size"],), dtype=torch.long, device=device)
real_label = torch.ones(size=(self.config["batch_size"],), dtype=torch.float32, device=device)
fake_label = torch.zeros(size=(self.config["batch_size"],), dtype=torch.float32, device=device)
with torch.no_grad():
input.resize_(data[0].size()).copy_(data[0])
# gt.resize_(data[1].size()).copy_(data[1])
label.resize_(data[1].size())
fake, latent_i, latent_o = netG(input)
# _, latent_o = netG(fake)
pred_real, feat_real, real_last, _ = netD(input)
pred_fake, feat_fake, fake_last, _ = netD(fake.detach())
# pred_real, feat_real = netD(input)
# pred_fake, feat_fake = netD(fake.detach())
pred_real_teacher, real_last_teacher = netDTeacher(input)
pred_fake_teacher, fake_last_teacher = netDTeacher(fake.detach())
errD = torch.tensor([0])
errG = model.get_Generator_loss(netG, netD, optimizerG, input, fake, latent_i, latent_o, self.config)
if i % self.config["generator_learntimes"] == 0:
for p in netD.parameters(): # reset requires_grad
p.requires_grad = True # they are set to False below in netG update
for parm in netD.parameters():
parm.data.clamp_(-self.config["clamp_num"], self.config["clamp_num"])
errD = model.get_Discriminator_loss(netD, optimizerD, pred_real, pred_fake, real_label, fake_label,
real_last, fake_last, real_last_teacher, fake_last_teacher, optimizerDTeacher)
#errD = model.get_Discriminator_loss(netD, optimizerD, pred_real, pred_fake, real_label, fake_label)
# for p in netD.parameters():
# p.requires_grad = False
#if errD.item() < 1e-5:
#self.train_model["netG"].apply(model._weights_init)
return errD, errG
'''
noise = model.get_noise(real_data, self.config)
fake_data = netG(noise)
label = model.get_label(real_data, self.config)
label = label.to(torch.float32)
errD, D_x, D_G_z1 = model.get_Discriminator_loss(netD, optimizerD, real_data, fake_data.detach(), label,
criterion, self.config)
errG, D_G_z2 = model.get_Generator_loss(netG, netD, optimizerG, fake_data, label, criterion, self.config)
return errD, errG, D_x, D_G_z1, D_G_z2
'''
def _train_a_step(self, data, i, epoch):
start = time.time()
errD, errG = self._train_step(data, i)
end = time.time()
step_time = end - start
self.train_model["take_time"] = self.train_model["take_time"] + step_time
print_every = self.config["print_every"]
if i % print_every == 0:
record.print_status(step_time*print_every,
self.train_model["take_time"],
epoch,
i,
errD,
errG,
self.config,
self.dataloader)
return errD, errG
def _DCGAN_eval(self):
# fixed_noise: 64, nz, 1, 1
fixed_noise = self.train_model["fixed_noise"]
with torch.no_grad():
netG = self.train_model["netG"]
fake = netG(fixed_noise).detach().cpu() # 64, nc, 64, 64
return fake
def _save_generator_images(self, iters, epoch, i):
num_epochs = self.config["num_epochs"]
save_every = self.config["save_every"]
img_list = self.train_model["img_list"]
if (iters % save_every == 0) or ((epoch == num_epochs-1) and (i == len(self.dataloader)-1)):
fake = self._DCGAN_eval() # 64, nc, 64, 64
img_one = vutils.make_grid(fake, padding=2, normalize=True)
img_list.append(img_one)
show._show_one_img(img_one)
self._save_train_model()
def _train_iters(self):
num_epochs = self.config["num_epochs"]
G_losses = self.train_model["G_losses"]
D_losses = self.train_model["D_losses"]
iters = self.train_model["current_iters"]
start_epoch = self.train_model["current_epoch"]
if self.config["add_gasuss"]:
for _, data in enumerate(self.dataloader['train'], 0):
data[0] = self.gasuss_noise(data[0])
for epoch in range(start_epoch, num_epochs):
self.train_model["current_epoch"] = epoch
for i, data in enumerate(self.dataloader['train'], 0):
errD, errG = self._train_a_step(data, i, epoch)
G_losses[0].append(i + epoch * len(self.dataloader['train']))
G_losses[1].append(errG.item())
if errD.item() != 0:
D_losses[0].append(i + epoch * len(self.dataloader['train']))
D_losses[1].append(errD.item())
iters += 1
self.train_model["current_iters"] = iters
# self._save_generator_images(iters, epoch, i)
self.test()
self._save_loss_images(G_losses, D_losses)
def _save_loss_images(self, G_losses, D_losses):
x1 = G_losses[0]
x2 = D_losses[0]
y1 = G_losses[1]
y2 = D_losses[1]
fig = plt.figure
没有合适的资源?快使用搜索试试~ 我知道了~
资源推荐
资源详情
资源评论
收起资源包目录
基于gan算法的GAN做异常检测.zip (8个子文件)
etc.py 1KB
show.py 4KB
main.py 245B
dataset.py 11KB
model.py 8KB
DCGAN_architecture.py 15KB
graph.py 16KB
record.py 3KB
共 8 条
- 1
资源评论
小码蚁.
- 粉丝: 2520
- 资源: 4067
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功