import os
import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
from CaffeLoader import loadCaffemodel, ModelParallel
import argparse
parser = argparse.ArgumentParser()
# Basic options
parser.add_argument("-style_image", help="Style target image", default='examples/inputs/seated-nude.jpg')
parser.add_argument("-style_blend_weights", default=None)
parser.add_argument("-content_image", help="Content target image", default='examples/inputs/tubingen.jpg')
parser.add_argument("-image_size", help="Maximum height / width of generated image", type=int, default=512)
parser.add_argument("-gpu", help="Zero-indexed ID of the GPU to use; for CPU mode set -gpu = c", default=0)
# Optimization options
parser.add_argument("-content_weight", type=float, default=5e0)
parser.add_argument("-style_weight", type=float, default=1e2)
parser.add_argument("-normalize_weights", action='store_true')
parser.add_argument("-normalize_gradients", action='store_true')
parser.add_argument("-tv_weight", type=float, default=1e-3)
parser.add_argument("-num_iterations", type=int, default=1000)
parser.add_argument("-init", choices=['random', 'image'], default='random')
parser.add_argument("-init_image", default=None)
parser.add_argument("-optimizer", choices=['lbfgs', 'adam'], default='lbfgs')
parser.add_argument("-learning_rate", type=float, default=1e0)
parser.add_argument("-lbfgs_num_correction", type=int, default=100)
# Output options
parser.add_argument("-print_iter", type=int, default=50)
parser.add_argument("-save_iter", type=int, default=100)
parser.add_argument("-output_image", default='out.png')
# Other options
parser.add_argument("-style_scale", type=float, default=1.0)
parser.add_argument("-original_colors", type=int, choices=[0, 1], default=0)
parser.add_argument("-pooling", choices=['avg', 'max'], default='max')
parser.add_argument("-model_file", type=str, default='models/vgg19-d01eb7cb.pth')
parser.add_argument("-disable_check", action='store_true')
parser.add_argument("-backend", choices=['nn', 'cudnn', 'mkl', 'mkldnn', 'openmp', 'mkl,cudnn', 'cudnn,mkl'], default='nn')
parser.add_argument("-cudnn_autotune", action='store_true')
parser.add_argument("-seed", type=int, default=-1)
parser.add_argument("-content_layers", help="layers for content", default='relu4_2')
parser.add_argument("-style_layers", help="layers for style", default='relu1_1,relu2_1,relu3_1,relu4_1,relu5_1')
parser.add_argument("-multidevice_strategy", default='4,7,29')
params = parser.parse_args()
Image.MAX_IMAGE_PIXELS = 1000000000 # Support gigapixel images
def main():
dtype, multidevice, backward_device = setup_gpu()
cnn, layerList = loadCaffemodel(params.model_file, params.pooling, params.gpu, params.disable_check)
content_image = preprocess(params.content_image, params.image_size).type(dtype)
style_image_input = params.style_image.split(',')
style_image_list, ext = [], [".jpg", ".jpeg", ".png", ".tiff"]
for image in style_image_input:
if os.path.isdir(image):
images = (image + "/" + file for file in os.listdir(image)
if os.path.splitext(file)[1].lower() in ext)
style_image_list.extend(images)
else:
style_image_list.append(image)
style_images_caffe = []
for image in style_image_list:
style_size = int(params.image_size * params.style_scale)
img_caffe = preprocess(image, style_size).type(dtype)
style_images_caffe.append(img_caffe)
if params.init_image != None:
image_size = (content_image.size(2), content_image.size(3))
init_image = preprocess(params.init_image, image_size).type(dtype)
# Handle style blending weights for multiple style inputs
style_blend_weights = []
if params.style_blend_weights == None:
# Style blending not specified, so use equal weighting
for i in style_image_list:
style_blend_weights.append(1.0)
for i, blend_weights in enumerate(style_blend_weights):
style_blend_weights[i] = int(style_blend_weights[i])
else:
style_blend_weights = params.style_blend_weights.split(',')
assert len(style_blend_weights) == len(style_image_list), \
"-style_blend_weights and -style_images must have the same number of elements!"
# Normalize the style blending weights so they sum to 1
style_blend_sum = 0
for i, blend_weights in enumerate(style_blend_weights):
style_blend_weights[i] = float(style_blend_weights[i])
style_blend_sum = float(style_blend_sum) + style_blend_weights[i]
for i, blend_weights in enumerate(style_blend_weights):
style_blend_weights[i] = float(style_blend_weights[i]) / float(style_blend_sum)
content_layers = params.content_layers.split(',')
style_layers = params.style_layers.split(',')
# Set up the network, inserting style and content loss modules
cnn = copy.deepcopy(cnn)
content_losses, style_losses, tv_losses = [], [], []
next_content_idx, next_style_idx = 1, 1
net = nn.Sequential()
c, r = 0, 0
if params.tv_weight > 0:
tv_mod = TVLoss(params.tv_weight).type(dtype)
net.add_module(str(len(net)), tv_mod)
tv_losses.append(tv_mod)
for i, layer in enumerate(list(cnn), 1):
if next_content_idx <= len(content_layers) or next_style_idx <= len(style_layers):
if isinstance(layer, nn.Conv2d):
net.add_module(str(len(net)), layer)
if layerList['C'][c] in content_layers:
print("Setting up content layer " + str(i) + ": " + str(layerList['C'][c]))
loss_module = ContentLoss(params.content_weight, params.normalize_gradients)
net.add_module(str(len(net)), loss_module)
content_losses.append(loss_module)
if layerList['C'][c] in style_layers:
print("Setting up style layer " + str(i) + ": " + str(layerList['C'][c]))
loss_module = StyleLoss(params.style_weight, params.normalize_gradients)
net.add_module(str(len(net)), loss_module)
style_losses.append(loss_module)
c+=1
if isinstance(layer, nn.ReLU):
net.add_module(str(len(net)), layer)
if layerList['R'][r] in content_layers:
print("Setting up content layer " + str(i) + ": " + str(layerList['R'][r]))
loss_module = ContentLoss(params.content_weight, params.normalize_gradients)
net.add_module(str(len(net)), loss_module)
content_losses.append(loss_module)
next_content_idx += 1
if layerList['R'][r] in style_layers:
print("Setting up style layer " + str(i) + ": " + str(layerList['R'][r]))
loss_module = StyleLoss(params.style_weight, params.normalize_gradients)
net.add_module(str(len(net)), loss_module)
style_losses.append(loss_module)
next_style_idx += 1
r+=1
if isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d):
net.add_module(str(len(net)), layer)
if multidevice:
net = setup_multi_device(net)
# Capture content targets
for i in content_losses:
i.mode = 'capture'
print("Capturing content targets")
print_torch(net, multidevice)
net(content_image)
# Capture style targets
for i in content_losses:
i.mode = 'None'
for i, image in enumerate(style_images_caffe):
print("Capturing style target " + str(i+1))
for j in style_losses:
j.mode = 'capture'
j.blend_weight = style_blend_weights[i]
net(style_images_caffe[i])
# Set all loss modules to loss mode
for i in content_loss

