|
| 1 | +#! /usr/bin/python |
| 2 | +# -*- encoding: utf-8 -*- |
| 3 | +''' |
| 4 | +@author LeslieZhao |
| 5 | +@date 20220721 |
| 6 | +''' |
| 7 | +import os |
| 8 | +from torchvision import transforms |
| 9 | +import PIL.Image as Image |
| 10 | +from data.DataLoader import DatasetBase |
| 11 | +import random |
| 12 | +import numpy as np |
| 13 | +import torch |
| 14 | + |
| 15 | + |
| 16 | +class TTNData(DatasetBase): |
| 17 | + def __init__(self, slice_id=0, slice_count=1,dist=False, **kwargs): |
| 18 | + super().__init__(slice_id, slice_count,dist, **kwargs) |
| 19 | + |
| 20 | + |
| 21 | + self.transform = transforms.Compose([ |
| 22 | + transforms.Resize([256,256]), |
| 23 | + transforms.RandomResizedCrop(256,scale=(0.8,1.2)), |
| 24 | + transforms.RandomRotation(degrees=(-90,90)), |
| 25 | + transforms.RandomHorizontalFlip(), |
| 26 | + transforms.ToTensor(), |
| 27 | + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
| 28 | + ]) |
| 29 | + |
| 30 | + if kwargs['eval']: |
| 31 | + self.transform = transforms.Compose([ |
| 32 | + transforms.Resize([256,256]), |
| 33 | + transforms.ToTensor(), |
| 34 | + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) |
| 35 | + self.length = 100 |
| 36 | + |
| 37 | + src_root = kwargs['src_root'] |
| 38 | + tgt_root = kwargs['tgt_root'] |
| 39 | + |
| 40 | + self.src_paths = [os.path.join(src_root,f) for f in os.listdir(src_root) if f.endswith('.png')] |
| 41 | + self.tgt_paths = [os.path.join(tgt_root,f) for f in os.listdir(tgt_root) if f.endswith('.png')] |
| 42 | + self.src_length = len(self.src_paths) |
| 43 | + self.tgt_length = len(self.tgt_paths) |
| 44 | + random.shuffle(self.src_paths) |
| 45 | + random.shuffle(self.tgt_paths) |
| 46 | + |
| 47 | + self.mx_left_eye_all,\ |
| 48 | + self.mn_left_eye_all,\ |
| 49 | + self.mx_right_eye_all,\ |
| 50 | + self.mn_right_eye_all,\ |
| 51 | + self.mx_lip_all,\ |
| 52 | + self.mn_lip_all = \ |
| 53 | + np.load(kwargs['score_info']) |
| 54 | + |
| 55 | + def __getitem__(self,i): |
| 56 | + src_idx = i % self.src_length |
| 57 | + tgt_idx = i % self.tgt_length |
| 58 | + |
| 59 | + src_path = self.src_paths[src_idx] |
| 60 | + tgt_path = self.tgt_paths[tgt_idx] |
| 61 | + exp_path = src_path.replace('img','express')[:-3] + 'npy' |
| 62 | + |
| 63 | + with Image.open(src_path) as img: |
| 64 | + srcImg = self.transform(img) |
| 65 | + |
| 66 | + with Image.open(tgt_path) as img: |
| 67 | + tgtImg = self.transform(img) |
| 68 | + |
| 69 | + score = np.load(exp_path) |
| 70 | + score[0] = (score[0] - self.mn_left_eye_all) / (self.mx_left_eye_all - self.mn_left_eye_all) |
| 71 | + score[1] = (score[1] - self.mn_right_eye_all) / (self.mx_right_eye_all - self.mn_right_eye_all) |
| 72 | + score[2] = (score[2] - self.mn_lip_all) / (self.mx_lip_all - self.mn_lip_all) |
| 73 | + score = torch.from_numpy(score.astype(np.float32)) |
| 74 | + |
| 75 | + return srcImg,tgtImg,score |
| 76 | + |
| 77 | + |
| 78 | + def __len__(self): |
| 79 | + # return max(self.src_length,self.tgt_length) |
| 80 | + if hasattr(self,'length'): |
| 81 | + return self.length |
| 82 | + else: |
| 83 | + return 10000 |
| 84 | + |
0 commit comments