diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..46c31ca --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +download/ +output_logfile.txt diff --git a/README.md b/README.md index 036b12a..7933970 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,20 @@ - # HDTF +<<<<<<< HEAD +Flow-guided One-shot Talking Face Generation with a High-resolution Audio-visual Dataset +paper supplementary [demo video](https://www.youtube.com/watch?v=uJdBgWYBTww) + +## Details of HDTF dataset +**./HDTF_dataset** consists of *youtube video url*, *video resolution* (in our method, may not be the best resolution), *time stamps of talking face*, *facial region* (in the our method) and *the zoom scale* of the cropped window. + +**xx_video_url.txt:** +======= Flow-guided One-shot Talking Face Generation with a High-resolution Audio-visual Dataset paper supplementary ## Details of HDTF dataset **./HDTF_dataset** consists of *youtube video url*, *video resolution* (in our method, may not be the best resolution), *time stamps of talking face*, *facial region* (in the our method) and *the zoom scale* of the cropped window. **xx_video_url.txt:** - +>>>>>>> 8c402f4 (doc: add the downloading instructions) ``` format: video name | video youtube url @@ -44,7 +52,20 @@ When using HDTF dataset, - We resize all cropped videos into **512 x 512** resolution. -The HDTF dataset is available to download under a Creative Commons Attribution 4.0 International License. If you face any problems when processing HDTF, pls contact me. +The HDTF dataset is available to download under a Creative Commons Attribution 4.0 International License. **Thanks @universome for provding the the script of data processing, pls visit [here](https://github.com/universome/HDTF) for more details.** If you face any problems when processing HDTF, pls contact me. + +## Inference code +#### code of audio-to-animation +coming soon...... + +#### code of constructing approximate dense flow +The code is in **./code_constructing_Fapp**, pls visit [here](https://github.com/MRzzm/HDTF/tree/main/code_constructing_Fapp) for more details. + +#### code of animation-to-video module +The code is in **./code_animation2video**, pls visit [here](https://github.com/MRzzm/HDTF/tree/main/code_animation2video) for more details. + +#### code of reproducing other works +coming soon...... ## Downloading For convenience, we added the `download.py` script which downloads, crops and resizes the dataset. You can use it via the following command: diff --git a/code_animation2video/inference.py b/code_animation2video/inference.py new file mode 100644 index 0000000..0744d9f --- /dev/null +++ b/code_animation2video/inference.py @@ -0,0 +1,115 @@ +import numpy as np +from models import VideoGenerator +from opts import parse_opts +from torchvision import transforms +import os +from math import log10 +import torch +import numpy as np +import json +import random +import cv2 +import subprocess +from utils import visualize_dense_flow,make_coordinate_grid +if __name__ == "__main__": + opt = parse_opts() + ### load trained model + video_generator = VideoGenerator(opt.input_channel, opt.encoder_num_down_blocks, opt.encoder_block_expansion, + opt.encoder_max_features, opt.houglass_num_blocks, + opt.houglass_block_expansion, opt.houglass_max_features, + opt.num_bottleneck_blocks).cuda() + old_dict = torch.load(opt.model_path)['state_dict']['net_g'] + new_dict = {} + for k, v in old_dict.items(): + name = k[7:].replace('dense_motion','foreground_matting').replace('attention_mask','matting_mask') + new_dict[name] = v + video_generator.load_state_dict(new_dict) + video_generator.eval() + #### load reference image + reference_image = cv2.imread(opt.image_path) + reference_tensor = torch.from_numpy(reference_image / 255).permute(2, 0, 1).float().unsqueeze(0).cuda() + #### load approximate dense flow + Fapp = np.load(opt.dense_flow_path) + tem = Fapp[0,:,:,0] + #### output setting + ## generated video + synthetic_video_path = os.path.join(opt.res_path,os.path.basename(opt.image_path)[:-4] + '_video.mp4') + if os.path.exists(synthetic_video_path): + os.remove(synthetic_video_path) + videowriter_synthetic_video = cv2.VideoWriter(synthetic_video_path, cv2.VideoWriter_fourcc(*'XVID'), 30, (512, 512)) + ## app dense flow + Fapp_video_path = synthetic_video_path.replace('video.mp4','Fapp.mp4') + if os.path.exists(Fapp_video_path): + os.remove(Fapp_video_path) + videowriter_Fapp = cv2.VideoWriter(Fapp_video_path, cv2.VideoWriter_fourcc(*'XVID'), 30, (512, 512)) + ## matting mask + matting_mask_path = synthetic_video_path.replace('video.mp4','matting_mask.mp4') + if os.path.exists(matting_mask_path): + os.remove(matting_mask_path) + videowriter_matting_mask = cv2.VideoWriter(matting_mask_path, cv2.VideoWriter_fourcc(*'XVID'), 30, (128, 128)) + ## revised dense flow + revised_dense_path = synthetic_video_path.replace('video.mp4', 'revised_dense.mp4') + if os.path.exists(revised_dense_path): + os.remove(revised_dense_path) + videowriter_revised_dense = cv2.VideoWriter(revised_dense_path, cv2.VideoWriter_fourcc(*'XVID'), 30, (128, 128)) + ## foreground mask + foreground_mask_path = synthetic_video_path.replace('video.mp4', 'foreground_mask.mp4') + if os.path.exists(foreground_mask_path): + os.remove(foreground_mask_path) + videowriter_foreground_mask = cv2.VideoWriter(foreground_mask_path, cv2.VideoWriter_fourcc(*'XVID'), 30, (128, 128)) + ## warped image + warped_image_path = synthetic_video_path.replace('video.mp4', 'warped_image.mp4') + if os.path.exists(warped_image_path): + os.remove(warped_image_path) + videowriter_warped_image = cv2.VideoWriter(warped_image_path, cv2.VideoWriter_fourcc(*'XVID'), 30, (128, 128)) + ######### generate video frame by frame + frame_length = Fapp.shape[0] + for i in range(frame_length): + print('generating frame {}/{} '.format(i, frame_length)) + Fapp_i = Fapp[i,:,:] + Fapp_i_visual = visualize_dense_flow(Fapp_i - make_coordinate_grid((Fapp_i.shape[0],Fapp_i.shape[1]))) + videowriter_Fapp.write(Fapp_i_visual) + with torch.no_grad(): + Fapp_i = torch.from_numpy(Fapp_i).float().cuda().unsqueeze(0) + res_out = video_generator(reference_tensor, Fapp_i) + ## synthetic_video + synthetic_video_i = res_out['synthetic_image'] * 255 + synthetic_video_i = synthetic_video_i.cpu().squeeze().permute(1, 2, 0).float().detach().numpy().astype(np.uint8) + videowriter_synthetic_video.write(synthetic_video_i) + ## warped image + warped_image_i = res_out['warped_image']* 255 + warped_image_i = warped_image_i.cpu().squeeze().permute(1, 2, 0).float().detach().numpy().astype(np.uint8) + videowriter_warped_image.write(warped_image_i) + ## foreground_mask + foreground_mask_i = res_out['foreground_mask']* 255 + foreground_mask_i =np.expand_dims(foreground_mask_i.cpu().squeeze().detach().numpy().astype(np.uint8),2) + foreground_mask_i = foreground_mask_i.repeat(3,2) + videowriter_foreground_mask.write(foreground_mask_i) + ## dense_flow_foreground_vis + dense_flow_foreground_vis_i = foreground_mask_i/255.0 * cv2.resize(Fapp_i_visual,(128,128)) + dense_flow_foreground_vis_i = dense_flow_foreground_vis_i.astype(np.uint8) + videowriter_revised_dense.write(dense_flow_foreground_vis_i) + ### matting_mask + matting_mask_i = res_out['matting_mask'] * 255 + matting_mask_i = matting_mask_i.cpu().squeeze().detach().numpy().astype(np.uint8) + matting_mask_i = np.expand_dims(matting_mask_i,2).repeat(3,2) + videowriter_matting_mask.write(matting_mask_i) + + videowriter_Fapp.release() + videowriter_synthetic_video.release() + videowriter_warped_image.release() + videowriter_foreground_mask.release() + videowriter_revised_dense.release() + videowriter_matting_mask.release() + + if os.path.exists(opt.audio_path): + video_add_audio_path = synthetic_video_path.replace('.mp4', '_add_audio.mp4') + if os.path.exists(video_add_audio_path): + os.remove(video_add_audio_path) + cmd = 'ffmpeg -i {} -i {} -c:v copy -c:a aac -strict experimental -map 0:v:0 -map 1:a:0 {}'.format( + synthetic_video_path, + opt.audio_path, + video_add_audio_path) + subprocess.call(cmd, shell=True) + + diff --git a/code_animation2video/models.py b/code_animation2video/models.py new file mode 100644 index 0000000..014680c --- /dev/null +++ b/code_animation2video/models.py @@ -0,0 +1,292 @@ +from torch import nn +from torch import nn +import torch.nn.functional as F +import torch +from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d + +class ResBlock2d(nn.Module): + def __init__(self, in_features, kernel_size, padding): + super(ResBlock2d, self).__init__() + self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size, + padding=padding) + self.norm1 = BatchNorm2d(in_features) + self.norm2 = BatchNorm2d(in_features) + self.relu = nn.ReLU() + def forward(self, x): + out = self.norm1(x) + out = self.relu(out) + out = self.conv1(out) + out = self.norm2(out) + out = self.relu(out) + out = self.conv2(out) + out += x + return out + +class UpBlock2d(nn.Module): + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(UpBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = BatchNorm2d(out_features) + self.relu = nn.ReLU() + def forward(self, x): + out = F.interpolate(x, scale_factor=2) + out = self.conv(out) + out = self.norm(out) + out = F.relu(out) + return out + +class DownBlock2d(nn.Module): + def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1): + super(DownBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size, + padding=padding, groups=groups) + self.norm = BatchNorm2d(out_features) + self.pool = nn.AvgPool2d(kernel_size=(2, 2)) + self.relu = nn.ReLU() + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = self.relu(out) + out = self.pool(out) + return out + +class SameBlock2d(nn.Module): + def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1): + super(SameBlock2d, self).__init__() + self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, + kernel_size=kernel_size, padding=padding, groups=groups) + self.norm = BatchNorm2d(out_features) + self.relu = nn.ReLU() + def forward(self, x): + out = self.conv(x) + out = self.norm(out) + out = self.relu(out) + return out + +class HourglassEncoder(nn.Module): + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(HourglassEncoder, self).__init__() + down_blocks = [] + for i in range(num_blocks): + down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)), + min(max_features, block_expansion * (2 ** (i + 1))), + kernel_size=3, padding=1)) + self.down_blocks = nn.ModuleList(down_blocks) + + def forward(self, x): + outs = [x] + for down_block in self.down_blocks: + outs.append(down_block(outs[-1])) + return outs + +class HourglassDecoder(nn.Module): + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(HourglassDecoder, self).__init__() + up_blocks = [] + for i in range(num_blocks)[::-1]: + in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1))) + out_filters = min(max_features, block_expansion * (2 ** i)) + up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1)) + self.up_blocks = nn.ModuleList(up_blocks) + self.out_filters = block_expansion + in_features + def forward(self, x): + out = x.pop() + for up_block in self.up_blocks: + out = up_block(out) + skip = x.pop() + out = torch.cat([out, skip], dim=1) + return out + +class Hourglass(nn.Module): + def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256): + super(Hourglass, self).__init__() + self.encoder = HourglassEncoder(block_expansion, in_features, num_blocks, max_features) + self.decoder = HourglassDecoder(block_expansion, in_features, num_blocks, max_features) + self.out_filters = self.decoder.out_filters + def forward(self, x): + return self.decoder(self.encoder(x)) + +class AntiAliasInterpolation2d(nn.Module): + """ + Band-limited downsampling, for better preservation of the input signal. + """ + def __init__(self, channels, scale): + super(AntiAliasInterpolation2d, self).__init__() + sigma = (1 / scale - 1) / 2 + kernel_size = 2 * round(sigma * 4) + 1 + self.ka = kernel_size // 2 + self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka + + kernel_size = [kernel_size, kernel_size] + sigma = [sigma, sigma] + # The gaussian kernel is the product of the + # gaussian function of each dimension. + kernel = 1 + meshgrids = torch.meshgrid( + [ + torch.arange(size, dtype=torch.float32) + for size in kernel_size + ] + ) + for size, std, mgrid in zip(kernel_size, sigma, meshgrids): + mean = (size - 1) / 2 + kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2)) + + # Make sure sum of values in gaussian kernel equals 1. + kernel = kernel / torch.sum(kernel) + # Reshape to depthwise convolutional weight + kernel = kernel.view(1, 1, *kernel.size()) + kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) + + self.register_buffer('weight', kernel) + self.groups = channels + self.scale = scale + + def forward(self, input): + if self.scale == 1.0: + return input + + out = F.pad(input, (self.ka, self.kb, self.ka, self.kb)) + out = F.conv2d(out, weight=self.weight, groups=self.groups) + out = F.interpolate(out, scale_factor=(self.scale, self.scale)) + + return out + +class Encoder(nn.Module): + def __init__(self, num_channels, num_down_blocks=3, block_expansion=64, max_features=512, + ): + super(Encoder, self).__init__() + self.in_conv = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3)) + down_blocks = [] + for i in range(num_down_blocks): + in_features = min(max_features, block_expansion * (2 ** i)) + out_features = min(max_features, block_expansion * (2 ** (i + 1))) + down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) + self.down_blocks = nn.Sequential(*down_blocks) + def forward(self, image): + out = self.in_conv(image) + out = self.down_blocks(out) + return out + +class Bottleneck(nn.Module): + def __init__(self, num_bottleneck_blocks,num_down_blocks=3, block_expansion=64, max_features=512): + super(Bottleneck, self).__init__() + bottleneck = [] + in_features = min(max_features, block_expansion * (2 ** num_down_blocks)) + for i in range(num_bottleneck_blocks): + bottleneck.append(ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1))) + self.bottleneck = nn.Sequential(*bottleneck) + def forward(self, feature_map): + out = self.bottleneck(feature_map) + return out + +class Decoder(nn.Module): + def __init__(self,num_channels, num_down_blocks=3, block_expansion=64, max_features=512): + super(Decoder, self).__init__() + up_blocks = [] + for i in range(num_down_blocks): + in_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i))) + out_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i - 1))) + up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1))) + self.up_blocks = nn.Sequential(*up_blocks) + self.out_conv = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3)) + self.sigmoid = nn.Sigmoid() + def forward(self, feature_map): + out = self.up_blocks(feature_map) + out = self.out_conv(out) + out = self.sigmoid(out) + return out + +def warp_image(image, motion_flow): + _, h_old, w_old, _ = motion_flow.shape + _, _, h, w = image.shape + if h_old != h or w_old != w: + motion_flow = motion_flow.permute(0, 3, 1, 2) + motion_flow = F.interpolate(motion_flow, size=(h, w), mode='bilinear') + motion_flow = motion_flow.permute(0, 2, 3, 1) + return F.grid_sample(image, motion_flow) + +def make_coordinate_grid(spatial_size, type): + h, w = spatial_size + x = torch.arange(w).type(type) + y = torch.arange(h).type(type) + x = (2 * (x / (w - 1)) - 1) + y = (2 * (y / (h - 1)) - 1) + yy = y.view(-1, 1).repeat(1, w) + xx = x.view(1, -1).repeat(h, 1) + meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2) + return meshed + +class ForegroundMatting(nn.Module): + def __init__(self, num_channels,scale_factor,matting_channel,num_blocks,block_expansion, max_features): + super(ForegroundMatting, self).__init__() + self.down_sample_image = AntiAliasInterpolation2d(num_channels, scale_factor) + self.down_sample_flow = AntiAliasInterpolation2d(2, scale_factor) + self.hourglass = Hourglass(block_expansion=block_expansion, in_features= num_channels * 2 + 2, + max_features=max_features, num_blocks=num_blocks) + self.foreground_mask = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3)) + self.matting_mask = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3)) + self.matting = nn.Conv2d(self.hourglass.out_filters, matting_channel, kernel_size=(7, 7), padding=(3, 3)) + self.scale_factor = scale_factor + self.sigmoid = nn.Sigmoid() + def forward(self, reference_image, dense_flow): + ''' + source_image : b x c x h x w + dense_tensor: b x h x w x 2 + ''' + res_out = {} + if self.scale_factor != 1: #down sample the image + reference_image = self.down_sample_image(reference_image) + dense_flow = self.down_sample_flow(dense_flow.permute(0,3,1,2)).permute(0,2,3,1) + batch, _, h, w = reference_image.shape + warped_image = warp_image(reference_image, dense_flow)#warp the image with dense flow + res_out['warped_image'] = warped_image + hourglass_input = torch.cat([reference_image,dense_flow.permute(0,3,1,2),warped_image], dim=1) + hourglass_out = self.hourglass(hourglass_input) + foreground_mask = self.foreground_mask(hourglass_out) # compute foreground mask + foreground_mask = self.sigmoid(foreground_mask).permute(0,2,3,1) + res_out['foreground_mask'] = foreground_mask + grid_flow = make_coordinate_grid((h, w), dense_flow.type()) + dense_flow_foreground = dense_flow * foreground_mask + (1-foreground_mask) * grid_flow.unsqueeze(0) ## revise the dense flow + res_out['dense_flow_foreground'] = dense_flow_foreground + res_out['dense_flow_foreground_vis'] = dense_flow * foreground_mask + matting_mask = self.matting_mask(hourglass_out) # compute matting mask + matting_mask = self.sigmoid(matting_mask) + res_out['matting_mask'] = matting_mask + matting_image = self.matting(hourglass_out) # computing matting image + res_out['matting_image'] = matting_image + return res_out + + + +class VideoGenerator(nn.Module): + def __init__(self, num_channels, encoder_num_down_blocks=3,encoder_block_expansion=64, + encoder_max_features=512, houglass_num_blocks=5, + houglass_block_expansion = 64,houglass_max_features = 1024, num_bottleneck_blocks=6): + super(VideoGenerator, self).__init__() + self.encoder = Encoder(num_channels,encoder_num_down_blocks, + encoder_block_expansion,encoder_max_features) + matting_channel = int(min(encoder_max_features, encoder_block_expansion * (2 ** encoder_num_down_blocks))) + self.foreground_matting = ForegroundMatting(num_channels,scale_factor=1/(2**encoder_num_down_blocks),matting_channel = matting_channel, + num_blocks = houglass_num_blocks,block_expansion =houglass_block_expansion, + max_features = houglass_max_features) + self.bottleneck = Bottleneck(num_bottleneck_blocks,encoder_num_down_blocks, + encoder_block_expansion,encoder_max_features) + self.decoder = Decoder(num_channels,encoder_num_down_blocks, encoder_block_expansion, + encoder_max_features) + def forward(self, reference_image,dense_flow): + ''' + source_image: b x c x h x w + dense_flow: b x h x w x 2 + ''' + feature_map = self.encoder(reference_image) ## compute feature map + res_out = self.foreground_matting(reference_image, dense_flow) ## compute matting & revise dense flow + assert feature_map.shape[2] == res_out['matting_mask'].shape[2] and feature_map.shape[3] == res_out['matting_mask'].shape[3] + warped_feature_map = warp_image(feature_map, res_out['dense_flow_foreground']) * res_out['matting_mask'] + (1-res_out['matting_mask']) * res_out['matting_image'] + warped_feature_map = self.bottleneck(warped_feature_map) # decode feature map + synthetic_image = self.decoder(warped_feature_map) # decode feature map + res_out['synthetic_image'] = synthetic_image + return res_out \ No newline at end of file diff --git a/code_animation2video/opts.py b/code_animation2video/opts.py new file mode 100644 index 0000000..68a5fa0 --- /dev/null +++ b/code_animation2video/opts.py @@ -0,0 +1,24 @@ +import argparse +def parse_opts(): + parser = argparse.ArgumentParser(description='animation2video') + # ========================= Input Configs ========================== + parser.add_argument('--model_path', type=str, default=r'./checkpoints/checkpoint_animation2video.pth', help='trained model path') + parser.add_argument('--image_path', type=str, default=r'./test_data/taile.jpg', help='reference image path') + parser.add_argument('--dense_flow_path', type=str, default=r'./test_data/taile_Fapp.npy', help='reference approximate dense flow path') + parser.add_argument('--audio_path', type=str, default=r'./test_data/chuanpu.wav', help='input audio path') + parser.add_argument('--res_path', type=str, default=r'./result', help='result path') + # ========================= Base Configs ========================== + parser.add_argument('--input_channel', type=int, default=3, help='input image channels') + parser.add_argument('--out_channel', type=int, default=3, help='output image channels') + parser.add_argument('--image_size', type=int, default=512, help='image size') + #========================= Network Configs ========================== + parser.add_argument('--encoder_num_down_blocks', type=int, default=2, help='network setting') + parser.add_argument('--encoder_block_expansion', type=int, default=64, help='network setting') + parser.add_argument('--encoder_max_features', type=int, default=512, help='network setting') + parser.add_argument('--num_bottleneck_blocks', type=int, default=2, help='network setting') + parser.add_argument('--houglass_num_blocks', type=int, default=5, help='network setting') + parser.add_argument('--houglass_block_expansion', type=int, default=64, help='network setting') + parser.add_argument('--houglass_max_features', type=int, default=512, help='network setting') + args = parser.parse_args() + + return args \ No newline at end of file diff --git a/code_animation2video/readme.md b/code_animation2video/readme.md new file mode 100644 index 0000000..b5d748a --- /dev/null +++ b/code_animation2video/readme.md @@ -0,0 +1,13 @@ +# Code of animation-to-video module +### inference + + 1. Download the trained model (`checkpoint_animation2video.pth`), approximate dense flow (`mengnalisa_Fapp.npy, taile_Fapp.npy`) in [google drive](https://drive.google.com/drive/folders/1OM3AE6rjZKY1v6PVDnv-YwlmkBZOhw1L?usp=sharing). + 2. Put the `checkpoint_animation2video.pth` into **./checkpoints** + 3. Put the `mengnalisa_Fapp.npy, taile_Fapp.npy` into **./test_data** + 4. run +> python inference.py --image_path=./test_data/mengnalisa.jpg --dense_flow_path=./test_data/mengnalisa_Fapp.npy + +or +> python inference.py --image_path=./test_data/taile.jpg --dense_flow_path=./test_data/taile_Fapp.npy + +to generate all intermediate results of animation-to-video module. \ No newline at end of file diff --git a/code_animation2video/requirements.txt b/code_animation2video/requirements.txt new file mode 100644 index 0000000..3065e88 --- /dev/null +++ b/code_animation2video/requirements.txt @@ -0,0 +1,5 @@ +torch==1.0.0 +torchvision==0.2.2 +opencv_python==4.4.0.46 +~umpy==1.19.1 +numpy==1.21.2 diff --git a/code_animation2video/sync_batchnorm/__init__.py b/code_animation2video/sync_batchnorm/__init__.py new file mode 100644 index 0000000..bc8709d --- /dev/null +++ b/code_animation2video/sync_batchnorm/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +# File : __init__.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d +from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/code_animation2video/sync_batchnorm/__pycache__/__init__.cpython-36.pyc b/code_animation2video/sync_batchnorm/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..3104579 Binary files /dev/null and b/code_animation2video/sync_batchnorm/__pycache__/__init__.cpython-36.pyc differ diff --git a/code_animation2video/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc b/code_animation2video/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc new file mode 100644 index 0000000..1590c73 Binary files /dev/null and b/code_animation2video/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc differ diff --git a/code_animation2video/sync_batchnorm/__pycache__/comm.cpython-36.pyc b/code_animation2video/sync_batchnorm/__pycache__/comm.cpython-36.pyc new file mode 100644 index 0000000..c7b8551 Binary files /dev/null and b/code_animation2video/sync_batchnorm/__pycache__/comm.cpython-36.pyc differ diff --git a/code_animation2video/sync_batchnorm/__pycache__/replicate.cpython-36.pyc b/code_animation2video/sync_batchnorm/__pycache__/replicate.cpython-36.pyc new file mode 100644 index 0000000..4db52bc Binary files /dev/null and b/code_animation2video/sync_batchnorm/__pycache__/replicate.cpython-36.pyc differ diff --git a/code_animation2video/sync_batchnorm/batchnorm.py b/code_animation2video/sync_batchnorm/batchnorm.py new file mode 100644 index 0000000..5f4e763 --- /dev/null +++ b/code_animation2video/sync_batchnorm/batchnorm.py @@ -0,0 +1,315 @@ +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast + +from .comm import SyncMaster + +__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dementions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) + + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + def forward(self, input): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + + # Resize the input to (B, C, -1). + input_shape = input.size() + input = input.view(input.size(0), self.num_features, -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + + # Reduce-and-broadcast the statistics. + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # Compute the output. + if self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + + # Always using same "device order" makes the ReduceAdd operation faster. + # Thanks to:: Tete Xiao (http://tetexiao.com/) + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + + return outputs + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data + self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data + + return mean, bias_var.clamp(self.eps) ** -0.5 + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm1d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm2d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm3d, self)._check_input_dim(input) diff --git a/code_animation2video/sync_batchnorm/comm.py b/code_animation2video/sync_batchnorm/comm.py new file mode 100644 index 0000000..922f8c4 --- /dev/null +++ b/code_animation2video/sync_batchnorm/comm.py @@ -0,0 +1,137 @@ +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def __getstate__(self): + return {'master_callback': self._master_callback} + + def __setstate__(self, state): + self.__init__(state['master_callback']) + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/code_animation2video/sync_batchnorm/replicate.py b/code_animation2video/sync_batchnorm/replicate.py new file mode 100644 index 0000000..b71c7b8 --- /dev/null +++ b/code_animation2video/sync_batchnorm/replicate.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate diff --git a/code_animation2video/sync_batchnorm/unittest.py b/code_animation2video/sync_batchnorm/unittest.py new file mode 100644 index 0000000..0675c02 --- /dev/null +++ b/code_animation2video/sync_batchnorm/unittest.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# File : unittest.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import unittest + +import numpy as np +from torch.autograd import Variable + + +def as_numpy(v): + if isinstance(v, Variable): + v = v.data + return v.cpu().numpy() + + +class TorchTestCase(unittest.TestCase): + def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): + npa, npb = as_numpy(a), as_numpy(b) + self.assertTrue( + np.allclose(npa, npb, atol=atol), + 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) + ) diff --git a/code_animation2video/test_data/chuanpu.wav b/code_animation2video/test_data/chuanpu.wav new file mode 100644 index 0000000..60fab9c Binary files /dev/null and b/code_animation2video/test_data/chuanpu.wav differ diff --git a/code_animation2video/test_data/mengnalisa.jpg b/code_animation2video/test_data/mengnalisa.jpg new file mode 100644 index 0000000..622e18e Binary files /dev/null and b/code_animation2video/test_data/mengnalisa.jpg differ diff --git a/code_animation2video/test_data/taile.jpg b/code_animation2video/test_data/taile.jpg new file mode 100644 index 0000000..23f77ec Binary files /dev/null and b/code_animation2video/test_data/taile.jpg differ diff --git a/code_animation2video/utils.py b/code_animation2video/utils.py new file mode 100644 index 0000000..5caf246 --- /dev/null +++ b/code_animation2video/utils.py @@ -0,0 +1,24 @@ +import numpy as np +import cv2 + +def make_coordinate_grid(image_size): + h, w = image_size + x = np.arange(w) + y = np.arange(h) + x = (2 * (x / (w - 1)) - 1) + y = (2 * (y / (h - 1)) - 1) + xx = x.reshape(1, -1).repeat(h, axis=0) + yy = y.reshape(-1, 1).repeat(w, axis=1) + meshed = np.stack([xx, yy], 2) + return meshed + +def visualize_dense_flow(dense_flow): + hsv = np.zeros([dense_flow.shape[0], dense_flow.shape[1]]) + hsv = np.stack([hsv, hsv, hsv], 2).astype(np.uint8) + hsv[..., 1] = 255 + mag, ang = cv2.cartToPolar(dense_flow[:, :, 0], dense_flow[:, :, 1]) + hsv[..., 0] = ang * 180 / np.pi / 2 + hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) + rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) + rgb = rgb.astype(np.uint8) + return rgb \ No newline at end of file diff --git a/code_constructing_Fapp/inference.py b/code_constructing_Fapp/inference.py new file mode 100644 index 0000000..550a3e4 --- /dev/null +++ b/code_constructing_Fapp/inference.py @@ -0,0 +1,132 @@ +import numpy as np +from scipy.interpolate import griddata +import argparse +import os + +def to_image(prokected_points, h, w): + ''' + transform the center to (0,0) + ''' + image_vertices = prokected_points.copy() + image_vertices[:,0] = image_vertices[:,0] + w/2 + image_vertices[:,1] = image_vertices[:,1] + h/2 + image_vertices[:,1] = h - image_vertices[:,1] - 1 + return image_vertices + +def orthogonal_transform(points3D, scale, R, t): + ''' + orthogonal transform + ''' + t3d = np.squeeze(np.array(t, dtype = np.float32)) + transformed_vertices = scale * points3D.dot(R.T) + t3d[np.newaxis, :] + + return transformed_vertices + +def project_to_image(points3D, scale, R, t, h, w): + ''' + project 3D points to 2d plane in orthogonal projection + ''' + prokected_points = orthogonal_transform(points3D, scale, R, t) + prokected_points = to_image(prokected_points, h, w) + return prokected_points + +def compute_projected_mesh_points(shape_para,exp_para,R,T,scale,image_size,model_3dmm): + ''' + compute the projected mesh points from facial animation parameters + :param shape_para: the shape parameter of reference face + :param scale: the scale parameter in orthogonal projection + :param exp_para: the expression parameter in orthogonal projection + :param R: the head rotation in orthogonal projection + :param T: the head translation in orthogonal projection + :param image_size: the size of image + :param model_3dmm: 3dmm model + :return: projected_mesh_points + ''' + shape_para = np.expand_dims(shape_para, 1) + exp_para = np.expand_dims(exp_para, 1) + R_matrix = R.reshape((3, 3)) + ## compute 3d mesh points in 3DMM + mesh_points_3D = model_3dmm.generate_vertices(shape_para, exp_para) + ## project 3D points to 2D plane + projected_2Dpoints = project_to_image(mesh_points_3D, scale, R_matrix, T, image_size[0], image_size[1]) + return projected_2Dpoints + + +def make_coordinate_grid(image_size): + h, w = image_size + x = np.arange(w) + y = np.arange(h) + x = (2 * (x / (w - 1)) - 1) + y = (2 * (y / (h - 1)) - 1) + xx = x.reshape(1, -1).repeat(h, axis=0) + yy = y.reshape(-1, 1).repeat(w, axis=1) + meshed = np.stack([xx, yy], 2) + return meshed + +def construct_Fapp(reference_projected_mesh_points, + drive_projected_mesh_points,image_size): + ''' + compute Fapp from projected mesh points + reference_projected_mesh_points: the projected mesh points of reference image + drive_projected_mesh_points: the driving projected mesh points + ''' + ## resize to -1 ~ 1 + reference_projected_mesh_points = (reference_projected_mesh_points / image_size * 2) - 1 + ### compute the max heigh of face + face_max_h = np.max(drive_projected_mesh_points[:, 1]).astype(np.int) + ## resize to -1 ~ 1 + drive_projected_mesh_points = (drive_projected_mesh_points / image_size * 2) - 1 + drive_projected_mesh_points_yx = drive_projected_mesh_points[:, [1, 0]] + ## compute sparse dense flow + sparse_dense_flow = reference_projected_mesh_points - drive_projected_mesh_points + ## compute average head motion + mean_dense_flow = np.mean(sparse_dense_flow, axis=0) + ## compute the dense flow in head-related region + grid_nums = complex(str(image_size) + "j") + grid_y, grid_x = np.mgrid[-1:1:grid_nums, -1:1:grid_nums] + dense_foreground_flow_x = griddata(drive_projected_mesh_points_yx, sparse_dense_flow[:, 0], (grid_y, grid_x), method='nearest') + ## compute dense flow in torso related region + dense_foreground_flow_x[face_max_h:, :] = mean_dense_flow[0] + dense_foreground_flow_y = griddata(drive_projected_mesh_points_yx, sparse_dense_flow[:, 1], (grid_y, grid_x), method='nearest') + dense_foreground_flow_y[face_max_h:, :] = mean_dense_flow[1] + Fapp = np.stack([dense_foreground_flow_x, dense_foreground_flow_y], 2) + ## transform into grid data + grid_mesh = make_coordinate_grid((image_size,image_size)) + Fapp = grid_mesh + Fapp + + return Fapp + +def parse_opts(): + parser = argparse.ArgumentParser(description='construct Fapp') + parser.add_argument('--reference_projected_mesh_points_path', type=str, + default='./test_data/taile_source_points.npy', + help='the projected mesh points of reference image') + parser.add_argument('--drive_projected_mesh_points_path', type=str, + default='./test_data/taile_drive_points.npy', + help='the driving projected mesh points') + parser.add_argument('--image_size', type=int, default=512, help='the size of image') + parser.add_argument('--res_dir', type=str,default='./result',help='the dir of results') + args = parser.parse_args() + return args + +if __name__ == "__main__": + ''' + It is not allowed to share the 3DMM model, so we release the inference code + of constructing Fapp from projected mesh points.the function + "compute_projected_mesh_points" shows how to compute + projected mesh points from facial animation parameters. + ''' + opt = parse_opts() + reference_projected_mesh_points = np.load(opt.reference_projected_mesh_points_path) + drive_projected_mesh_points = np.load(opt.drive_projected_mesh_points_path) + frame_num = drive_projected_mesh_points.shape[0] + res_Fapp = [] + for i in range(frame_num): + print('construct {}/{} Fapp'.format(i,frame_num)) + Fapp_i = construct_Fapp(reference_projected_mesh_points, + drive_projected_mesh_points[i,:,:],opt.image_size) + res_Fapp.append(Fapp_i) + res_Fapp = np.stack(res_Fapp,0) + res_Fapp_path = os.path.join(opt.res_dir,os.path.basename(opt.reference_projected_mesh_points_path).replace('_source_points','_Fapp')) + np.save(res_Fapp_path,res_Fapp) + \ No newline at end of file diff --git a/code_constructing_Fapp/readme.md b/code_constructing_Fapp/readme.md new file mode 100644 index 0000000..33c79d4 --- /dev/null +++ b/code_constructing_Fapp/readme.md @@ -0,0 +1,12 @@ +# Code of constructing Fapp +### inference + + 1. Download the projected facial points (`mengnalisa_source_points.npy`,`mengnalisa_drive_points.npy`,`taile_source_points.npy`,`taile_drive_points.npy`) in [google drive](https://drive.google.com/drive/folders/1OM3AE6rjZKY1v6PVDnv-YwlmkBZOhw1L?usp=sharing). + 2. Put all files into **./test_data** + 4. run +> python inference.py--reference_projected_mesh_points_path=./test_data/taile_source_points.npy --drive_projected_mesh_points_path=./test_data/taile_drive_points.npy + +or +> python inference.py--reference_projected_mesh_points_path=./test_data/mengnalisa_source_points.npy --drive_projected_mesh_points_path=./test_data/mengnalisa_drive_points.npy + +to compute Fapp. \ No newline at end of file diff --git a/code_constructing_Fapp/requirements.txt b/code_constructing_Fapp/requirements.txt new file mode 100644 index 0000000..7293049 --- /dev/null +++ b/code_constructing_Fapp/requirements.txt @@ -0,0 +1,3 @@ +scipy==1.3.1 +~umpy==1.19.1 +numpy==1.21.2 diff --git a/download.py b/download.py index 3e21b61..8d80d32 100644 --- a/download.py +++ b/download.py @@ -10,25 +10,62 @@ $ python download.py --output_dir /tmp/data/hdtf --num_workers 8 ``` -You need tqdm and youtube-dl libraries to be installed for this script to work. +You need tqdm, yt_dlp, and colorama libraries to be installed for this script to work. """ - import os import argparse +import subprocess +import pprint from typing import List, Dict from multiprocessing import Pool -import subprocess -from subprocess import Popen, PIPE from urllib import parse +import yt_dlp from tqdm import tqdm - +from colorama import init as cinit +from colorama import Fore subsets = ["RD", "WDA", "WRA"] +cinit(autoreset=True) def download_hdtf(source_dir: os.PathLike, output_dir: os.PathLike, num_workers: int, **process_video_kwargs): + """ + Downloads and processes videos from the HDTF dataset in parallel using multiprocessing. + + The function manages the download process by: + - Creating the necessary output directories. + - Constructing a download queue from files in the specified source directory. + - Using a multiprocessing pool to handle downloads and subsequent processing. + - Providing progress tracking with tqdm. + + After completing the download, a message is displayed with optional cleanup instructions to delete + temporary raw video files to save space. + + Args: + source_dir (os.PathLike): The directory containing HDTF metadata files, including video URLs, + crop data, time intervals, and resolution information for each video subset. + output_dir (os.PathLike): The directory where downloaded videos and processed files will be saved. + num_workers (int): The number of parallel worker processes to use for downloading. + **process_video_kwargs: Additional keyword arguments passed to `download_and_process_video`, + allowing custom settings for processing each video. + + Workflow: + 1. Creates the primary output directory and a subdirectory `_videos_raw` for raw downloads. + 2. Calls `construct_download_queue` to prepare a list of video download tasks based on the metadata + available in `source_dir`. Each entry in the queue includes details needed for downloading and processing. + 3. Uses a multiprocessing `Pool` to execute `download_and_process_video` for each video in `download_queue`, + with progress displayed via tqdm. + 4. After completing downloads, provides a message about optional cleanup for temporary video files. + + Returns: + None + + Raises: + AssertionError: If certain data inconsistencies are detected during download queue construction, such as + missing or malformed intervals, crops, or resolution information. + """ os.makedirs(output_dir, exist_ok=True) os.makedirs(os.path.join(output_dir, '_videos_raw'), exist_ok=True) @@ -37,57 +74,120 @@ def download_hdtf(source_dir: os.PathLike, output_dir: os.PathLike, num_workers: video_data=vd, output_dir=output_dir, **process_video_kwargs, - ) for vd in download_queue] + ) for vd in download_queue] pool = Pool(processes=num_workers) - tqdm_kwargs = dict(total=len(task_kwargs), desc=f'Downloading videos into {output_dir} (note: without sound)') + tqdm_kwargs = dict(total=len(task_kwargs), + desc=f'Downloading videos into {output_dir}') for _ in tqdm(pool.imap_unordered(task_proxy, task_kwargs), **tqdm_kwargs): pass + pool.close() + pool.join() - print('Download is finished, you can now (optionally) delete the following directories, since they are not needed anymore and occupy a lot of space:') - print(' -', os.path.join(output_dir, '_videos_raw')) + print(Fore.GREEN+'Download is finished, you can now (optionally) delete the following directories, since they are not needed anymore and occupy a lot of space:') + print(Fore.GREEN+' - '+os.path.join(output_dir, '_videos_raw')) def construct_download_queue(source_dir: os.PathLike, output_dir: os.PathLike) -> List[Dict]: + """ + Constructs a queue of videos to be downloaded and processed based on metadata from the HDTF dataset. + + This function reads metadata files for each subset in the HDTF dataset, which provide information on: + - Video URLs. + - Time intervals indicating segments to be extracted from each video. + - Crop coordinates defining the regions of interest. + - Resolution information for each video. + + For each valid video file, an entry is created in the download queue with detailed information required + for downloading, cropping, and segmenting. + + Args: + source_dir (os.PathLike): Path to the directory containing metadata files (`*_video_url.txt`, + `*_crop_wh.txt`, `*_annotion_time.txt`, and `*_resolution.txt`) for each subset. + output_dir (os.PathLike): Path to the directory where the downloaded and processed videos will be stored. + + Returns: + List[Dict]: A list of dictionaries, each representing a video to download and process. Each dictionary + contains the following keys: + - 'name': Combined subset and video name identifier. + - 'id': YouTube video ID extracted from the video URL. + - 'intervals': List of start and end times for each clip segment. + - 'crops': List of crop coordinates for each segment. + - 'output_dir': The output directory path for this video. + - 'resolution': Desired resolution for the video. + + Workflow: + 1. Reads metadata files for each subset (e.g., "RD", "WDA", "WRA") to gather video URLs, time intervals, crops, + and resolution information. + 2. For each video: + - Ensures it has valid time intervals and resolution data. + - Verifies that all segments have corresponding crop information. + - Discards videos missing required metadata, and prints warnings about invalid or missing data. + 3. Creates a download queue entry for each valid video with the required download and processing data. + + Raises: + AssertionError: If the video segment data is inconsistent, such as: + - Missing or malformed time intervals. + - Incomplete or non-square crop data. + These assertions ensure that only well-formed entries are added to the download queue. + + Example: + >>> construct_download_queue("HDTF_dataset", "/tmp/data/hdtf") + [{'name': 'RD_sample_video', 'id': 'abc123', 'intervals': [[0, 10], [15, 25]], + 'crops': [[0, 128, 0, 128], [0, 128, 0, 128]], 'output_dir': '/tmp/data/hdtf', 'resolution': '720p'}] + """ download_queue = [] for subset in subsets: - video_urls = read_file_as_space_separated_data(os.path.join(source_dir, f'{subset}_video_url.txt')) - crops = read_file_as_space_separated_data(os.path.join(source_dir, f'{subset}_crop_wh.txt')) - intervals = read_file_as_space_separated_data(os.path.join(source_dir, f'{subset}_annotion_time.txt')) - resolutions = read_file_as_space_separated_data(os.path.join(source_dir, f'{subset}_resolution.txt')) + video_urls = read_file_as_space_separated_data( + os.path.join(source_dir, f'{subset}_video_url.txt')) + crops = read_file_as_space_separated_data( + os.path.join(source_dir, f'{subset}_crop_wh.txt')) + intervals = read_file_as_space_separated_data( + os.path.join(source_dir, f'{subset}_annotion_time.txt')) + resolutions = read_file_as_space_separated_data( + os.path.join(source_dir, f'{subset}_resolution.txt')) for video_name, (video_url,) in video_urls.items(): if not f'{video_name}.mp4' in intervals: - print(f'Entire {subset}/{video_name} does not contain any clip intervals, hence is broken. Discarding it.') + print( + f'{Fore.RED}Clip {subset}/{video_name} does not contain any clip intervals. It will be discarded.') continue if not f'{video_name}.mp4' in resolutions or len(resolutions[f'{video_name}.mp4']) > 1: - print(f'Entire {subset}/{video_name} does not contain the resolution (or it is in a bad format), hence is broken. Discarding it.') + print(f'{Fore.RED}Clip {subset}/{video_name} does not contain an appropriate resolution (or it is in a bad format). It will be discarded.') continue - all_clips_intervals = [x.split('-') for x in intervals[f'{video_name}.mp4']] + all_clips_intervals = [x.split('-') + for x in intervals[f'{video_name}.mp4']] clips_crops = [] clips_intervals = [] for clip_idx, clip_interval in enumerate(all_clips_intervals): clip_name = f'{video_name}_{clip_idx}.mp4' if not clip_name in crops: - print(f'Clip {subset}/{clip_name} is not present in crops, hence is broken. Discarding it.') + print( + f'{Fore.RED}Discarding Clip: {subset}/{clip_name}. Clip is not present in crops.') continue + else: + print(f'{Fore.GREEN}Appending Clip: {subset}/{clip_name}') clips_crops.append(crops[clip_name]) clips_intervals.append(clip_interval) clips_crops = [list(map(int, cs)) for cs in clips_crops] if len(clips_crops) == 0: - print(f'Entire {subset}/{video_name} does not contain any crops, hence is broken. Discarding it.') + print( + f'{Fore.RED}Discarding {subset}/{video_name}. No cropped versions found.') continue assert len(clips_intervals) == len(clips_crops) - assert set([len(vi) for vi in clips_intervals]) == {2}, f"Broken time interval, {clips_intervals}" - assert set([len(vc) for vc in clips_crops]) == {4}, f"Broken crops, {clips_crops}" - assert all([vc[1] == vc[3] for vc in clips_crops]), f'Some crops are not square, {clips_crops}' + assert set([len(vi) for vi in clips_intervals]) == { + 2}, f"Broken time interval, {clips_intervals}" + assert set([len(vc) for vc in clips_crops]) == { + 4}, f"Broken crops, {clips_crops}" + assert all([vc[1] == vc[3] for vc in clips_crops] + ), f'Some crops are not square, {clips_crops}' download_queue.append({ 'name': f'{subset}_{video_name}', @@ -100,45 +200,147 @@ def construct_download_queue(source_dir: os.PathLike, output_dir: os.PathLike) - return download_queue - def task_proxy(kwargs): + """ + A proxy function to execute `download_and_process_video` with unpacked keyword arguments. + + This function serves as a wrapper that allows passing a dictionary of arguments (`kwargs`) + to the `download_and_process_video` function. It is primarily used in conjunction with + multiprocessing, where it enables the `Pool.imap_unordered` method to handle the video + processing tasks in parallel. + + Args: + kwargs (dict): A dictionary of arguments required by `download_and_process_video`. + This typically includes: + - 'video_data': A dictionary containing video details (ID, name, intervals, crops, etc.). + - 'output_dir': The directory path where processed clips will be saved. + + Returns: + None + + Usage: + The `task_proxy` function is designed for use with parallel processing. By passing a dictionary + of arguments instead of positional arguments, it enables compatibility with the multiprocessing + pool's mapping methods. + + Example: + >>> task_kwargs = {'video_data': {...}, 'output_dir': '/path/to/output'} + >>> task_proxy(task_kwargs) + + Notes: + This function simplifies the interface for multiprocessing tasks, allowing + `download_and_process_video` to be used directly within the parallel processing workflow + without modifying its original function signature. + """ return download_and_process_video(**kwargs) + def download_and_process_video(video_data: Dict, output_dir: str): """ - Downloads the video and cuts/crops it into several ones according to the provided time intervals + Downloads a video from YouTube and processes it by segmenting and cropping based on provided intervals and crop data. + + The function performs the following steps: + 1. Downloads the specified video to a raw file path within the `_videos_raw` subdirectory of `output_dir`. + 2. Iterates over the specified intervals and crop data to create individual video clips: + - Each clip is extracted according to its specified time interval. + - Each clip is cropped based on the coordinates provided in `video_data['crops']`. + 3. Saves each processed clip in `output_dir` with a unique name indicating the video and clip index. + + Args: + video_data (dict): A dictionary containing metadata for the video to be downloaded and processed. + Expected keys include: + - 'id': The YouTube ID of the video. + - 'name': A unique name identifier for the video. + - 'intervals': A list of time intervals (start, end) for each clip segment. + - 'crops': A list of crop coordinates (x, width, y, height) for each clip segment. + - 'resolution': The desired resolution of the video. + output_dir (str): Path to the directory where processed video clips will be saved. + + Workflow: + - Downloads the video using `download_video`, saving it as `{video_name}.mp4` in `_videos_raw`. + - For each time interval in `video_data['intervals']`: + - Extracts the segment and applies cropping according to the corresponding entry in `video_data['crops']`. + - Saves each clip in `output_dir` with a file name formatted as `{video_name}_{clip_idx:03d}.mp4`. + - Logs errors to the console if downloading or processing fails for a particular segment or crop. + + Returns: + None + + Raises: + ValueError: If the video cannot be downloaded or if any of the cropping or segmentation fails. + + Example: + >>> video_data = { + 'id': 'abc123', + 'name': 'sample_video', + 'intervals': [[0, 10], [15, 25]], + 'crops': [[0, 128, 0, 128], [10, 118, 10, 118]], + 'output_dir': '/tmp/data/hdtf', + 'resolution': '720p' + } + >>> download_and_process_video(video_data, '/tmp/data/hdtf') + + Notes: + - This function requires `ffmpeg` to be installed for segmenting and cropping video clips. + - Detailed logging is provided to indicate the status of each clip's download and processing. + - If `download_video` fails, an error message is printed to the console, and the function skips further processing. """ - raw_download_path = os.path.join(output_dir, '_videos_raw', f"{video_data['name']}.mp4") - raw_download_log_file = os.path.join(output_dir, '_videos_raw', f"{video_data['name']}_download_log.txt") - download_result = download_video(video_data['id'], raw_download_path, resolution=video_data['resolution'], log_file=raw_download_log_file) + raw_download_path = os.path.join( + output_dir, '_videos_raw', f"{video_data['name']}.mp4") + raw_download_log_file = os.path.join( + output_dir, '_videos_raw', f"{video_data['name']}_download_log.txt") + print(f"{Fore.LIGHTBLUE_EX} raw_download_path: {raw_download_path}") - if not download_result: - print('Failed to download', video_data) - print(f'See {raw_download_log_file} for details') - return + download_result = download_video( + video_data['id'], raw_download_path, log_file=raw_download_log_file) - # We do not know beforehand, what will be the resolution of the downloaded video - # Youtube-dl selects a (presumably) highest one - video_resolution = get_video_resolution(raw_download_path) - if not video_resolution != video_data['resolution']: - print(f"Downloaded resolution is not correct for {video_data['name']}: {video_resolution} vs {video_data['name']}. Discarding this video.") + if not download_result: + print(f'{Fore.RED} Failed to download {video_data["name"]}') + print(f'{Fore.RED} See {raw_download_log_file} for details') return for clip_idx in range(len(video_data['intervals'])): start, end = video_data['intervals'][clip_idx] clip_name = f'{video_data["name"]}_{clip_idx:03d}' clip_path = os.path.join(output_dir, clip_name + '.mp4') - crop_success = cut_and_crop_video(raw_download_path, clip_path, start, end, video_data['crops'][clip_idx]) + crop_success = cut_and_crop_video( + raw_download_path, clip_path, start, end, video_data['crops'][clip_idx]) if not crop_success: - print(f'Failed to cut-and-crop clip #{clip_idx}', video_data) + print(f'{Fore.RED} Failed to cut-and-crop clip #{clip_idx}') + pprint.pprint(video_data, indent=4, sort_dicts=False) continue - def read_file_as_space_separated_data(filepath: os.PathLike) -> Dict: """ - Reads a file as a space-separated dataframe, where the first column is the index + Reads a space-separated file and returns its contents as a dictionary. + + This function reads a text file where each line contains space-separated values. + The first value in each line is treated as the key, and the remaining values are + stored as a list associated with that key. This is useful for parsing metadata + files with a consistent space-separated format. + + Args: + filepath (os.PathLike): The path to the file to be read. + + Returns: + Dict: A dictionary where each key corresponds to the first item in a line, + and each value is a list of the remaining items in that line. + + Example: + Suppose `example.txt` contains: + video1 1280 720 + video2 640 480 + >>> read_file_as_space_separated_data("example.txt") + {'video1': ['1280', '720'], 'video2': ['640', '480']} + + Notes: + - Blank lines are not supported and may cause errors. + - Each line must contain at least one space-separated value to be valid. + + Raises: + IOError: If the file cannot be opened or read. """ with open(filepath, 'r') as f: lines = f.read().splitlines() @@ -148,91 +350,196 @@ def read_file_as_space_separated_data(filepath: os.PathLike) -> Dict: return data -def download_video(video_id, download_path, resolution: int=None, video_format="mp4", log_file=None): +def download_video(video_id, download_path, video_format="bestvideo+bestaudio", log_file=None): """ - Download video from YouTube. - :param video_id: YouTube ID of the video. - :param download_path: Where to save the video. - :param video_format: Format to download. - :param log_file: Path to a log file for youtube-dl. - :return: Tuple: path to the downloaded video and a bool indicating success. - - Copy-pasted from https://github.com/ytdl-org/youtube-dl + Downloads a YouTube video in the specified format and saves it to a given path. + + This function uses `yt-dlp` to download a video by its YouTube ID, selecting the highest + available quality by default. It provides options for specifying a custom format or resolution + and can log download progress and errors to a specified log file. + + Args: + video_id (str): The YouTube ID of the video to download. + download_path (str): The full path (including file name) where the downloaded video will be saved. + video_format (str, optional): The video and audio format selection for yt-dlp. Defaults to + "bestvideo+bestaudio" for highest available quality. + log_file (str, optional): Path to a file where log messages (debug, warnings, and errors) + will be recorded. If None, logging to a file is disabled. + + Returns: + Tuple[str, bool]: A tuple where: + - The first element is the path to the downloaded video file. + - The second element is a boolean indicating success (True if the file + was downloaded successfully, False otherwise). + + Workflow: + 1. Constructs `yt-dlp` options based on the provided arguments, including `format`, `outtmpl`, + and `logger` if a log file is specified. + 2. Attempts to download the video. If successful, verifies the file exists at `download_path`. + 3. Logs errors if the download fails and saves them to `log_file` if specified. + + Raises: + Exception: Any exceptions during the download are logged if `log_file` is provided, and the + function will return False for success. + + Example: + >>> download_video("abc123", "/path/to/video.mp4", log_file="/path/to/log.txt") + ("/path/to/video.mp4", True) + + Notes: + - Requires `yt-dlp` to be installed. + - Requires `ffmpeg` if merging video and audio streams is necessary. + - Custom logging is provided through a nested `Logger` class if `log_file` is specified. """ - # if os.path.isfile(download_path): return True # File already exists - - if log_file is None: - stderr = subprocess.DEVNULL - else: - stderr = open(log_file, "a") - video_selection = f"bestvideo[ext={video_format}]" - video_selection = video_selection if resolution is None else f"{video_selection}[height={resolution}]" - command = [ - "youtube-dl", - "https://youtube.com/watch?v={}".format(video_id), "--quiet", "-f", - video_selection, - "--output", download_path, - "--no-continue" - ] - return_code = subprocess.call(command, stderr=stderr) - success = return_code == 0 - - if log_file is not None: - stderr.close() - - return success and os.path.isfile(download_path) - - -def get_video_resolution(video_path: os.PathLike) -> int: - command = ' '.join([ - "ffprobe", - "-v", "error", - "-select_streams", "v:0", "-show_entries", "stream=height", "-of", "csv=p=0", - video_path - ]) - - process = Popen(command, stdout=PIPE, shell=True) - (output, err) = process.communicate() - return_code = process.wait() - success = return_code == 0 - - if not success: - print('Command failed:', command) - return -1 - - return int(output) - + class Logger: + """ + A simple logger for yt-dlp to write debug, warning, and error messages to a specified log file. + + Attributes: + log_path (str): Path to the log file where messages will be written. + """ + + def __init__(self, log_path): + """ + Initializes the Logger with a log file path. + + :param log_path: Path to the file where log messages should be saved. + """ + self.log_path = log_path + + def debug(self, msg): + """ + Logs a debug message. + + :param msg: The debug message to log. + """ + with open(self.log_path, "a") as f: + f.write(f"DEBUG: {msg}\n") + + def warning(self, msg): + """ + Logs a warning message. + + :param msg: The warning message to log. + """ + with open(self.log_path, "a") as f: + f.write(f"WARNING: {msg}\n") + + def error(self, msg): + """ + Logs an error message. + + :param msg: The error message to log. + """ + with open(self.log_path, "a") as f: + f.write(f"ERROR: {msg}\n") + + # Define yt-dlp options + ydl_opts = { + # Set video format to best video and audio by default + 'format': video_format, + 'outtmpl': download_path, # Output path template + 'quiet': True, # Suppress verbose output + 'merge_output_format': 'mp4', # Ensure output format is MP4 + } + + # If a log file is specified, configure the logger + if log_file: + ydl_opts['logger'] = Logger(log_file) + + # Download the video using yt-dlp + try: + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + ydl.download([f'https://www.youtube.com/watch?v={video_id}']) + success = True + except Exception as e: + success = False + if log_file: + with open(log_file, "a") as f: + f.write( + f"ERROR: Failed to download {video_id}. Exception: {str(e)}\n") + + result = success and os.path.isfile(download_path) + return download_path, result def cut_and_crop_video(raw_video_path, output_path, start, end, crop: List[int]): - # if os.path.isfile(output_path): return True # File already exists - + """ + Cuts and crops a video segment from a specified start to end time and saves it to the output path. + + This function uses `ffmpeg` to: + 1. Extract a segment of the video from `start` to `end` time. + 2. Apply a crop filter to the segment based on the provided crop coordinates. + 3. Save the processed clip to `output_path` with the original quality preserved. + + Args: + raw_video_path (str): Path to the source video file to be processed. + output_path (str): Path where the processed video clip will be saved, including the file name. + start (float or int): Start time in seconds for the video segment to be cut. + end (float or int): End time in seconds for the video segment to be cut. + crop (List[int]): A list specifying crop parameters [x, width, y, height], where: + - x (int): The x-coordinate of the top-left corner of the crop area. + - width (int): The width of the crop area. + - y (int): The y-coordinate of the top-left corner of the crop area. + - height (int): The height of the crop area. + + Returns: + bool: True if the cutting and cropping were successful, False otherwise. + + Workflow: + 1. Constructs an `ffmpeg` command to cut the video from `start` to `end` and apply the specified crop filter. + 2. Executes the command with `subprocess.call` to process the video. + 3. Checks the return code to confirm successful execution. Prints a message if the process fails. + + Raises: + ValueError: If `crop` does not contain exactly four values, or if any component is invalid. + FileNotFoundError: If `ffmpeg` is not installed or accessible from the system PATH. + + Example: + >>> cut_and_crop_video( + raw_video_path="/path/to/source.mp4", + output_path="/path/to/clip.mp4", + start=10, + end=20, + crop=[50, 200, 30, 200] + ) + True + + Notes: + - Requires `ffmpeg` to be installed and accessible from the command line. + - If `output_path` already exists, it will be overwritten. + - `-qscale 0` is used to preserve the video quality. + - The crop filter uses the format `crop=width:height:x:y`, where `x` and `y` specify the top-left corner. + """ x, out_w, y, out_h = crop - command = ' '.join([ + command = [ "ffmpeg", "-i", raw_video_path, - "-strict", "-2", # Some legacy arguments - "-loglevel", "quiet", # Verbosity arguments - "-qscale", "0", # Preserve the quality - "-y", # Overwrite if the file exists - "-ss", str(start), "-to", str(end), # Cut arguments - "-filter:v", f'"crop={out_w}:{out_h}:{x}:{y}"', # Crop arguments + "-strict", "-2", # Some legacy arguments + "-loglevel", "quiet", # Verbosity arguments + "-qscale", "0", # Preserve the quality + "-y", # Overwrite if the file exists + "-ss", str(start), + "-to", str(end), + "-filter:v", f"crop={out_w}:{out_h}:{x}:{y}", # Crop arguments output_path - ]) - - return_code = subprocess.call(command, shell=True) + ] + return_code = subprocess.call(command) success = return_code == 0 if not success: - print('Command failed:', command) + print(f'{Fore.RED} Command failed: {" ".join(command)}') return success if __name__ == "__main__": parser = argparse.ArgumentParser(description="Download HDTF dataset") - parser.add_argument('-s', '--source_dir', type=str, default='HDTF_dataset', help='Path to the directory with the dataset') - parser.add_argument('-o', '--output_dir', type=str, help='Where to save the videos?') - parser.add_argument('-w', '--num_workers', type=int, default=8, help='Number of workers for downloading') + parser.add_argument('-s', '--source_dir', type=str, default='HDTF_dataset', + help='Path to the directory with the dataset description') + parser.add_argument('-o', '--output_dir', type=str, default='download', + help='Where to save the videos?') + parser.add_argument('-w', '--num_workers', type=int, default=1, + help='Number of workers for downloading.') args = parser.parse_args() download_hdtf(