import time
import itertools
from dataset import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader
from .ms.networks import *
from utils import *
from glob import glob
from .face_features import FaceFeatures
from mindspore import dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
from mindspore import Tensor, export, load_checkpoint, load_param_into_net, save_checkpoint
import mindspore.nn as nn
import mindspore.ops as ops
from .ms.grad import value_and_grad
def RhoClipper(clip_max, module):
if hasattr(module, 'rho'):
w = module.rho.data
w = w.clamp(0, clip_max)
module.rho.data = w
return module
def WClipper(clip_max, module):
if hasattr(module, 'w_gamma'):
w = module.w_gamma.data
w = w.clamp(self.clip_min, self.clip_max)
module.w_gamma.data = w
if hasattr(module, 'w_beta'):
w = module.w_beta.data
w = w.clamp(self.clip_min, self.clip_max)
module.w_beta.data = w
return module
class UgatitSadalinHourglass(object):
def __init__(self, args):
self.light = args.light
if self.light:
self.model_name = 'UGATIT_light'
else:
self.model_name = 'UGATIT'
self.result_dir = args.result_dir
self.dataset = args.dataset
self.iteration = args.iteration
self.decay_flag = args.decay_flag
self.batch_size = args.batch_size
self.print_freq = args.print_freq
self.save_freq = args.save_freq
self.lr = args.lr
self.ch = args.ch
""" Weight """
self.adv_weight = Tensor(args.adv_weight)
self.cycle_weight = Tensor(args.cycle_weight)
self.identity_weight = Tensor(args.identity_weight)
self.cam_weight = Tensor(args.cam_weight)
self.faceid_weight = Tensor(args.faceid_weight)
""" Discriminator """
self.n_dis = args.n_dis
self.img_size = args.img_size
self.img_ch = args.img_ch
self.device = f'cuda:{args.gpu_ids[0]}'
self.gpu_ids = args.gpu_ids
self.benchmark_flag = args.benchmark_flag
self.resume = args.resume
self.rho_clipper = args.rho_clipper
self.w_clipper = args.w_clipper
self.pretrained_model = args.pretrained_model
if torch.backends.cudnn.enabled and self.benchmark_flag:
print('set benchmark !')
torch.backends.cudnn.benchmark = True
print("##### Information #####")
print("# light : ", self.light)
print("# dataset : ", self.dataset)
print("# batch_size : ", self.batch_size)
print("# iteration per epoch : ", self.iteration)
print("##### Discriminator #####")
print("# discriminator layer : ", self.n_dis)
print()
print("##### Weight #####")
print("# adv_weight : ", self.adv_weight)
print("# cycle_weight : ", self.cycle_weight)
print("# faceid_weight : ", self.faceid_weight)
print("# identity_weight : ", self.identity_weight)
print("# cam_weight : ", self.cam_weight)
print("# rho_clipper: ", self.rho_clipper)
print("# w_clipper: ", self.w_clipper)
##################################################################################
# Model
##################################################################################
def build_model(self):
""" DataLoader """
# train_transform = transforms.Compose([
# transforms.RandomHorizontalFlip(),
# transforms.Resize((self.img_size + 30, self.img_size+30)),
# transforms.RandomCrop(self.img_size),
# transforms.ToTensor(),
# transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
# ])
# test_transform = transforms.Compose([
# transforms.Resize((self.img_size, self.img_size)),
# transforms.ToTensor(),
# transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
# ])
train_transform = [
C.Decode(),
C.Resize((self.img_size + 30, self.img_size+30)),
C.RandomCrop(self.img_size),
C.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
C.HWC2CHW()
]
test_transform = [
C.Decode(),
C.Resize((self.img_size, self.img_size)),
C.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
C.HWC2CHW()
]
# self.trainA = ImageFolder(os.path.join('dataset', self.dataset, 'trainA'), train_transform)
# self.trainB = ImageFolder(os.path.join('dataset', self.dataset, 'trainB'), train_transform)
# self.testA = ImageFolder(os.path.join('dataset', self.dataset, 'testA'), test_transform)
# self.testB = ImageFolder(os.path.join('dataset', self.dataset, 'testB'), test_transform)
self.trainA = ds.ImageFolderDataset(os.path.join('dataset', self.dataset, 'trainAA'), num_parallel_workers=8, shuffle=True)
self.trainB = ds.ImageFolderDataset(os.path.join('dataset', self.dataset, 'trainBB'), num_parallel_workers=8, shuffle=True)
self.testA = ds.ImageFolderDataset(os.path.join('dataset', self.dataset, 'testAA'), num_parallel_workers=8, shuffle=False)
self.testB = ds.ImageFolderDataset(os.path.join('dataset', self.dataset, 'testBB'), num_parallel_workers=8, shuffle=False)
self.trainA = self.trainA.map(operations=train_transform, input_columns="image", num_parallel_workers=8)
self.trainA_loader = self.trainA.batch(batch_size=1, drop_remainder=True)
self.trainB = self.trainB.map(operations=train_transform, input_columns="image", num_parallel_workers=8)
self.trainB_loader = self.trainB.batch(batch_size=1, drop_remainder=True)
self.testA = self.testA.map(operations=test_transform, input_columns="image", num_parallel_workers=8)
self.testA_loader = self.testA.batch(batch_size=1, drop_remainder=False)
self.testB = self.testB.map(operations=test_transform, input_columns="image", num_parallel_workers=8)
self.testB_loader = self.testB.batch(batch_size=1, drop_remainder=False)
# self.trainA_loader = DataLoader(self.trainA, batch_size=self.batch_size, shuffle=True)
# self.trainB_loader = DataLoader(self.trainB, batch_size=self.batch_size, shuffle=True)
# self.testA_loader = DataLoader(self.testA, batch_size=1, shuffle=False)
# self.testB_loader = DataLoader(self.testB, batch_size=1, shuffle=False)
""" Define Generator, Discriminator """
self.genA2B = ResnetGenerator(ngf=self.ch, img_size=self.img_size, light=self.light)
self.genB2A = ResnetGenerator(ngf=self.ch, img_size=self.img_size, light=self.light)
self.disGA = Discriminator(input_nc=3, ndf=self.ch, n_layers=7)
self.disGB = Discriminator(input_nc=3, ndf=self.ch, n_layers=7)
self.disLA = Discriminator(input_nc=3, ndf=self.ch, n_layers=5)
self.disLB = Discriminator(input_nc=3, ndf=self.ch, n_layers=5)
self.facenet = FaceFeatures('pretrained_models/model_mobilefacenet.pth', self.device)
""" Trainer """
# self.G_optim = torch.optim.Adam(itertools.chain(self.genA2B.parameters(), self.genB2A.parameters()), lr=self.lr, betas=(0.5, 0.999), weight_decay=0.0001)
# self.D_optim = torch.optim.Adam(
# itertools.chain(self.disGA.parameters(), self.disGB.parameters(), self.disLA.parameters(), self.disLB.parameters()),
# lr=self.lr, betas=(0.5, 0.999), weight_decay=0.0001
# )
# g_params = [{'params': self.genA2B.trainable_p