-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathCLI_interface.py
More file actions
80 lines (57 loc) · 2.22 KB
/
Copy pathCLI_interface.py
File metadata and controls
80 lines (57 loc) · 2.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim
from torchvision import transforms, models
from Helper import load_image, get_images, im_convert
from Style_Transfer import style_transfer
import os
import argparse
def is_valid_directory(parser, arg):
if not os.path.isfile(arg):
parser.error('The directory {} does not exist!'.format(arg))
else:
return arg
def get_paths():
parser = argparse.ArgumentParser(
description='Style Transfer. You need to pass two directorys for the content and style images. ')
parser.add_argument(
'content_image_dir',
help='Content image directory.',
type=lambda x: is_valid_directory(parser, x))
parser.add_argument(
'style_image_dir',
help='Style image directory.',
type=lambda x: is_valid_directory(parser, x))
args = parser.parse_args()
return args.content_image_dir, args.style_image_dir
if __name__ == "__main__":
# loading the first part
vgg = models.vgg19(pretrained=True).features
# freezing the parameters
for param in vgg.parameters():
param.requires_grad_(False)
# move the model to GPU, if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg.to(device)
content_image_path, style_image_path = get_paths()
content, style = get_images(content_image_path, style_image_path)
# weights for each style layer
# weighting earlier layers more will result in *larger* style artifacts
# notice we are excluding `conv4_2` our content representation
style_weights = {'conv1_1': 1.,
'conv2_1': 0.8,
'conv3_1': 0.5,
'conv4_1': 0.3,
'conv5_1': 0.1}
content_weight = 1 # alpha
style_weight = 1e6 # beta
steps = 3000
target = style_transfer(content, style, vgg, steps, content_weight, style_weight, style_weights)
plt.imsave("result.png", im_convert(target))
# display content and final, target image
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
ax1.imshow(im_convert(content))
ax2.imshow(im_convert(target))
plt.show()