Skip to content

Commit b303f41

Browse files
committed
first
1 parent 0927651 commit b303f41

35 files changed

Lines changed: 4069 additions & 0 deletions

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
*__pycache__*
2+
pretrain_models
3+
*checkpoint*

ReadMe.md

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# DCT-NET.Pytorch
2+
unofficial implementation of DCT-Net: Domain-Calibrated Translation for Portrait Stylization.<br>
3+
you can find official version [here](https://github.com/menyifang/DCT-Net)
4+
![](assets/net.png)
5+
6+
## show
7+
![img](assets/ldh.png)
8+
![video](assets/xcaq.gif)
9+
10+
## environment
11+
you can build your environment follow [this](https://github.com/rosinality/stylegan2-pytorch)<br>
12+
```pip install tensorboardX ``` for show
13+
14+
## how to run
15+
### train
16+
download pretrain weights
17+
#### CCN
18+
1. prepare the style pictures and align them<br>
19+
the image path is like this<br>
20+
style-photos/<br>
21+
|-- 000000.png<br>
22+
|-- 000006.png<br>
23+
|-- 000010.png<br>
24+
|-- 000011.png<br>
25+
|-- 000015.png<br>
26+
|-- 000028.png<br>
27+
|-- 000039.png<br>
28+
2. change your own path in [ccn_config](./model/styleganModule/config.py#L7)
29+
3. train ccn<br>
30+
31+
```shell
32+
# single gpu
33+
python train.py \
34+
--model ccn \
35+
--batch_size 16 \
36+
--checkpoint_path checkpoint \
37+
--lr 0.002 \
38+
--print_interval 100 \
39+
--save_interval 100 --dist
40+
```
41+
42+
```shell
43+
# multi gpu
44+
python -m torch.distributed.launch train.py \
45+
--model ccn \
46+
--batch_size 16 \
47+
--checkpoint_path checkpoint \
48+
--lr 0.002 \
49+
--print_interval 100 \
50+
--save_interval 100
51+
```
52+
almost 1000 steps, you can stop
53+
#### TTN
54+
1. prepare expression information<br>
55+
you can follow [LVT](https://github.com/LeslieZhoa/LVT) to estimate facial landmark<br>
56+
```shell
57+
cd utils
58+
python get_face_expression.py \
59+
--img_base '' # your real image path base,like ffhq \
60+
--pool_num 2 # multiprocess number \
61+
--LVT '' # the LVT path you put \
62+
--train # train data or val data
63+
```
64+
2. prepare your generator image<br>
65+
```shell
66+
cd utils
67+
python get_tcc_input.py \
68+
--model_path '' # ccn model path \
69+
--output_path '' # save path
70+
```
71+
__select almost 5k~1w good image manually__
72+
3. change your own path in [ttn_config](./model/Pix2PixModule/config.py#21)
73+
```shell
74+
# like
75+
self.train_src_root = '/StyleTransform/DATA/ffhq-2w/img'
76+
self.train_tgt_root = '/StyleTransform/DATA/select-style-gan'
77+
self.val_src_root = '/StyleTransform/DATA/dmloghq-1k/img'
78+
self.val_tgt_root = '/StyleTransform/DATA/select-style-gan'
79+
```
80+
4. train tnn
81+
```shell
82+
# like ccn single and multi gpus
83+
python train.py \
84+
--model ttn \
85+
--batch_size 64 \
86+
--checkpoint_path checkpoint \
87+
--lr 2e-4 \
88+
--print_interval 100 \
89+
--save_interval 100 \
90+
--dist
91+
```
92+
## inference
93+
you can follow inference.py to put your own ttn model path and image path<br>
94+
```python inference.py```
95+
96+
## Credits
97+
SEAN model and implementation:<br>
98+
https://github.com/ZPdesu/SEAN Copyright © 2020, ZPdesu.<br>
99+
License https://github.com/ZPdesu/SEAN/blob/master/LICENSE.md
100+
101+
stylegan2-pytorch model and implementation:<br>
102+
https://github.com/rosinality/stylegan2-pytorch Copyright © 2019, rosinality.<br>
103+
License https://github.com/rosinality/stylegan2-pytorch/blob/master/LICENSE
104+
105+
White-box-Cartoonization model and implementation:<br>
106+
https://github.com/SystemErrorWang/White-box-Cartoonization Copyright © 2020, SystemErrorWang.<br>
107+
108+
White-box-Cartoonization model pytorch model and implementation:<br>
109+
https://github.com/vinesmsuic/White-box-Cartoonization-PyTorch Copyright © 2022, vinesmsuic.<br>
110+
License https://github.com/vinesmsuic/White-box-Cartoonization-PyTorch/blob/main/LICENSE
111+
112+
arcface pytorch model pytorch model and implementation:<br>
113+
https://github.com/ronghuaiyang/arcface-pytorch Copyright © 2018, ronghuaiyang.<br>
114+
115+
116+

assets/ldh.png

445 KB
Loading

assets/net.png

101 KB
Loading

assets/xcaq.gif

34.9 MB
Loading

data/CCNLoader.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#! /usr/bin/python
2+
# -*- encoding: utf-8 -*-
3+
'''
4+
@author LeslieZhao
5+
@date 20220721
6+
'''
7+
8+
import os
9+
10+
from torchvision import transforms
11+
import PIL.Image as Image
12+
from data.DataLoader import DatasetBase
13+
import random
14+
15+
16+
class CCNData(DatasetBase):
17+
def __init__(self, slice_id=0, slice_count=1,dist=False, **kwargs):
18+
super().__init__(slice_id, slice_count,dist, **kwargs)
19+
20+
21+
self.transform = transforms.Compose([
22+
transforms.RandomHorizontalFlip(),
23+
transforms.ToTensor(),
24+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
25+
])
26+
27+
root = kwargs['root']
28+
self.paths = [os.path.join(root,f) for f in os.listdir(root)]
29+
self.length = len(self.paths)
30+
random.shuffle(self.paths)
31+
32+
def __getitem__(self,i):
33+
idx = i % self.length
34+
img_path = self.paths[idx]
35+
36+
with Image.open(img_path) as img:
37+
Img = self.transform(img)
38+
39+
return Img
40+
41+
42+
def __len__(self):
43+
return max(100000,self.length)
44+
# return 4
45+

data/DataLoader.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#! /usr/bin/python
2+
# -*- encoding: utf-8 -*-
3+
'''
4+
@author LeslieZhao
5+
@date 20220721
6+
'''
7+
8+
9+
from torch.utils.data import Dataset
10+
import torch.distributed as dist
11+
12+
13+
class DatasetBase(Dataset):
14+
def __init__(self,slice_id=0,slice_count=1,use_dist=False,**kwargs):
15+
16+
if use_dist:
17+
slice_id = dist.get_rank()
18+
slice_count = dist.get_world_size()
19+
self.id = slice_id
20+
self.count = slice_count
21+
22+
23+
def __getitem__(self,i):
24+
pass
25+
26+
27+
28+
29+
def __len__(self):
30+
return 1000
31+

data/TTNLoader.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#! /usr/bin/python
2+
# -*- encoding: utf-8 -*-
3+
'''
4+
@author LeslieZhao
5+
@date 20220721
6+
'''
7+
import os
8+
from torchvision import transforms
9+
import PIL.Image as Image
10+
from data.DataLoader import DatasetBase
11+
import random
12+
import numpy as np
13+
import torch
14+
15+
16+
class TTNData(DatasetBase):
17+
def __init__(self, slice_id=0, slice_count=1,dist=False, **kwargs):
18+
super().__init__(slice_id, slice_count,dist, **kwargs)
19+
20+
21+
self.transform = transforms.Compose([
22+
transforms.Resize([256,256]),
23+
transforms.RandomResizedCrop(256,scale=(0.8,1.2)),
24+
transforms.RandomRotation(degrees=(-90,90)),
25+
transforms.RandomHorizontalFlip(),
26+
transforms.ToTensor(),
27+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
28+
])
29+
30+
if kwargs['eval']:
31+
self.transform = transforms.Compose([
32+
transforms.Resize([256,256]),
33+
transforms.ToTensor(),
34+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
35+
self.length = 100
36+
37+
src_root = kwargs['src_root']
38+
tgt_root = kwargs['tgt_root']
39+
40+
self.src_paths = [os.path.join(src_root,f) for f in os.listdir(src_root) if f.endswith('.png')]
41+
self.tgt_paths = [os.path.join(tgt_root,f) for f in os.listdir(tgt_root) if f.endswith('.png')]
42+
self.src_length = len(self.src_paths)
43+
self.tgt_length = len(self.tgt_paths)
44+
random.shuffle(self.src_paths)
45+
random.shuffle(self.tgt_paths)
46+
47+
self.mx_left_eye_all,\
48+
self.mn_left_eye_all,\
49+
self.mx_right_eye_all,\
50+
self.mn_right_eye_all,\
51+
self.mx_lip_all,\
52+
self.mn_lip_all = \
53+
np.load(kwargs['score_info'])
54+
55+
def __getitem__(self,i):
56+
src_idx = i % self.src_length
57+
tgt_idx = i % self.tgt_length
58+
59+
src_path = self.src_paths[src_idx]
60+
tgt_path = self.tgt_paths[tgt_idx]
61+
exp_path = src_path.replace('img','express')[:-3] + 'npy'
62+
63+
with Image.open(src_path) as img:
64+
srcImg = self.transform(img)
65+
66+
with Image.open(tgt_path) as img:
67+
tgtImg = self.transform(img)
68+
69+
score = np.load(exp_path)
70+
score[0] = (score[0] - self.mn_left_eye_all) / (self.mx_left_eye_all - self.mn_left_eye_all)
71+
score[1] = (score[1] - self.mn_right_eye_all) / (self.mx_right_eye_all - self.mn_right_eye_all)
72+
score[2] = (score[2] - self.mn_lip_all) / (self.mx_lip_all - self.mn_lip_all)
73+
score = torch.from_numpy(score.astype(np.float32))
74+
75+
return srcImg,tgtImg,score
76+
77+
78+
def __len__(self):
79+
# return max(self.src_length,self.tgt_length)
80+
if hasattr(self,'length'):
81+
return self.length
82+
else:
83+
return 10000
84+

inference.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import numpy as np
2+
import os
3+
import cv2
4+
import torch
5+
from model.Pix2PixModule.model import Generator
6+
from utils.utils import convert_img
7+
8+
class Infer:
9+
def __init__(self,model_path):
10+
self.net = Generator(img_channels=3)
11+
self.load_checkpoint(model_path)
12+
13+
14+
def run(self,img):
15+
if isinstance(img,str):
16+
img = cv2.imread(img)
17+
inp = self.preprocess(img)
18+
with torch.no_grad():
19+
xg = self.net(inp)
20+
oup = self.postprocess(xg[0])
21+
return oup
22+
23+
def load_checkpoint(self,path):
24+
ckpt = torch.load(path, map_location=lambda storage, loc: storage)
25+
self.net.load_state_dict(ckpt['netG'],strict=False)
26+
if torch.cuda.is_available():
27+
self.net.cuda()
28+
self.net.eval()
29+
30+
def preprocess(self,img):
31+
32+
img = (img[...,::-1] / 255.0 - 0.5) * 2
33+
img = img.transpose(2,0,1)[np.newaxis,:].astype(np.float32)
34+
img = torch.from_numpy(img)
35+
if torch.cuda.is_available():
36+
img = img.cuda()
37+
return img
38+
def postprocess(self,img):
39+
img = convert_img(img,unit=True)
40+
return img.permute(1,2,0).cpu().numpy()[...,::-1]
41+
42+
43+
44+
if __name__ == "__main__":
45+
46+
path = 'pretrain_models/final.pth'
47+
model = Infer(path)
48+
49+
img = cv2.imread('')
50+
51+
img_h,img_w,_ = img.shape
52+
n_h,n_w = img_h // 8 * 8,img_w // 8 * 8
53+
img = cv2.resize(img,(n_w,n_h))
54+
55+
oup = model.run(img)[...,::-1]
56+
cv2.imwrite('output.png',oup)
57+
58+

model/Pix2PixModule/config.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
class Params:
2+
def __init__(self):
3+
4+
self.name = 'Pix2Pix'
5+
6+
self.pretrain_path = None
7+
self.vgg_model = 'pretrain_models/vgg19-dcbb9e9d.pth'
8+
self.lr = 2e-4
9+
self.beta1 = 0.5
10+
self.beta2 = 0.99
11+
12+
self.use_exp = True
13+
self.lambda_surface = 2.0
14+
self.lambda_texture = 2.0
15+
self.lambda_content = 200
16+
self.lambda_tv = 1e4
17+
18+
self.lambda_exp = 1.0
19+
20+
21+
self.train_src_root = '/StyleTransform/DATA/ffhq-2w/img'
22+
self.train_tgt_root = '/StyleTransform/DATA/select-style-gan'
23+
self.val_src_root = '/StyleTransform/DATA/dmloghq-1k/img'
24+
self.val_tgt_root = '/StyleTransform/DATA/select-style-gan'
25+
self.score_info = 'pretrain_models/all_express_mean.npy'
26+
27+
self.infer_batch_size = 2
28+

0 commit comments

Comments
 (0)