# Copyright (c) 2015-2019 Anish Athalye. Released under GPLv3.
import os
import math
import re
from argparse import ArgumentParser
from collections import OrderedDict
from PIL import Image
import numpy as np
import scipy.misc
from stylize import stylize
# default arguments
CONTENT_WEIGHT = 5e0
CONTENT_WEIGHT_BLEND = 1
STYLE_WEIGHT = 5e2
TV_WEIGHT = 1e2
STYLE_LAYER_WEIGHT_EXP = 1
LEARNING_RATE = 1e1
BETA1 = 0.9
BETA2 = 0.999
EPSILON = 1e-08
STYLE_SCALE = 1.0
ITERATIONS = 1000
VGG_PATH = 'imagenet-vgg-verydeep-19.mat'
POOLING = 'max'
def build_parser():
parser = ArgumentParser()
parser.add_argument('--content',
dest='content', help='content image',
metavar='CONTENT', required=True)
parser.add_argument('--styles',
dest='styles',
nargs='+', help='one or more style images',
metavar='STYLE', required=True)
parser.add_argument('--output',
dest='output', help='output path',
metavar='OUTPUT', required=True)
parser.add_argument('--iterations', type=int,
dest='iterations', help='iterations (default %(default)s)',
metavar='ITERATIONS', default=ITERATIONS)
parser.add_argument('--print-iterations', type=int,
dest='print_iterations', help='statistics printing frequency',
metavar='PRINT_ITERATIONS')
parser.add_argument('--checkpoint-output',
dest='checkpoint_output',
help='checkpoint output format, e.g. output_{:05}.jpg or '
'output_%%05d.jpg',
metavar='OUTPUT', default=None)
parser.add_argument('--checkpoint-iterations', type=int,
dest='checkpoint_iterations', help='checkpoint frequency',
metavar='CHECKPOINT_ITERATIONS', default=None)
parser.add_argument('--progress-write', default=False, action='store_true',
help="write iteration progess data to OUTPUT's dir",
required=False)
parser.add_argument('--progress-plot', default=False, action='store_true',
help="plot iteration progess data to OUTPUT's dir",
required=False)
parser.add_argument('--width', type=int,
dest='width', help='output width',
metavar='WIDTH')
parser.add_argument('--style-scales', type=float,
dest='style_scales',
nargs='+', help='one or more style scales',
metavar='STYLE_SCALE')
parser.add_argument('--network',
dest='network', help='path to network parameters (default %(default)s)',
metavar='VGG_PATH', default=VGG_PATH)
parser.add_argument('--content-weight-blend', type=float,
dest='content_weight_blend',
help='content weight blend, conv4_2 * blend + conv5_2 * (1-blend) '
'(default %(default)s)',
metavar='CONTENT_WEIGHT_BLEND', default=CONTENT_WEIGHT_BLEND)
parser.add_argument('--content-weight', type=float,
dest='content_weight', help='content weight (default %(default)s)',
metavar='CONTENT_WEIGHT', default=CONTENT_WEIGHT)
parser.add_argument('--style-weight', type=float,
dest='style_weight', help='style weight (default %(default)s)',
metavar='STYLE_WEIGHT', default=STYLE_WEIGHT)
parser.add_argument('--style-layer-weight-exp', type=float,
dest='style_layer_weight_exp',
help='style layer weight exponentional increase - '
'weight(layer<n+1>) = weight_exp*weight(layer<n>) '
'(default %(default)s)',
metavar='STYLE_LAYER_WEIGHT_EXP', default=STYLE_LAYER_WEIGHT_EXP)
parser.add_argument('--style-blend-weights', type=float,
dest='style_blend_weights', help='style blending weights',
nargs='+', metavar='STYLE_BLEND_WEIGHT')
parser.add_argument('--tv-weight', type=float,
dest='tv_weight',
help='total variation regularization weight (default %(default)s)',
metavar='TV_WEIGHT', default=TV_WEIGHT)
parser.add_argument('--learning-rate', type=float,
dest='learning_rate', help='learning rate (default %(default)s)',
metavar='LEARNING_RATE', default=LEARNING_RATE)
parser.add_argument('--beta1', type=float,
dest='beta1', help='Adam: beta1 parameter (default %(default)s)',
metavar='BETA1', default=BETA1)
parser.add_argument('--beta2', type=float,
dest='beta2', help='Adam: beta2 parameter (default %(default)s)',
metavar='BETA2', default=BETA2)
parser.add_argument('--eps', type=float,
dest='epsilon', help='Adam: epsilon parameter (default %(default)s)',
metavar='EPSILON', default=EPSILON)
parser.add_argument('--initial',
dest='initial', help='initial image',
metavar='INITIAL')
parser.add_argument('--initial-noiseblend', type=float,
dest='initial_noiseblend',
help='ratio of blending initial image with normalized noise '
'(if no initial image specified, content image is used) '
'(default %(default)s)',
metavar='INITIAL_NOISEBLEND')
parser.add_argument('--preserve-colors', action='store_true',
dest='preserve_colors',
help='style-only transfer (preserving colors) - if color transfer '
'is not needed')
parser.add_argument('--pooling',
dest='pooling',
help='pooling layer configuration: max or avg (default %(default)s)',
metavar='POOLING', default=POOLING)
parser.add_argument('--overwrite', action='store_true', dest='overwrite',
help='write file even if there is already a file with that name')
return parser
def fmt_imsave(fmt, iteration):
if re.match(r'^.*\{.*\}.*$', fmt):
return fmt.format(iteration)
elif '%' in fmt:
return fmt % iteration
else:
raise ValueError("illegal format string '{}'".format(fmt))
def main():
# https://stackoverflow.com/a/42121886
key = 'TF_CPP_MIN_LOG_LEVEL'
if key not in os.environ:
os.environ[key] = '2'
parser = build_parser()
options = parser.parse_args()
if not os.path.isfile(options.network):
parser.error("Network %s does not exist. (Did you forget to "
"download it?)" % options.network)
if [options.checkpoint_iterations,
options.checkpoint_output].count(None) == 1:
parser.error("use either both of checkpoint_output and "
"checkpoint_iterations or neither")
if options.checkpoint_output is not None:
if re.match(r'^.*(\{.*\}|%.*).*$', options.checkpoint_output) is None:
parser.error("To save intermediate images, the checkpoint_output "
"parameter must contain placeholders (e.g. "
"`foo_{}.jpg` or `foo_%d.jpg`")
content_image = imread(options.content)
style_images = [imread(style) for style in options.styles]
width = options.width
if width is not None:
new_shape = (int(math.floor(float(content_image.shape[0]) /
content_image.shape[1] * width)), width)
content_image = scipy.misc.imresize(content_image, new_shape)
target_shape = content_image.shape
for i in range(len(style_images)):
style_scale = STYLE_SCALE
if options.style_scales is not None:
style_scale = options.style_scales[i]
style_images[i] = scipy.misc.imresize(style_images[i], style_scale *
target_shape[1] / style_images[i].shape[1])
style_blend_weights = options.style_blend_weights
if style_blend_weights is None:
# default is equal weights
style_blend_weights = [1.0/len(style_images) for _ in style_images]
else:
total_blend_weight = sum(style_blend_weights)
style_blend_weights = [weight/total_blend_weight