onnx
- 粉丝: 1w+
- 资源: 5637
最新资源
- 嵌入式芯片与系统设计大赛-基于深度学习的机械故障检测示波器的设计.zip
- 【毕业设计-python】python基于深度学习的音乐推荐方法研究系统(django)(完整前后端+mysql+说明文档+LW).zip
- 【毕业设计-python】python旅游景点方面级别情感分析语料库与模型(完整前后端+mysql+说明文档+LW).zip
- 【毕业设计-python】python某在线中药店销售数据统计与分析系统(完整前后端+mysql+说明文档+LW).zip
- 【毕业设计-python】python某大学学生影响力分析系统(完整前后端+mysql+说明文档+LW).zip
- 基于双二阶广义积分器的软件锁相环仿真模型研究:在不对称工况下的应用与性能分析,基于双二阶广义积分器的软件锁相环仿真模型研究:不对称工况下的性能对比及其在并网逆变器中的应用,基于双二阶广义积分器的软件锁
- 网上花店系统(SSH).zip(毕设&课设&实训&大作业&竞赛&项目)
- 基于Pytorch的口罩佩戴检测.zip(毕设&课设&实训&大作业&竞赛&项目)
- 基于PyCharm构建高加热水平衡计算的小型应用程序
- 基于LCL滤波器的单相光伏逆变器控制策略设计与MATLAB-Simulink仿真验证,基于LCL滤波器的单相光伏逆变器控制设计的MATLAB-Simulink仿真分析与实现,基于LCL滤波器的单相光伏
- 基于强化学习的五子棋人工智能系统.zip(毕设&课设&实训&大作业&竞赛&项目)
- CAN分析仪资料,含软件,驱动,创芯科技
- 基于java_ssm_mysql实现的房屋_公寓出租网平台_租赁平台.zip(毕设&课设&实训&大作业&竞赛&项目)
- 基于51单片机的交通灯设计,Proteus仿真与C语言编程实现,详解代码及电路原理,可加工实物 ,基于51单片机的交通灯设计项目:Proteus仿真与C语言编程实现,项目:交通灯-基于51单片机的交
- 基于Java和Mysql的网上购物系统,主要业务是完成在线购物的功能。系统包括前台商家系统、买家系统及后台管理员操作系统。.zip(毕设&课设&实训&大作业&竞赛&项目)
- 基于javaWeb 的网上商城系统.zip(毕设&课设&实训&大作业&竞赛&项目)
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈


