diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..46c31ca
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+download/
+output_logfile.txt
diff --git a/README.md b/README.md
index 036b12a..7933970 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,20 @@
-
# HDTF
+<<<<<<< HEAD
+Flow-guided One-shot Talking Face Generation with a High-resolution Audio-visual Dataset
+paper supplementary [demo video](https://www.youtube.com/watch?v=uJdBgWYBTww)
+
+## Details of HDTF dataset
+**./HDTF_dataset** consists of *youtube video url*, *video resolution* (in our method, may not be the best resolution), *time stamps of talking face*, *facial region* (in the our method) and *the zoom scale* of the cropped window.
+
+**xx_video_url.txt:**
+=======
Flow-guided One-shot Talking Face Generation with a High-resolution Audio-visual Dataset
paper supplementary
## Details of HDTF dataset
**./HDTF_dataset** consists of *youtube video url*, *video resolution* (in our method, may not be the best resolution), *time stamps of talking face*, *facial region* (in the our method) and *the zoom scale* of the cropped window.
**xx_video_url.txt:**
-
+>>>>>>> 8c402f4 (doc: add the downloading instructions)
```
format: video name | video youtube url
@@ -44,7 +52,20 @@ When using HDTF dataset,
- We resize all cropped videos into **512 x 512** resolution.
-The HDTF dataset is available to download under a Creative Commons Attribution 4.0 International License. If you face any problems when processing HDTF, pls contact me.
+The HDTF dataset is available to download under a Creative Commons Attribution 4.0 International License. **Thanks @universome for provding the the script of data processing, pls visit [here](https://github.com/universome/HDTF) for more details.** If you face any problems when processing HDTF, pls contact me.
+
+## Inference code
+#### code of audio-to-animation
+coming soon......
+
+#### code of constructing approximate dense flow
+The code is in **./code_constructing_Fapp**, pls visit [here](https://github.com/MRzzm/HDTF/tree/main/code_constructing_Fapp) for more details.
+
+#### code of animation-to-video module
+The code is in **./code_animation2video**, pls visit [here](https://github.com/MRzzm/HDTF/tree/main/code_animation2video) for more details.
+
+#### code of reproducing other works
+coming soon......
## Downloading
For convenience, we added the `download.py` script which downloads, crops and resizes the dataset. You can use it via the following command:
diff --git a/code_animation2video/inference.py b/code_animation2video/inference.py
new file mode 100644
index 0000000..0744d9f
--- /dev/null
+++ b/code_animation2video/inference.py
@@ -0,0 +1,115 @@
+import numpy as np
+from models import VideoGenerator
+from opts import parse_opts
+from torchvision import transforms
+import os
+from math import log10
+import torch
+import numpy as np
+import json
+import random
+import cv2
+import subprocess
+from utils import visualize_dense_flow,make_coordinate_grid
+if __name__ == "__main__":
+ opt = parse_opts()
+ ### load trained model
+ video_generator = VideoGenerator(opt.input_channel, opt.encoder_num_down_blocks, opt.encoder_block_expansion,
+ opt.encoder_max_features, opt.houglass_num_blocks,
+ opt.houglass_block_expansion, opt.houglass_max_features,
+ opt.num_bottleneck_blocks).cuda()
+ old_dict = torch.load(opt.model_path)['state_dict']['net_g']
+ new_dict = {}
+ for k, v in old_dict.items():
+ name = k[7:].replace('dense_motion','foreground_matting').replace('attention_mask','matting_mask')
+ new_dict[name] = v
+ video_generator.load_state_dict(new_dict)
+ video_generator.eval()
+ #### load reference image
+ reference_image = cv2.imread(opt.image_path)
+ reference_tensor = torch.from_numpy(reference_image / 255).permute(2, 0, 1).float().unsqueeze(0).cuda()
+ #### load approximate dense flow
+ Fapp = np.load(opt.dense_flow_path)
+ tem = Fapp[0,:,:,0]
+ #### output setting
+ ## generated video
+ synthetic_video_path = os.path.join(opt.res_path,os.path.basename(opt.image_path)[:-4] + '_video.mp4')
+ if os.path.exists(synthetic_video_path):
+ os.remove(synthetic_video_path)
+ videowriter_synthetic_video = cv2.VideoWriter(synthetic_video_path, cv2.VideoWriter_fourcc(*'XVID'), 30, (512, 512))
+ ## app dense flow
+ Fapp_video_path = synthetic_video_path.replace('video.mp4','Fapp.mp4')
+ if os.path.exists(Fapp_video_path):
+ os.remove(Fapp_video_path)
+ videowriter_Fapp = cv2.VideoWriter(Fapp_video_path, cv2.VideoWriter_fourcc(*'XVID'), 30, (512, 512))
+ ## matting mask
+ matting_mask_path = synthetic_video_path.replace('video.mp4','matting_mask.mp4')
+ if os.path.exists(matting_mask_path):
+ os.remove(matting_mask_path)
+ videowriter_matting_mask = cv2.VideoWriter(matting_mask_path, cv2.VideoWriter_fourcc(*'XVID'), 30, (128, 128))
+ ## revised dense flow
+ revised_dense_path = synthetic_video_path.replace('video.mp4', 'revised_dense.mp4')
+ if os.path.exists(revised_dense_path):
+ os.remove(revised_dense_path)
+ videowriter_revised_dense = cv2.VideoWriter(revised_dense_path, cv2.VideoWriter_fourcc(*'XVID'), 30, (128, 128))
+ ## foreground mask
+ foreground_mask_path = synthetic_video_path.replace('video.mp4', 'foreground_mask.mp4')
+ if os.path.exists(foreground_mask_path):
+ os.remove(foreground_mask_path)
+ videowriter_foreground_mask = cv2.VideoWriter(foreground_mask_path, cv2.VideoWriter_fourcc(*'XVID'), 30, (128, 128))
+ ## warped image
+ warped_image_path = synthetic_video_path.replace('video.mp4', 'warped_image.mp4')
+ if os.path.exists(warped_image_path):
+ os.remove(warped_image_path)
+ videowriter_warped_image = cv2.VideoWriter(warped_image_path, cv2.VideoWriter_fourcc(*'XVID'), 30, (128, 128))
+ ######### generate video frame by frame
+ frame_length = Fapp.shape[0]
+ for i in range(frame_length):
+ print('generating frame {}/{} '.format(i, frame_length))
+ Fapp_i = Fapp[i,:,:]
+ Fapp_i_visual = visualize_dense_flow(Fapp_i - make_coordinate_grid((Fapp_i.shape[0],Fapp_i.shape[1])))
+ videowriter_Fapp.write(Fapp_i_visual)
+ with torch.no_grad():
+ Fapp_i = torch.from_numpy(Fapp_i).float().cuda().unsqueeze(0)
+ res_out = video_generator(reference_tensor, Fapp_i)
+ ## synthetic_video
+ synthetic_video_i = res_out['synthetic_image'] * 255
+ synthetic_video_i = synthetic_video_i.cpu().squeeze().permute(1, 2, 0).float().detach().numpy().astype(np.uint8)
+ videowriter_synthetic_video.write(synthetic_video_i)
+ ## warped image
+ warped_image_i = res_out['warped_image']* 255
+ warped_image_i = warped_image_i.cpu().squeeze().permute(1, 2, 0).float().detach().numpy().astype(np.uint8)
+ videowriter_warped_image.write(warped_image_i)
+ ## foreground_mask
+ foreground_mask_i = res_out['foreground_mask']* 255
+ foreground_mask_i =np.expand_dims(foreground_mask_i.cpu().squeeze().detach().numpy().astype(np.uint8),2)
+ foreground_mask_i = foreground_mask_i.repeat(3,2)
+ videowriter_foreground_mask.write(foreground_mask_i)
+ ## dense_flow_foreground_vis
+ dense_flow_foreground_vis_i = foreground_mask_i/255.0 * cv2.resize(Fapp_i_visual,(128,128))
+ dense_flow_foreground_vis_i = dense_flow_foreground_vis_i.astype(np.uint8)
+ videowriter_revised_dense.write(dense_flow_foreground_vis_i)
+ ### matting_mask
+ matting_mask_i = res_out['matting_mask'] * 255
+ matting_mask_i = matting_mask_i.cpu().squeeze().detach().numpy().astype(np.uint8)
+ matting_mask_i = np.expand_dims(matting_mask_i,2).repeat(3,2)
+ videowriter_matting_mask.write(matting_mask_i)
+
+ videowriter_Fapp.release()
+ videowriter_synthetic_video.release()
+ videowriter_warped_image.release()
+ videowriter_foreground_mask.release()
+ videowriter_revised_dense.release()
+ videowriter_matting_mask.release()
+
+ if os.path.exists(opt.audio_path):
+ video_add_audio_path = synthetic_video_path.replace('.mp4', '_add_audio.mp4')
+ if os.path.exists(video_add_audio_path):
+ os.remove(video_add_audio_path)
+ cmd = 'ffmpeg -i {} -i {} -c:v copy -c:a aac -strict experimental -map 0:v:0 -map 1:a:0 {}'.format(
+ synthetic_video_path,
+ opt.audio_path,
+ video_add_audio_path)
+ subprocess.call(cmd, shell=True)
+
+
diff --git a/code_animation2video/models.py b/code_animation2video/models.py
new file mode 100644
index 0000000..014680c
--- /dev/null
+++ b/code_animation2video/models.py
@@ -0,0 +1,292 @@
+from torch import nn
+from torch import nn
+import torch.nn.functional as F
+import torch
+from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
+
+class ResBlock2d(nn.Module):
+ def __init__(self, in_features, kernel_size, padding):
+ super(ResBlock2d, self).__init__()
+ self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
+ padding=padding)
+ self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
+ padding=padding)
+ self.norm1 = BatchNorm2d(in_features)
+ self.norm2 = BatchNorm2d(in_features)
+ self.relu = nn.ReLU()
+ def forward(self, x):
+ out = self.norm1(x)
+ out = self.relu(out)
+ out = self.conv1(out)
+ out = self.norm2(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+ out += x
+ return out
+
+class UpBlock2d(nn.Module):
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
+ super(UpBlock2d, self).__init__()
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+ padding=padding, groups=groups)
+ self.norm = BatchNorm2d(out_features)
+ self.relu = nn.ReLU()
+ def forward(self, x):
+ out = F.interpolate(x, scale_factor=2)
+ out = self.conv(out)
+ out = self.norm(out)
+ out = F.relu(out)
+ return out
+
+class DownBlock2d(nn.Module):
+ def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
+ super(DownBlock2d, self).__init__()
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+ padding=padding, groups=groups)
+ self.norm = BatchNorm2d(out_features)
+ self.pool = nn.AvgPool2d(kernel_size=(2, 2))
+ self.relu = nn.ReLU()
+ def forward(self, x):
+ out = self.conv(x)
+ out = self.norm(out)
+ out = self.relu(out)
+ out = self.pool(out)
+ return out
+
+class SameBlock2d(nn.Module):
+ def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1):
+ super(SameBlock2d, self).__init__()
+ self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,
+ kernel_size=kernel_size, padding=padding, groups=groups)
+ self.norm = BatchNorm2d(out_features)
+ self.relu = nn.ReLU()
+ def forward(self, x):
+ out = self.conv(x)
+ out = self.norm(out)
+ out = self.relu(out)
+ return out
+
+class HourglassEncoder(nn.Module):
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
+ super(HourglassEncoder, self).__init__()
+ down_blocks = []
+ for i in range(num_blocks):
+ down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
+ min(max_features, block_expansion * (2 ** (i + 1))),
+ kernel_size=3, padding=1))
+ self.down_blocks = nn.ModuleList(down_blocks)
+
+ def forward(self, x):
+ outs = [x]
+ for down_block in self.down_blocks:
+ outs.append(down_block(outs[-1]))
+ return outs
+
+class HourglassDecoder(nn.Module):
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
+ super(HourglassDecoder, self).__init__()
+ up_blocks = []
+ for i in range(num_blocks)[::-1]:
+ in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
+ out_filters = min(max_features, block_expansion * (2 ** i))
+ up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1))
+ self.up_blocks = nn.ModuleList(up_blocks)
+ self.out_filters = block_expansion + in_features
+ def forward(self, x):
+ out = x.pop()
+ for up_block in self.up_blocks:
+ out = up_block(out)
+ skip = x.pop()
+ out = torch.cat([out, skip], dim=1)
+ return out
+
+class Hourglass(nn.Module):
+ def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
+ super(Hourglass, self).__init__()
+ self.encoder = HourglassEncoder(block_expansion, in_features, num_blocks, max_features)
+ self.decoder = HourglassDecoder(block_expansion, in_features, num_blocks, max_features)
+ self.out_filters = self.decoder.out_filters
+ def forward(self, x):
+ return self.decoder(self.encoder(x))
+
+class AntiAliasInterpolation2d(nn.Module):
+ """
+ Band-limited downsampling, for better preservation of the input signal.
+ """
+ def __init__(self, channels, scale):
+ super(AntiAliasInterpolation2d, self).__init__()
+ sigma = (1 / scale - 1) / 2
+ kernel_size = 2 * round(sigma * 4) + 1
+ self.ka = kernel_size // 2
+ self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
+
+ kernel_size = [kernel_size, kernel_size]
+ sigma = [sigma, sigma]
+ # The gaussian kernel is the product of the
+ # gaussian function of each dimension.
+ kernel = 1
+ meshgrids = torch.meshgrid(
+ [
+ torch.arange(size, dtype=torch.float32)
+ for size in kernel_size
+ ]
+ )
+ for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
+ mean = (size - 1) / 2
+ kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
+
+ # Make sure sum of values in gaussian kernel equals 1.
+ kernel = kernel / torch.sum(kernel)
+ # Reshape to depthwise convolutional weight
+ kernel = kernel.view(1, 1, *kernel.size())
+ kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
+
+ self.register_buffer('weight', kernel)
+ self.groups = channels
+ self.scale = scale
+
+ def forward(self, input):
+ if self.scale == 1.0:
+ return input
+
+ out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
+ out = F.conv2d(out, weight=self.weight, groups=self.groups)
+ out = F.interpolate(out, scale_factor=(self.scale, self.scale))
+
+ return out
+
+class Encoder(nn.Module):
+ def __init__(self, num_channels, num_down_blocks=3, block_expansion=64, max_features=512,
+ ):
+ super(Encoder, self).__init__()
+ self.in_conv = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3))
+ down_blocks = []
+ for i in range(num_down_blocks):
+ in_features = min(max_features, block_expansion * (2 ** i))
+ out_features = min(max_features, block_expansion * (2 ** (i + 1)))
+ down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
+ self.down_blocks = nn.Sequential(*down_blocks)
+ def forward(self, image):
+ out = self.in_conv(image)
+ out = self.down_blocks(out)
+ return out
+
+class Bottleneck(nn.Module):
+ def __init__(self, num_bottleneck_blocks,num_down_blocks=3, block_expansion=64, max_features=512):
+ super(Bottleneck, self).__init__()
+ bottleneck = []
+ in_features = min(max_features, block_expansion * (2 ** num_down_blocks))
+ for i in range(num_bottleneck_blocks):
+ bottleneck.append(ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1)))
+ self.bottleneck = nn.Sequential(*bottleneck)
+ def forward(self, feature_map):
+ out = self.bottleneck(feature_map)
+ return out
+
+class Decoder(nn.Module):
+ def __init__(self,num_channels, num_down_blocks=3, block_expansion=64, max_features=512):
+ super(Decoder, self).__init__()
+ up_blocks = []
+ for i in range(num_down_blocks):
+ in_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i)))
+ out_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i - 1)))
+ up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
+ self.up_blocks = nn.Sequential(*up_blocks)
+ self.out_conv = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3))
+ self.sigmoid = nn.Sigmoid()
+ def forward(self, feature_map):
+ out = self.up_blocks(feature_map)
+ out = self.out_conv(out)
+ out = self.sigmoid(out)
+ return out
+
+def warp_image(image, motion_flow):
+ _, h_old, w_old, _ = motion_flow.shape
+ _, _, h, w = image.shape
+ if h_old != h or w_old != w:
+ motion_flow = motion_flow.permute(0, 3, 1, 2)
+ motion_flow = F.interpolate(motion_flow, size=(h, w), mode='bilinear')
+ motion_flow = motion_flow.permute(0, 2, 3, 1)
+ return F.grid_sample(image, motion_flow)
+
+def make_coordinate_grid(spatial_size, type):
+ h, w = spatial_size
+ x = torch.arange(w).type(type)
+ y = torch.arange(h).type(type)
+ x = (2 * (x / (w - 1)) - 1)
+ y = (2 * (y / (h - 1)) - 1)
+ yy = y.view(-1, 1).repeat(1, w)
+ xx = x.view(1, -1).repeat(h, 1)
+ meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
+ return meshed
+
+class ForegroundMatting(nn.Module):
+ def __init__(self, num_channels,scale_factor,matting_channel,num_blocks,block_expansion, max_features):
+ super(ForegroundMatting, self).__init__()
+ self.down_sample_image = AntiAliasInterpolation2d(num_channels, scale_factor)
+ self.down_sample_flow = AntiAliasInterpolation2d(2, scale_factor)
+ self.hourglass = Hourglass(block_expansion=block_expansion, in_features= num_channels * 2 + 2,
+ max_features=max_features, num_blocks=num_blocks)
+ self.foreground_mask = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3))
+ self.matting_mask = nn.Conv2d(self.hourglass.out_filters, 1, kernel_size=(7, 7), padding=(3, 3))
+ self.matting = nn.Conv2d(self.hourglass.out_filters, matting_channel, kernel_size=(7, 7), padding=(3, 3))
+ self.scale_factor = scale_factor
+ self.sigmoid = nn.Sigmoid()
+ def forward(self, reference_image, dense_flow):
+ '''
+ source_image : b x c x h x w
+ dense_tensor: b x h x w x 2
+ '''
+ res_out = {}
+ if self.scale_factor != 1: #down sample the image
+ reference_image = self.down_sample_image(reference_image)
+ dense_flow = self.down_sample_flow(dense_flow.permute(0,3,1,2)).permute(0,2,3,1)
+ batch, _, h, w = reference_image.shape
+ warped_image = warp_image(reference_image, dense_flow)#warp the image with dense flow
+ res_out['warped_image'] = warped_image
+ hourglass_input = torch.cat([reference_image,dense_flow.permute(0,3,1,2),warped_image], dim=1)
+ hourglass_out = self.hourglass(hourglass_input)
+ foreground_mask = self.foreground_mask(hourglass_out) # compute foreground mask
+ foreground_mask = self.sigmoid(foreground_mask).permute(0,2,3,1)
+ res_out['foreground_mask'] = foreground_mask
+ grid_flow = make_coordinate_grid((h, w), dense_flow.type())
+ dense_flow_foreground = dense_flow * foreground_mask + (1-foreground_mask) * grid_flow.unsqueeze(0) ## revise the dense flow
+ res_out['dense_flow_foreground'] = dense_flow_foreground
+ res_out['dense_flow_foreground_vis'] = dense_flow * foreground_mask
+ matting_mask = self.matting_mask(hourglass_out) # compute matting mask
+ matting_mask = self.sigmoid(matting_mask)
+ res_out['matting_mask'] = matting_mask
+ matting_image = self.matting(hourglass_out) # computing matting image
+ res_out['matting_image'] = matting_image
+ return res_out
+
+
+
+class VideoGenerator(nn.Module):
+ def __init__(self, num_channels, encoder_num_down_blocks=3,encoder_block_expansion=64,
+ encoder_max_features=512, houglass_num_blocks=5,
+ houglass_block_expansion = 64,houglass_max_features = 1024, num_bottleneck_blocks=6):
+ super(VideoGenerator, self).__init__()
+ self.encoder = Encoder(num_channels,encoder_num_down_blocks,
+ encoder_block_expansion,encoder_max_features)
+ matting_channel = int(min(encoder_max_features, encoder_block_expansion * (2 ** encoder_num_down_blocks)))
+ self.foreground_matting = ForegroundMatting(num_channels,scale_factor=1/(2**encoder_num_down_blocks),matting_channel = matting_channel,
+ num_blocks = houglass_num_blocks,block_expansion =houglass_block_expansion,
+ max_features = houglass_max_features)
+ self.bottleneck = Bottleneck(num_bottleneck_blocks,encoder_num_down_blocks,
+ encoder_block_expansion,encoder_max_features)
+ self.decoder = Decoder(num_channels,encoder_num_down_blocks, encoder_block_expansion,
+ encoder_max_features)
+ def forward(self, reference_image,dense_flow):
+ '''
+ source_image: b x c x h x w
+ dense_flow: b x h x w x 2
+ '''
+ feature_map = self.encoder(reference_image) ## compute feature map
+ res_out = self.foreground_matting(reference_image, dense_flow) ## compute matting & revise dense flow
+ assert feature_map.shape[2] == res_out['matting_mask'].shape[2] and feature_map.shape[3] == res_out['matting_mask'].shape[3]
+ warped_feature_map = warp_image(feature_map, res_out['dense_flow_foreground']) * res_out['matting_mask'] + (1-res_out['matting_mask']) * res_out['matting_image']
+ warped_feature_map = self.bottleneck(warped_feature_map) # decode feature map
+ synthetic_image = self.decoder(warped_feature_map) # decode feature map
+ res_out['synthetic_image'] = synthetic_image
+ return res_out
\ No newline at end of file
diff --git a/code_animation2video/opts.py b/code_animation2video/opts.py
new file mode 100644
index 0000000..68a5fa0
--- /dev/null
+++ b/code_animation2video/opts.py
@@ -0,0 +1,24 @@
+import argparse
+def parse_opts():
+ parser = argparse.ArgumentParser(description='animation2video')
+ # ========================= Input Configs ==========================
+ parser.add_argument('--model_path', type=str, default=r'./checkpoints/checkpoint_animation2video.pth', help='trained model path')
+ parser.add_argument('--image_path', type=str, default=r'./test_data/taile.jpg', help='reference image path')
+ parser.add_argument('--dense_flow_path', type=str, default=r'./test_data/taile_Fapp.npy', help='reference approximate dense flow path')
+ parser.add_argument('--audio_path', type=str, default=r'./test_data/chuanpu.wav', help='input audio path')
+ parser.add_argument('--res_path', type=str, default=r'./result', help='result path')
+ # ========================= Base Configs ==========================
+ parser.add_argument('--input_channel', type=int, default=3, help='input image channels')
+ parser.add_argument('--out_channel', type=int, default=3, help='output image channels')
+ parser.add_argument('--image_size', type=int, default=512, help='image size')
+ #========================= Network Configs ==========================
+ parser.add_argument('--encoder_num_down_blocks', type=int, default=2, help='network setting')
+ parser.add_argument('--encoder_block_expansion', type=int, default=64, help='network setting')
+ parser.add_argument('--encoder_max_features', type=int, default=512, help='network setting')
+ parser.add_argument('--num_bottleneck_blocks', type=int, default=2, help='network setting')
+ parser.add_argument('--houglass_num_blocks', type=int, default=5, help='network setting')
+ parser.add_argument('--houglass_block_expansion', type=int, default=64, help='network setting')
+ parser.add_argument('--houglass_max_features', type=int, default=512, help='network setting')
+ args = parser.parse_args()
+
+ return args
\ No newline at end of file
diff --git a/code_animation2video/readme.md b/code_animation2video/readme.md
new file mode 100644
index 0000000..b5d748a
--- /dev/null
+++ b/code_animation2video/readme.md
@@ -0,0 +1,13 @@
+# Code of animation-to-video module
+### inference
+
+ 1. Download the trained model (`checkpoint_animation2video.pth`), approximate dense flow (`mengnalisa_Fapp.npy, taile_Fapp.npy`) in [google drive](https://drive.google.com/drive/folders/1OM3AE6rjZKY1v6PVDnv-YwlmkBZOhw1L?usp=sharing).
+ 2. Put the `checkpoint_animation2video.pth` into **./checkpoints**
+ 3. Put the `mengnalisa_Fapp.npy, taile_Fapp.npy` into **./test_data**
+ 4. run
+> python inference.py --image_path=./test_data/mengnalisa.jpg --dense_flow_path=./test_data/mengnalisa_Fapp.npy
+
+or
+> python inference.py --image_path=./test_data/taile.jpg --dense_flow_path=./test_data/taile_Fapp.npy
+
+to generate all intermediate results of animation-to-video module.
\ No newline at end of file
diff --git a/code_animation2video/requirements.txt b/code_animation2video/requirements.txt
new file mode 100644
index 0000000..3065e88
--- /dev/null
+++ b/code_animation2video/requirements.txt
@@ -0,0 +1,5 @@
+torch==1.0.0
+torchvision==0.2.2
+opencv_python==4.4.0.46
+~umpy==1.19.1
+numpy==1.21.2
diff --git a/code_animation2video/sync_batchnorm/__init__.py b/code_animation2video/sync_batchnorm/__init__.py
new file mode 100644
index 0000000..bc8709d
--- /dev/null
+++ b/code_animation2video/sync_batchnorm/__init__.py
@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+# File : __init__.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
+from .replicate import DataParallelWithCallback, patch_replication_callback
diff --git a/code_animation2video/sync_batchnorm/__pycache__/__init__.cpython-36.pyc b/code_animation2video/sync_batchnorm/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000..3104579
Binary files /dev/null and b/code_animation2video/sync_batchnorm/__pycache__/__init__.cpython-36.pyc differ
diff --git a/code_animation2video/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc b/code_animation2video/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc
new file mode 100644
index 0000000..1590c73
Binary files /dev/null and b/code_animation2video/sync_batchnorm/__pycache__/batchnorm.cpython-36.pyc differ
diff --git a/code_animation2video/sync_batchnorm/__pycache__/comm.cpython-36.pyc b/code_animation2video/sync_batchnorm/__pycache__/comm.cpython-36.pyc
new file mode 100644
index 0000000..c7b8551
Binary files /dev/null and b/code_animation2video/sync_batchnorm/__pycache__/comm.cpython-36.pyc differ
diff --git a/code_animation2video/sync_batchnorm/__pycache__/replicate.cpython-36.pyc b/code_animation2video/sync_batchnorm/__pycache__/replicate.cpython-36.pyc
new file mode 100644
index 0000000..4db52bc
Binary files /dev/null and b/code_animation2video/sync_batchnorm/__pycache__/replicate.cpython-36.pyc differ
diff --git a/code_animation2video/sync_batchnorm/batchnorm.py b/code_animation2video/sync_batchnorm/batchnorm.py
new file mode 100644
index 0000000..5f4e763
--- /dev/null
+++ b/code_animation2video/sync_batchnorm/batchnorm.py
@@ -0,0 +1,315 @@
+# -*- coding: utf-8 -*-
+# File : batchnorm.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import collections
+
+import torch
+import torch.nn.functional as F
+
+from torch.nn.modules.batchnorm import _BatchNorm
+from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
+
+from .comm import SyncMaster
+
+__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
+
+
+def _sum_ft(tensor):
+ """sum over the first and last dimention"""
+ return tensor.sum(dim=0).sum(dim=-1)
+
+
+def _unsqueeze_ft(tensor):
+ """add new dementions at the front and the tail"""
+ return tensor.unsqueeze(0).unsqueeze(-1)
+
+
+_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
+_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
+
+
+class _SynchronizedBatchNorm(_BatchNorm):
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
+ super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
+
+ self._sync_master = SyncMaster(self._data_parallel_master)
+
+ self._is_parallel = False
+ self._parallel_id = None
+ self._slave_pipe = None
+
+ def forward(self, input):
+ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
+ if not (self._is_parallel and self.training):
+ return F.batch_norm(
+ input, self.running_mean, self.running_var, self.weight, self.bias,
+ self.training, self.momentum, self.eps)
+
+ # Resize the input to (B, C, -1).
+ input_shape = input.size()
+ input = input.view(input.size(0), self.num_features, -1)
+
+ # Compute the sum and square-sum.
+ sum_size = input.size(0) * input.size(2)
+ input_sum = _sum_ft(input)
+ input_ssum = _sum_ft(input ** 2)
+
+ # Reduce-and-broadcast the statistics.
+ if self._parallel_id == 0:
+ mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
+ else:
+ mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
+
+ # Compute the output.
+ if self.affine:
+ # MJY:: Fuse the multiplication for speed.
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
+ else:
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
+
+ # Reshape it.
+ return output.view(input_shape)
+
+ def __data_parallel_replicate__(self, ctx, copy_id):
+ self._is_parallel = True
+ self._parallel_id = copy_id
+
+ # parallel_id == 0 means master device.
+ if self._parallel_id == 0:
+ ctx.sync_master = self._sync_master
+ else:
+ self._slave_pipe = ctx.sync_master.register_slave(copy_id)
+
+ def _data_parallel_master(self, intermediates):
+ """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
+
+ # Always using same "device order" makes the ReduceAdd operation faster.
+ # Thanks to:: Tete Xiao (http://tetexiao.com/)
+ intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
+
+ to_reduce = [i[1][:2] for i in intermediates]
+ to_reduce = [j for i in to_reduce for j in i] # flatten
+ target_gpus = [i[1].sum.get_device() for i in intermediates]
+
+ sum_size = sum([i[1].sum_size for i in intermediates])
+ sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
+ mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
+
+ broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
+
+ outputs = []
+ for i, rec in enumerate(intermediates):
+ outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
+
+ return outputs
+
+ def _compute_mean_std(self, sum_, ssum, size):
+ """Compute the mean and standard-deviation with sum and square-sum. This method
+ also maintains the moving average on the master device."""
+ assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
+ mean = sum_ / size
+ sumvar = ssum - sum_ * mean
+ unbias_var = sumvar / (size - 1)
+ bias_var = sumvar / size
+
+ self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
+ self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
+
+ return mean, bias_var.clamp(self.eps) ** -0.5
+
+
+class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
+ r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
+ mini-batch.
+
+ .. math::
+
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
+
+ This module differs from the built-in PyTorch BatchNorm1d as the mean and
+ standard-deviation are reduced across all devices during training.
+
+ For example, when one uses `nn.DataParallel` to wrap the network during
+ training, PyTorch's implementation normalize the tensor on each device using
+ the statistics only on that device, which accelerated the computation and
+ is also easy to implement, but the statistics might be inaccurate.
+ Instead, in this synchronized version, the statistics will be computed
+ over all training samples distributed on multiple devices.
+
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
+ as the built-in PyTorch implementation.
+
+ The mean and standard-deviation are calculated per-dimension over
+ the mini-batches and gamma and beta are learnable parameter vectors
+ of size C (where C is the input size).
+
+ During training, this layer keeps a running estimate of its computed mean
+ and variance. The running sum is kept with a default momentum of 0.1.
+
+ During evaluation, this running mean/variance is used for normalization.
+
+ Because the BatchNorm is done over the `C` dimension, computing statistics
+ on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
+
+ Args:
+ num_features: num_features from an expected input of size
+ `batch_size x num_features [x width]`
+ eps: a value added to the denominator for numerical stability.
+ Default: 1e-5
+ momentum: the value used for the running_mean and running_var
+ computation. Default: 0.1
+ affine: a boolean value that when set to ``True``, gives the layer learnable
+ affine parameters. Default: ``True``
+
+ Shape:
+ - Input: :math:`(N, C)` or :math:`(N, C, L)`
+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
+
+ Examples:
+ >>> # With Learnable Parameters
+ >>> m = SynchronizedBatchNorm1d(100)
+ >>> # Without Learnable Parameters
+ >>> m = SynchronizedBatchNorm1d(100, affine=False)
+ >>> input = torch.autograd.Variable(torch.randn(20, 100))
+ >>> output = m(input)
+ """
+
+ def _check_input_dim(self, input):
+ if input.dim() != 2 and input.dim() != 3:
+ raise ValueError('expected 2D or 3D input (got {}D input)'
+ .format(input.dim()))
+ super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
+
+
+class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
+ r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
+ of 3d inputs
+
+ .. math::
+
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
+
+ This module differs from the built-in PyTorch BatchNorm2d as the mean and
+ standard-deviation are reduced across all devices during training.
+
+ For example, when one uses `nn.DataParallel` to wrap the network during
+ training, PyTorch's implementation normalize the tensor on each device using
+ the statistics only on that device, which accelerated the computation and
+ is also easy to implement, but the statistics might be inaccurate.
+ Instead, in this synchronized version, the statistics will be computed
+ over all training samples distributed on multiple devices.
+
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
+ as the built-in PyTorch implementation.
+
+ The mean and standard-deviation are calculated per-dimension over
+ the mini-batches and gamma and beta are learnable parameter vectors
+ of size C (where C is the input size).
+
+ During training, this layer keeps a running estimate of its computed mean
+ and variance. The running sum is kept with a default momentum of 0.1.
+
+ During evaluation, this running mean/variance is used for normalization.
+
+ Because the BatchNorm is done over the `C` dimension, computing statistics
+ on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
+
+ Args:
+ num_features: num_features from an expected input of
+ size batch_size x num_features x height x width
+ eps: a value added to the denominator for numerical stability.
+ Default: 1e-5
+ momentum: the value used for the running_mean and running_var
+ computation. Default: 0.1
+ affine: a boolean value that when set to ``True``, gives the layer learnable
+ affine parameters. Default: ``True``
+
+ Shape:
+ - Input: :math:`(N, C, H, W)`
+ - Output: :math:`(N, C, H, W)` (same shape as input)
+
+ Examples:
+ >>> # With Learnable Parameters
+ >>> m = SynchronizedBatchNorm2d(100)
+ >>> # Without Learnable Parameters
+ >>> m = SynchronizedBatchNorm2d(100, affine=False)
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
+ >>> output = m(input)
+ """
+
+ def _check_input_dim(self, input):
+ if input.dim() != 4:
+ raise ValueError('expected 4D input (got {}D input)'
+ .format(input.dim()))
+ super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
+
+
+class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
+ r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
+ of 4d inputs
+
+ .. math::
+
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
+
+ This module differs from the built-in PyTorch BatchNorm3d as the mean and
+ standard-deviation are reduced across all devices during training.
+
+ For example, when one uses `nn.DataParallel` to wrap the network during
+ training, PyTorch's implementation normalize the tensor on each device using
+ the statistics only on that device, which accelerated the computation and
+ is also easy to implement, but the statistics might be inaccurate.
+ Instead, in this synchronized version, the statistics will be computed
+ over all training samples distributed on multiple devices.
+
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
+ as the built-in PyTorch implementation.
+
+ The mean and standard-deviation are calculated per-dimension over
+ the mini-batches and gamma and beta are learnable parameter vectors
+ of size C (where C is the input size).
+
+ During training, this layer keeps a running estimate of its computed mean
+ and variance. The running sum is kept with a default momentum of 0.1.
+
+ During evaluation, this running mean/variance is used for normalization.
+
+ Because the BatchNorm is done over the `C` dimension, computing statistics
+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
+ or Spatio-temporal BatchNorm
+
+ Args:
+ num_features: num_features from an expected input of
+ size batch_size x num_features x depth x height x width
+ eps: a value added to the denominator for numerical stability.
+ Default: 1e-5
+ momentum: the value used for the running_mean and running_var
+ computation. Default: 0.1
+ affine: a boolean value that when set to ``True``, gives the layer learnable
+ affine parameters. Default: ``True``
+
+ Shape:
+ - Input: :math:`(N, C, D, H, W)`
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
+
+ Examples:
+ >>> # With Learnable Parameters
+ >>> m = SynchronizedBatchNorm3d(100)
+ >>> # Without Learnable Parameters
+ >>> m = SynchronizedBatchNorm3d(100, affine=False)
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
+ >>> output = m(input)
+ """
+
+ def _check_input_dim(self, input):
+ if input.dim() != 5:
+ raise ValueError('expected 5D input (got {}D input)'
+ .format(input.dim()))
+ super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
diff --git a/code_animation2video/sync_batchnorm/comm.py b/code_animation2video/sync_batchnorm/comm.py
new file mode 100644
index 0000000..922f8c4
--- /dev/null
+++ b/code_animation2video/sync_batchnorm/comm.py
@@ -0,0 +1,137 @@
+# -*- coding: utf-8 -*-
+# File : comm.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import queue
+import collections
+import threading
+
+__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
+
+
+class FutureResult(object):
+ """A thread-safe future implementation. Used only as one-to-one pipe."""
+
+ def __init__(self):
+ self._result = None
+ self._lock = threading.Lock()
+ self._cond = threading.Condition(self._lock)
+
+ def put(self, result):
+ with self._lock:
+ assert self._result is None, 'Previous result has\'t been fetched.'
+ self._result = result
+ self._cond.notify()
+
+ def get(self):
+ with self._lock:
+ if self._result is None:
+ self._cond.wait()
+
+ res = self._result
+ self._result = None
+ return res
+
+
+_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
+_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
+
+
+class SlavePipe(_SlavePipeBase):
+ """Pipe for master-slave communication."""
+
+ def run_slave(self, msg):
+ self.queue.put((self.identifier, msg))
+ ret = self.result.get()
+ self.queue.put(True)
+ return ret
+
+
+class SyncMaster(object):
+ """An abstract `SyncMaster` object.
+
+ - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
+ call `register(id)` and obtain an `SlavePipe` to communicate with the master.
+ - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
+ and passed to a registered callback.
+ - After receiving the messages, the master device should gather the information and determine to message passed
+ back to each slave devices.
+ """
+
+ def __init__(self, master_callback):
+ """
+
+ Args:
+ master_callback: a callback to be invoked after having collected messages from slave devices.
+ """
+ self._master_callback = master_callback
+ self._queue = queue.Queue()
+ self._registry = collections.OrderedDict()
+ self._activated = False
+
+ def __getstate__(self):
+ return {'master_callback': self._master_callback}
+
+ def __setstate__(self, state):
+ self.__init__(state['master_callback'])
+
+ def register_slave(self, identifier):
+ """
+ Register an slave device.
+
+ Args:
+ identifier: an identifier, usually is the device id.
+
+ Returns: a `SlavePipe` object which can be used to communicate with the master device.
+
+ """
+ if self._activated:
+ assert self._queue.empty(), 'Queue is not clean before next initialization.'
+ self._activated = False
+ self._registry.clear()
+ future = FutureResult()
+ self._registry[identifier] = _MasterRegistry(future)
+ return SlavePipe(identifier, self._queue, future)
+
+ def run_master(self, master_msg):
+ """
+ Main entry for the master device in each forward pass.
+ The messages were first collected from each devices (including the master device), and then
+ an callback will be invoked to compute the message to be sent back to each devices
+ (including the master device).
+
+ Args:
+ master_msg: the message that the master want to send to itself. This will be placed as the first
+ message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
+
+ Returns: the message to be sent back to the master device.
+
+ """
+ self._activated = True
+
+ intermediates = [(0, master_msg)]
+ for i in range(self.nr_slaves):
+ intermediates.append(self._queue.get())
+
+ results = self._master_callback(intermediates)
+ assert results[0][0] == 0, 'The first result should belongs to the master.'
+
+ for i, res in results:
+ if i == 0:
+ continue
+ self._registry[i].result.put(res)
+
+ for i in range(self.nr_slaves):
+ assert self._queue.get() is True
+
+ return results[0][1]
+
+ @property
+ def nr_slaves(self):
+ return len(self._registry)
diff --git a/code_animation2video/sync_batchnorm/replicate.py b/code_animation2video/sync_batchnorm/replicate.py
new file mode 100644
index 0000000..b71c7b8
--- /dev/null
+++ b/code_animation2video/sync_batchnorm/replicate.py
@@ -0,0 +1,94 @@
+# -*- coding: utf-8 -*-
+# File : replicate.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import functools
+
+from torch.nn.parallel.data_parallel import DataParallel
+
+__all__ = [
+ 'CallbackContext',
+ 'execute_replication_callbacks',
+ 'DataParallelWithCallback',
+ 'patch_replication_callback'
+]
+
+
+class CallbackContext(object):
+ pass
+
+
+def execute_replication_callbacks(modules):
+ """
+ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
+
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
+
+ Note that, as all modules are isomorphism, we assign each sub-module with a context
+ (shared among multiple copies of this module on different devices).
+ Through this context, different copies can share some information.
+
+ We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
+ of any slave copies.
+ """
+ master_copy = modules[0]
+ nr_modules = len(list(master_copy.modules()))
+ ctxs = [CallbackContext() for _ in range(nr_modules)]
+
+ for i, module in enumerate(modules):
+ for j, m in enumerate(module.modules()):
+ if hasattr(m, '__data_parallel_replicate__'):
+ m.__data_parallel_replicate__(ctxs[j], i)
+
+
+class DataParallelWithCallback(DataParallel):
+ """
+ Data Parallel with a replication callback.
+
+ An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
+ original `replicate` function.
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
+
+ Examples:
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
+ # sync_bn.__data_parallel_replicate__ will be invoked.
+ """
+
+ def replicate(self, module, device_ids):
+ modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
+ execute_replication_callbacks(modules)
+ return modules
+
+
+def patch_replication_callback(data_parallel):
+ """
+ Monkey-patch an existing `DataParallel` object. Add the replication callback.
+ Useful when you have customized `DataParallel` implementation.
+
+ Examples:
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
+ > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
+ > patch_replication_callback(sync_bn)
+ # this is equivalent to
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
+ """
+
+ assert isinstance(data_parallel, DataParallel)
+
+ old_replicate = data_parallel.replicate
+
+ @functools.wraps(old_replicate)
+ def new_replicate(module, device_ids):
+ modules = old_replicate(module, device_ids)
+ execute_replication_callbacks(modules)
+ return modules
+
+ data_parallel.replicate = new_replicate
diff --git a/code_animation2video/sync_batchnorm/unittest.py b/code_animation2video/sync_batchnorm/unittest.py
new file mode 100644
index 0000000..0675c02
--- /dev/null
+++ b/code_animation2video/sync_batchnorm/unittest.py
@@ -0,0 +1,29 @@
+# -*- coding: utf-8 -*-
+# File : unittest.py
+# Author : Jiayuan Mao
+# Email : maojiayuan@gmail.com
+# Date : 27/01/2018
+#
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import unittest
+
+import numpy as np
+from torch.autograd import Variable
+
+
+def as_numpy(v):
+ if isinstance(v, Variable):
+ v = v.data
+ return v.cpu().numpy()
+
+
+class TorchTestCase(unittest.TestCase):
+ def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
+ npa, npb = as_numpy(a), as_numpy(b)
+ self.assertTrue(
+ np.allclose(npa, npb, atol=atol),
+ 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
+ )
diff --git a/code_animation2video/test_data/chuanpu.wav b/code_animation2video/test_data/chuanpu.wav
new file mode 100644
index 0000000..60fab9c
Binary files /dev/null and b/code_animation2video/test_data/chuanpu.wav differ
diff --git a/code_animation2video/test_data/mengnalisa.jpg b/code_animation2video/test_data/mengnalisa.jpg
new file mode 100644
index 0000000..622e18e
Binary files /dev/null and b/code_animation2video/test_data/mengnalisa.jpg differ
diff --git a/code_animation2video/test_data/taile.jpg b/code_animation2video/test_data/taile.jpg
new file mode 100644
index 0000000..23f77ec
Binary files /dev/null and b/code_animation2video/test_data/taile.jpg differ
diff --git a/code_animation2video/utils.py b/code_animation2video/utils.py
new file mode 100644
index 0000000..5caf246
--- /dev/null
+++ b/code_animation2video/utils.py
@@ -0,0 +1,24 @@
+import numpy as np
+import cv2
+
+def make_coordinate_grid(image_size):
+ h, w = image_size
+ x = np.arange(w)
+ y = np.arange(h)
+ x = (2 * (x / (w - 1)) - 1)
+ y = (2 * (y / (h - 1)) - 1)
+ xx = x.reshape(1, -1).repeat(h, axis=0)
+ yy = y.reshape(-1, 1).repeat(w, axis=1)
+ meshed = np.stack([xx, yy], 2)
+ return meshed
+
+def visualize_dense_flow(dense_flow):
+ hsv = np.zeros([dense_flow.shape[0], dense_flow.shape[1]])
+ hsv = np.stack([hsv, hsv, hsv], 2).astype(np.uint8)
+ hsv[..., 1] = 255
+ mag, ang = cv2.cartToPolar(dense_flow[:, :, 0], dense_flow[:, :, 1])
+ hsv[..., 0] = ang * 180 / np.pi / 2
+ hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
+ rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
+ rgb = rgb.astype(np.uint8)
+ return rgb
\ No newline at end of file
diff --git a/code_constructing_Fapp/inference.py b/code_constructing_Fapp/inference.py
new file mode 100644
index 0000000..550a3e4
--- /dev/null
+++ b/code_constructing_Fapp/inference.py
@@ -0,0 +1,132 @@
+import numpy as np
+from scipy.interpolate import griddata
+import argparse
+import os
+
+def to_image(prokected_points, h, w):
+ '''
+ transform the center to (0,0)
+ '''
+ image_vertices = prokected_points.copy()
+ image_vertices[:,0] = image_vertices[:,0] + w/2
+ image_vertices[:,1] = image_vertices[:,1] + h/2
+ image_vertices[:,1] = h - image_vertices[:,1] - 1
+ return image_vertices
+
+def orthogonal_transform(points3D, scale, R, t):
+ '''
+ orthogonal transform
+ '''
+ t3d = np.squeeze(np.array(t, dtype = np.float32))
+ transformed_vertices = scale * points3D.dot(R.T) + t3d[np.newaxis, :]
+
+ return transformed_vertices
+
+def project_to_image(points3D, scale, R, t, h, w):
+ '''
+ project 3D points to 2d plane in orthogonal projection
+ '''
+ prokected_points = orthogonal_transform(points3D, scale, R, t)
+ prokected_points = to_image(prokected_points, h, w)
+ return prokected_points
+
+def compute_projected_mesh_points(shape_para,exp_para,R,T,scale,image_size,model_3dmm):
+ '''
+ compute the projected mesh points from facial animation parameters
+ :param shape_para: the shape parameter of reference face
+ :param scale: the scale parameter in orthogonal projection
+ :param exp_para: the expression parameter in orthogonal projection
+ :param R: the head rotation in orthogonal projection
+ :param T: the head translation in orthogonal projection
+ :param image_size: the size of image
+ :param model_3dmm: 3dmm model
+ :return: projected_mesh_points
+ '''
+ shape_para = np.expand_dims(shape_para, 1)
+ exp_para = np.expand_dims(exp_para, 1)
+ R_matrix = R.reshape((3, 3))
+ ## compute 3d mesh points in 3DMM
+ mesh_points_3D = model_3dmm.generate_vertices(shape_para, exp_para)
+ ## project 3D points to 2D plane
+ projected_2Dpoints = project_to_image(mesh_points_3D, scale, R_matrix, T, image_size[0], image_size[1])
+ return projected_2Dpoints
+
+
+def make_coordinate_grid(image_size):
+ h, w = image_size
+ x = np.arange(w)
+ y = np.arange(h)
+ x = (2 * (x / (w - 1)) - 1)
+ y = (2 * (y / (h - 1)) - 1)
+ xx = x.reshape(1, -1).repeat(h, axis=0)
+ yy = y.reshape(-1, 1).repeat(w, axis=1)
+ meshed = np.stack([xx, yy], 2)
+ return meshed
+
+def construct_Fapp(reference_projected_mesh_points,
+ drive_projected_mesh_points,image_size):
+ '''
+ compute Fapp from projected mesh points
+ reference_projected_mesh_points: the projected mesh points of reference image
+ drive_projected_mesh_points: the driving projected mesh points
+ '''
+ ## resize to -1 ~ 1
+ reference_projected_mesh_points = (reference_projected_mesh_points / image_size * 2) - 1
+ ### compute the max heigh of face
+ face_max_h = np.max(drive_projected_mesh_points[:, 1]).astype(np.int)
+ ## resize to -1 ~ 1
+ drive_projected_mesh_points = (drive_projected_mesh_points / image_size * 2) - 1
+ drive_projected_mesh_points_yx = drive_projected_mesh_points[:, [1, 0]]
+ ## compute sparse dense flow
+ sparse_dense_flow = reference_projected_mesh_points - drive_projected_mesh_points
+ ## compute average head motion
+ mean_dense_flow = np.mean(sparse_dense_flow, axis=0)
+ ## compute the dense flow in head-related region
+ grid_nums = complex(str(image_size) + "j")
+ grid_y, grid_x = np.mgrid[-1:1:grid_nums, -1:1:grid_nums]
+ dense_foreground_flow_x = griddata(drive_projected_mesh_points_yx, sparse_dense_flow[:, 0], (grid_y, grid_x), method='nearest')
+ ## compute dense flow in torso related region
+ dense_foreground_flow_x[face_max_h:, :] = mean_dense_flow[0]
+ dense_foreground_flow_y = griddata(drive_projected_mesh_points_yx, sparse_dense_flow[:, 1], (grid_y, grid_x), method='nearest')
+ dense_foreground_flow_y[face_max_h:, :] = mean_dense_flow[1]
+ Fapp = np.stack([dense_foreground_flow_x, dense_foreground_flow_y], 2)
+ ## transform into grid data
+ grid_mesh = make_coordinate_grid((image_size,image_size))
+ Fapp = grid_mesh + Fapp
+
+ return Fapp
+
+def parse_opts():
+ parser = argparse.ArgumentParser(description='construct Fapp')
+ parser.add_argument('--reference_projected_mesh_points_path', type=str,
+ default='./test_data/taile_source_points.npy',
+ help='the projected mesh points of reference image')
+ parser.add_argument('--drive_projected_mesh_points_path', type=str,
+ default='./test_data/taile_drive_points.npy',
+ help='the driving projected mesh points')
+ parser.add_argument('--image_size', type=int, default=512, help='the size of image')
+ parser.add_argument('--res_dir', type=str,default='./result',help='the dir of results')
+ args = parser.parse_args()
+ return args
+
+if __name__ == "__main__":
+ '''
+ It is not allowed to share the 3DMM model, so we release the inference code
+ of constructing Fapp from projected mesh points.the function
+ "compute_projected_mesh_points" shows how to compute
+ projected mesh points from facial animation parameters.
+ '''
+ opt = parse_opts()
+ reference_projected_mesh_points = np.load(opt.reference_projected_mesh_points_path)
+ drive_projected_mesh_points = np.load(opt.drive_projected_mesh_points_path)
+ frame_num = drive_projected_mesh_points.shape[0]
+ res_Fapp = []
+ for i in range(frame_num):
+ print('construct {}/{} Fapp'.format(i,frame_num))
+ Fapp_i = construct_Fapp(reference_projected_mesh_points,
+ drive_projected_mesh_points[i,:,:],opt.image_size)
+ res_Fapp.append(Fapp_i)
+ res_Fapp = np.stack(res_Fapp,0)
+ res_Fapp_path = os.path.join(opt.res_dir,os.path.basename(opt.reference_projected_mesh_points_path).replace('_source_points','_Fapp'))
+ np.save(res_Fapp_path,res_Fapp)
+
\ No newline at end of file
diff --git a/code_constructing_Fapp/readme.md b/code_constructing_Fapp/readme.md
new file mode 100644
index 0000000..33c79d4
--- /dev/null
+++ b/code_constructing_Fapp/readme.md
@@ -0,0 +1,12 @@
+# Code of constructing Fapp
+### inference
+
+ 1. Download the projected facial points (`mengnalisa_source_points.npy`,`mengnalisa_drive_points.npy`,`taile_source_points.npy`,`taile_drive_points.npy`) in [google drive](https://drive.google.com/drive/folders/1OM3AE6rjZKY1v6PVDnv-YwlmkBZOhw1L?usp=sharing).
+ 2. Put all files into **./test_data**
+ 4. run
+> python inference.py--reference_projected_mesh_points_path=./test_data/taile_source_points.npy --drive_projected_mesh_points_path=./test_data/taile_drive_points.npy
+
+or
+> python inference.py--reference_projected_mesh_points_path=./test_data/mengnalisa_source_points.npy --drive_projected_mesh_points_path=./test_data/mengnalisa_drive_points.npy
+
+to compute Fapp.
\ No newline at end of file
diff --git a/code_constructing_Fapp/requirements.txt b/code_constructing_Fapp/requirements.txt
new file mode 100644
index 0000000..7293049
--- /dev/null
+++ b/code_constructing_Fapp/requirements.txt
@@ -0,0 +1,3 @@
+scipy==1.3.1
+~umpy==1.19.1
+numpy==1.21.2
diff --git a/download.py b/download.py
index 3e21b61..8d80d32 100644
--- a/download.py
+++ b/download.py
@@ -10,25 +10,62 @@
$ python download.py --output_dir /tmp/data/hdtf --num_workers 8
```
-You need tqdm and youtube-dl libraries to be installed for this script to work.
+You need tqdm, yt_dlp, and colorama libraries to be installed for this script to work.
"""
-
import os
import argparse
+import subprocess
+import pprint
from typing import List, Dict
from multiprocessing import Pool
-import subprocess
-from subprocess import Popen, PIPE
from urllib import parse
+import yt_dlp
from tqdm import tqdm
-
+from colorama import init as cinit
+from colorama import Fore
subsets = ["RD", "WDA", "WRA"]
+cinit(autoreset=True)
def download_hdtf(source_dir: os.PathLike, output_dir: os.PathLike, num_workers: int, **process_video_kwargs):
+ """
+ Downloads and processes videos from the HDTF dataset in parallel using multiprocessing.
+
+ The function manages the download process by:
+ - Creating the necessary output directories.
+ - Constructing a download queue from files in the specified source directory.
+ - Using a multiprocessing pool to handle downloads and subsequent processing.
+ - Providing progress tracking with tqdm.
+
+ After completing the download, a message is displayed with optional cleanup instructions to delete
+ temporary raw video files to save space.
+
+ Args:
+ source_dir (os.PathLike): The directory containing HDTF metadata files, including video URLs,
+ crop data, time intervals, and resolution information for each video subset.
+ output_dir (os.PathLike): The directory where downloaded videos and processed files will be saved.
+ num_workers (int): The number of parallel worker processes to use for downloading.
+ **process_video_kwargs: Additional keyword arguments passed to `download_and_process_video`,
+ allowing custom settings for processing each video.
+
+ Workflow:
+ 1. Creates the primary output directory and a subdirectory `_videos_raw` for raw downloads.
+ 2. Calls `construct_download_queue` to prepare a list of video download tasks based on the metadata
+ available in `source_dir`. Each entry in the queue includes details needed for downloading and processing.
+ 3. Uses a multiprocessing `Pool` to execute `download_and_process_video` for each video in `download_queue`,
+ with progress displayed via tqdm.
+ 4. After completing downloads, provides a message about optional cleanup for temporary video files.
+
+ Returns:
+ None
+
+ Raises:
+ AssertionError: If certain data inconsistencies are detected during download queue construction, such as
+ missing or malformed intervals, crops, or resolution information.
+ """
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, '_videos_raw'), exist_ok=True)
@@ -37,57 +74,120 @@ def download_hdtf(source_dir: os.PathLike, output_dir: os.PathLike, num_workers:
video_data=vd,
output_dir=output_dir,
**process_video_kwargs,
- ) for vd in download_queue]
+ ) for vd in download_queue]
pool = Pool(processes=num_workers)
- tqdm_kwargs = dict(total=len(task_kwargs), desc=f'Downloading videos into {output_dir} (note: without sound)')
+ tqdm_kwargs = dict(total=len(task_kwargs),
+ desc=f'Downloading videos into {output_dir}')
for _ in tqdm(pool.imap_unordered(task_proxy, task_kwargs), **tqdm_kwargs):
pass
+ pool.close()
+ pool.join()
- print('Download is finished, you can now (optionally) delete the following directories, since they are not needed anymore and occupy a lot of space:')
- print(' -', os.path.join(output_dir, '_videos_raw'))
+ print(Fore.GREEN+'Download is finished, you can now (optionally) delete the following directories, since they are not needed anymore and occupy a lot of space:')
+ print(Fore.GREEN+' - '+os.path.join(output_dir, '_videos_raw'))
def construct_download_queue(source_dir: os.PathLike, output_dir: os.PathLike) -> List[Dict]:
+ """
+ Constructs a queue of videos to be downloaded and processed based on metadata from the HDTF dataset.
+
+ This function reads metadata files for each subset in the HDTF dataset, which provide information on:
+ - Video URLs.
+ - Time intervals indicating segments to be extracted from each video.
+ - Crop coordinates defining the regions of interest.
+ - Resolution information for each video.
+
+ For each valid video file, an entry is created in the download queue with detailed information required
+ for downloading, cropping, and segmenting.
+
+ Args:
+ source_dir (os.PathLike): Path to the directory containing metadata files (`*_video_url.txt`,
+ `*_crop_wh.txt`, `*_annotion_time.txt`, and `*_resolution.txt`) for each subset.
+ output_dir (os.PathLike): Path to the directory where the downloaded and processed videos will be stored.
+
+ Returns:
+ List[Dict]: A list of dictionaries, each representing a video to download and process. Each dictionary
+ contains the following keys:
+ - 'name': Combined subset and video name identifier.
+ - 'id': YouTube video ID extracted from the video URL.
+ - 'intervals': List of start and end times for each clip segment.
+ - 'crops': List of crop coordinates for each segment.
+ - 'output_dir': The output directory path for this video.
+ - 'resolution': Desired resolution for the video.
+
+ Workflow:
+ 1. Reads metadata files for each subset (e.g., "RD", "WDA", "WRA") to gather video URLs, time intervals, crops,
+ and resolution information.
+ 2. For each video:
+ - Ensures it has valid time intervals and resolution data.
+ - Verifies that all segments have corresponding crop information.
+ - Discards videos missing required metadata, and prints warnings about invalid or missing data.
+ 3. Creates a download queue entry for each valid video with the required download and processing data.
+
+ Raises:
+ AssertionError: If the video segment data is inconsistent, such as:
+ - Missing or malformed time intervals.
+ - Incomplete or non-square crop data.
+ These assertions ensure that only well-formed entries are added to the download queue.
+
+ Example:
+ >>> construct_download_queue("HDTF_dataset", "/tmp/data/hdtf")
+ [{'name': 'RD_sample_video', 'id': 'abc123', 'intervals': [[0, 10], [15, 25]],
+ 'crops': [[0, 128, 0, 128], [0, 128, 0, 128]], 'output_dir': '/tmp/data/hdtf', 'resolution': '720p'}]
+ """
download_queue = []
for subset in subsets:
- video_urls = read_file_as_space_separated_data(os.path.join(source_dir, f'{subset}_video_url.txt'))
- crops = read_file_as_space_separated_data(os.path.join(source_dir, f'{subset}_crop_wh.txt'))
- intervals = read_file_as_space_separated_data(os.path.join(source_dir, f'{subset}_annotion_time.txt'))
- resolutions = read_file_as_space_separated_data(os.path.join(source_dir, f'{subset}_resolution.txt'))
+ video_urls = read_file_as_space_separated_data(
+ os.path.join(source_dir, f'{subset}_video_url.txt'))
+ crops = read_file_as_space_separated_data(
+ os.path.join(source_dir, f'{subset}_crop_wh.txt'))
+ intervals = read_file_as_space_separated_data(
+ os.path.join(source_dir, f'{subset}_annotion_time.txt'))
+ resolutions = read_file_as_space_separated_data(
+ os.path.join(source_dir, f'{subset}_resolution.txt'))
for video_name, (video_url,) in video_urls.items():
if not f'{video_name}.mp4' in intervals:
- print(f'Entire {subset}/{video_name} does not contain any clip intervals, hence is broken. Discarding it.')
+ print(
+ f'{Fore.RED}Clip {subset}/{video_name} does not contain any clip intervals. It will be discarded.')
continue
if not f'{video_name}.mp4' in resolutions or len(resolutions[f'{video_name}.mp4']) > 1:
- print(f'Entire {subset}/{video_name} does not contain the resolution (or it is in a bad format), hence is broken. Discarding it.')
+ print(f'{Fore.RED}Clip {subset}/{video_name} does not contain an appropriate resolution (or it is in a bad format). It will be discarded.')
continue
- all_clips_intervals = [x.split('-') for x in intervals[f'{video_name}.mp4']]
+ all_clips_intervals = [x.split('-')
+ for x in intervals[f'{video_name}.mp4']]
clips_crops = []
clips_intervals = []
for clip_idx, clip_interval in enumerate(all_clips_intervals):
clip_name = f'{video_name}_{clip_idx}.mp4'
if not clip_name in crops:
- print(f'Clip {subset}/{clip_name} is not present in crops, hence is broken. Discarding it.')
+ print(
+ f'{Fore.RED}Discarding Clip: {subset}/{clip_name}. Clip is not present in crops.')
continue
+ else:
+ print(f'{Fore.GREEN}Appending Clip: {subset}/{clip_name}')
clips_crops.append(crops[clip_name])
clips_intervals.append(clip_interval)
clips_crops = [list(map(int, cs)) for cs in clips_crops]
if len(clips_crops) == 0:
- print(f'Entire {subset}/{video_name} does not contain any crops, hence is broken. Discarding it.')
+ print(
+ f'{Fore.RED}Discarding {subset}/{video_name}. No cropped versions found.')
continue
assert len(clips_intervals) == len(clips_crops)
- assert set([len(vi) for vi in clips_intervals]) == {2}, f"Broken time interval, {clips_intervals}"
- assert set([len(vc) for vc in clips_crops]) == {4}, f"Broken crops, {clips_crops}"
- assert all([vc[1] == vc[3] for vc in clips_crops]), f'Some crops are not square, {clips_crops}'
+ assert set([len(vi) for vi in clips_intervals]) == {
+ 2}, f"Broken time interval, {clips_intervals}"
+ assert set([len(vc) for vc in clips_crops]) == {
+ 4}, f"Broken crops, {clips_crops}"
+ assert all([vc[1] == vc[3] for vc in clips_crops]
+ ), f'Some crops are not square, {clips_crops}'
download_queue.append({
'name': f'{subset}_{video_name}',
@@ -100,45 +200,147 @@ def construct_download_queue(source_dir: os.PathLike, output_dir: os.PathLike) -
return download_queue
-
def task_proxy(kwargs):
+ """
+ A proxy function to execute `download_and_process_video` with unpacked keyword arguments.
+
+ This function serves as a wrapper that allows passing a dictionary of arguments (`kwargs`)
+ to the `download_and_process_video` function. It is primarily used in conjunction with
+ multiprocessing, where it enables the `Pool.imap_unordered` method to handle the video
+ processing tasks in parallel.
+
+ Args:
+ kwargs (dict): A dictionary of arguments required by `download_and_process_video`.
+ This typically includes:
+ - 'video_data': A dictionary containing video details (ID, name, intervals, crops, etc.).
+ - 'output_dir': The directory path where processed clips will be saved.
+
+ Returns:
+ None
+
+ Usage:
+ The `task_proxy` function is designed for use with parallel processing. By passing a dictionary
+ of arguments instead of positional arguments, it enables compatibility with the multiprocessing
+ pool's mapping methods.
+
+ Example:
+ >>> task_kwargs = {'video_data': {...}, 'output_dir': '/path/to/output'}
+ >>> task_proxy(task_kwargs)
+
+ Notes:
+ This function simplifies the interface for multiprocessing tasks, allowing
+ `download_and_process_video` to be used directly within the parallel processing workflow
+ without modifying its original function signature.
+ """
return download_and_process_video(**kwargs)
+
def download_and_process_video(video_data: Dict, output_dir: str):
"""
- Downloads the video and cuts/crops it into several ones according to the provided time intervals
+ Downloads a video from YouTube and processes it by segmenting and cropping based on provided intervals and crop data.
+
+ The function performs the following steps:
+ 1. Downloads the specified video to a raw file path within the `_videos_raw` subdirectory of `output_dir`.
+ 2. Iterates over the specified intervals and crop data to create individual video clips:
+ - Each clip is extracted according to its specified time interval.
+ - Each clip is cropped based on the coordinates provided in `video_data['crops']`.
+ 3. Saves each processed clip in `output_dir` with a unique name indicating the video and clip index.
+
+ Args:
+ video_data (dict): A dictionary containing metadata for the video to be downloaded and processed.
+ Expected keys include:
+ - 'id': The YouTube ID of the video.
+ - 'name': A unique name identifier for the video.
+ - 'intervals': A list of time intervals (start, end) for each clip segment.
+ - 'crops': A list of crop coordinates (x, width, y, height) for each clip segment.
+ - 'resolution': The desired resolution of the video.
+ output_dir (str): Path to the directory where processed video clips will be saved.
+
+ Workflow:
+ - Downloads the video using `download_video`, saving it as `{video_name}.mp4` in `_videos_raw`.
+ - For each time interval in `video_data['intervals']`:
+ - Extracts the segment and applies cropping according to the corresponding entry in `video_data['crops']`.
+ - Saves each clip in `output_dir` with a file name formatted as `{video_name}_{clip_idx:03d}.mp4`.
+ - Logs errors to the console if downloading or processing fails for a particular segment or crop.
+
+ Returns:
+ None
+
+ Raises:
+ ValueError: If the video cannot be downloaded or if any of the cropping or segmentation fails.
+
+ Example:
+ >>> video_data = {
+ 'id': 'abc123',
+ 'name': 'sample_video',
+ 'intervals': [[0, 10], [15, 25]],
+ 'crops': [[0, 128, 0, 128], [10, 118, 10, 118]],
+ 'output_dir': '/tmp/data/hdtf',
+ 'resolution': '720p'
+ }
+ >>> download_and_process_video(video_data, '/tmp/data/hdtf')
+
+ Notes:
+ - This function requires `ffmpeg` to be installed for segmenting and cropping video clips.
+ - Detailed logging is provided to indicate the status of each clip's download and processing.
+ - If `download_video` fails, an error message is printed to the console, and the function skips further processing.
"""
- raw_download_path = os.path.join(output_dir, '_videos_raw', f"{video_data['name']}.mp4")
- raw_download_log_file = os.path.join(output_dir, '_videos_raw', f"{video_data['name']}_download_log.txt")
- download_result = download_video(video_data['id'], raw_download_path, resolution=video_data['resolution'], log_file=raw_download_log_file)
+ raw_download_path = os.path.join(
+ output_dir, '_videos_raw', f"{video_data['name']}.mp4")
+ raw_download_log_file = os.path.join(
+ output_dir, '_videos_raw', f"{video_data['name']}_download_log.txt")
+ print(f"{Fore.LIGHTBLUE_EX} raw_download_path: {raw_download_path}")
- if not download_result:
- print('Failed to download', video_data)
- print(f'See {raw_download_log_file} for details')
- return
+ download_result = download_video(
+ video_data['id'], raw_download_path, log_file=raw_download_log_file)
- # We do not know beforehand, what will be the resolution of the downloaded video
- # Youtube-dl selects a (presumably) highest one
- video_resolution = get_video_resolution(raw_download_path)
- if not video_resolution != video_data['resolution']:
- print(f"Downloaded resolution is not correct for {video_data['name']}: {video_resolution} vs {video_data['name']}. Discarding this video.")
+ if not download_result:
+ print(f'{Fore.RED} Failed to download {video_data["name"]}')
+ print(f'{Fore.RED} See {raw_download_log_file} for details')
return
for clip_idx in range(len(video_data['intervals'])):
start, end = video_data['intervals'][clip_idx]
clip_name = f'{video_data["name"]}_{clip_idx:03d}'
clip_path = os.path.join(output_dir, clip_name + '.mp4')
- crop_success = cut_and_crop_video(raw_download_path, clip_path, start, end, video_data['crops'][clip_idx])
+ crop_success = cut_and_crop_video(
+ raw_download_path, clip_path, start, end, video_data['crops'][clip_idx])
if not crop_success:
- print(f'Failed to cut-and-crop clip #{clip_idx}', video_data)
+ print(f'{Fore.RED} Failed to cut-and-crop clip #{clip_idx}')
+ pprint.pprint(video_data, indent=4, sort_dicts=False)
continue
-
def read_file_as_space_separated_data(filepath: os.PathLike) -> Dict:
"""
- Reads a file as a space-separated dataframe, where the first column is the index
+ Reads a space-separated file and returns its contents as a dictionary.
+
+ This function reads a text file where each line contains space-separated values.
+ The first value in each line is treated as the key, and the remaining values are
+ stored as a list associated with that key. This is useful for parsing metadata
+ files with a consistent space-separated format.
+
+ Args:
+ filepath (os.PathLike): The path to the file to be read.
+
+ Returns:
+ Dict: A dictionary where each key corresponds to the first item in a line,
+ and each value is a list of the remaining items in that line.
+
+ Example:
+ Suppose `example.txt` contains:
+ video1 1280 720
+ video2 640 480
+ >>> read_file_as_space_separated_data("example.txt")
+ {'video1': ['1280', '720'], 'video2': ['640', '480']}
+
+ Notes:
+ - Blank lines are not supported and may cause errors.
+ - Each line must contain at least one space-separated value to be valid.
+
+ Raises:
+ IOError: If the file cannot be opened or read.
"""
with open(filepath, 'r') as f:
lines = f.read().splitlines()
@@ -148,91 +350,196 @@ def read_file_as_space_separated_data(filepath: os.PathLike) -> Dict:
return data
-def download_video(video_id, download_path, resolution: int=None, video_format="mp4", log_file=None):
+def download_video(video_id, download_path, video_format="bestvideo+bestaudio", log_file=None):
"""
- Download video from YouTube.
- :param video_id: YouTube ID of the video.
- :param download_path: Where to save the video.
- :param video_format: Format to download.
- :param log_file: Path to a log file for youtube-dl.
- :return: Tuple: path to the downloaded video and a bool indicating success.
-
- Copy-pasted from https://github.com/ytdl-org/youtube-dl
+ Downloads a YouTube video in the specified format and saves it to a given path.
+
+ This function uses `yt-dlp` to download a video by its YouTube ID, selecting the highest
+ available quality by default. It provides options for specifying a custom format or resolution
+ and can log download progress and errors to a specified log file.
+
+ Args:
+ video_id (str): The YouTube ID of the video to download.
+ download_path (str): The full path (including file name) where the downloaded video will be saved.
+ video_format (str, optional): The video and audio format selection for yt-dlp. Defaults to
+ "bestvideo+bestaudio" for highest available quality.
+ log_file (str, optional): Path to a file where log messages (debug, warnings, and errors)
+ will be recorded. If None, logging to a file is disabled.
+
+ Returns:
+ Tuple[str, bool]: A tuple where:
+ - The first element is the path to the downloaded video file.
+ - The second element is a boolean indicating success (True if the file
+ was downloaded successfully, False otherwise).
+
+ Workflow:
+ 1. Constructs `yt-dlp` options based on the provided arguments, including `format`, `outtmpl`,
+ and `logger` if a log file is specified.
+ 2. Attempts to download the video. If successful, verifies the file exists at `download_path`.
+ 3. Logs errors if the download fails and saves them to `log_file` if specified.
+
+ Raises:
+ Exception: Any exceptions during the download are logged if `log_file` is provided, and the
+ function will return False for success.
+
+ Example:
+ >>> download_video("abc123", "/path/to/video.mp4", log_file="/path/to/log.txt")
+ ("/path/to/video.mp4", True)
+
+ Notes:
+ - Requires `yt-dlp` to be installed.
+ - Requires `ffmpeg` if merging video and audio streams is necessary.
+ - Custom logging is provided through a nested `Logger` class if `log_file` is specified.
"""
- # if os.path.isfile(download_path): return True # File already exists
-
- if log_file is None:
- stderr = subprocess.DEVNULL
- else:
- stderr = open(log_file, "a")
- video_selection = f"bestvideo[ext={video_format}]"
- video_selection = video_selection if resolution is None else f"{video_selection}[height={resolution}]"
- command = [
- "youtube-dl",
- "https://youtube.com/watch?v={}".format(video_id), "--quiet", "-f",
- video_selection,
- "--output", download_path,
- "--no-continue"
- ]
- return_code = subprocess.call(command, stderr=stderr)
- success = return_code == 0
-
- if log_file is not None:
- stderr.close()
-
- return success and os.path.isfile(download_path)
-
-
-def get_video_resolution(video_path: os.PathLike) -> int:
- command = ' '.join([
- "ffprobe",
- "-v", "error",
- "-select_streams", "v:0", "-show_entries", "stream=height", "-of", "csv=p=0",
- video_path
- ])
-
- process = Popen(command, stdout=PIPE, shell=True)
- (output, err) = process.communicate()
- return_code = process.wait()
- success = return_code == 0
-
- if not success:
- print('Command failed:', command)
- return -1
-
- return int(output)
-
+ class Logger:
+ """
+ A simple logger for yt-dlp to write debug, warning, and error messages to a specified log file.
+
+ Attributes:
+ log_path (str): Path to the log file where messages will be written.
+ """
+
+ def __init__(self, log_path):
+ """
+ Initializes the Logger with a log file path.
+
+ :param log_path: Path to the file where log messages should be saved.
+ """
+ self.log_path = log_path
+
+ def debug(self, msg):
+ """
+ Logs a debug message.
+
+ :param msg: The debug message to log.
+ """
+ with open(self.log_path, "a") as f:
+ f.write(f"DEBUG: {msg}\n")
+
+ def warning(self, msg):
+ """
+ Logs a warning message.
+
+ :param msg: The warning message to log.
+ """
+ with open(self.log_path, "a") as f:
+ f.write(f"WARNING: {msg}\n")
+
+ def error(self, msg):
+ """
+ Logs an error message.
+
+ :param msg: The error message to log.
+ """
+ with open(self.log_path, "a") as f:
+ f.write(f"ERROR: {msg}\n")
+
+ # Define yt-dlp options
+ ydl_opts = {
+ # Set video format to best video and audio by default
+ 'format': video_format,
+ 'outtmpl': download_path, # Output path template
+ 'quiet': True, # Suppress verbose output
+ 'merge_output_format': 'mp4', # Ensure output format is MP4
+ }
+
+ # If a log file is specified, configure the logger
+ if log_file:
+ ydl_opts['logger'] = Logger(log_file)
+
+ # Download the video using yt-dlp
+ try:
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
+ ydl.download([f'https://www.youtube.com/watch?v={video_id}'])
+ success = True
+ except Exception as e:
+ success = False
+ if log_file:
+ with open(log_file, "a") as f:
+ f.write(
+ f"ERROR: Failed to download {video_id}. Exception: {str(e)}\n")
+
+ result = success and os.path.isfile(download_path)
+ return download_path, result
def cut_and_crop_video(raw_video_path, output_path, start, end, crop: List[int]):
- # if os.path.isfile(output_path): return True # File already exists
-
+ """
+ Cuts and crops a video segment from a specified start to end time and saves it to the output path.
+
+ This function uses `ffmpeg` to:
+ 1. Extract a segment of the video from `start` to `end` time.
+ 2. Apply a crop filter to the segment based on the provided crop coordinates.
+ 3. Save the processed clip to `output_path` with the original quality preserved.
+
+ Args:
+ raw_video_path (str): Path to the source video file to be processed.
+ output_path (str): Path where the processed video clip will be saved, including the file name.
+ start (float or int): Start time in seconds for the video segment to be cut.
+ end (float or int): End time in seconds for the video segment to be cut.
+ crop (List[int]): A list specifying crop parameters [x, width, y, height], where:
+ - x (int): The x-coordinate of the top-left corner of the crop area.
+ - width (int): The width of the crop area.
+ - y (int): The y-coordinate of the top-left corner of the crop area.
+ - height (int): The height of the crop area.
+
+ Returns:
+ bool: True if the cutting and cropping were successful, False otherwise.
+
+ Workflow:
+ 1. Constructs an `ffmpeg` command to cut the video from `start` to `end` and apply the specified crop filter.
+ 2. Executes the command with `subprocess.call` to process the video.
+ 3. Checks the return code to confirm successful execution. Prints a message if the process fails.
+
+ Raises:
+ ValueError: If `crop` does not contain exactly four values, or if any component is invalid.
+ FileNotFoundError: If `ffmpeg` is not installed or accessible from the system PATH.
+
+ Example:
+ >>> cut_and_crop_video(
+ raw_video_path="/path/to/source.mp4",
+ output_path="/path/to/clip.mp4",
+ start=10,
+ end=20,
+ crop=[50, 200, 30, 200]
+ )
+ True
+
+ Notes:
+ - Requires `ffmpeg` to be installed and accessible from the command line.
+ - If `output_path` already exists, it will be overwritten.
+ - `-qscale 0` is used to preserve the video quality.
+ - The crop filter uses the format `crop=width:height:x:y`, where `x` and `y` specify the top-left corner.
+ """
x, out_w, y, out_h = crop
- command = ' '.join([
+ command = [
"ffmpeg", "-i", raw_video_path,
- "-strict", "-2", # Some legacy arguments
- "-loglevel", "quiet", # Verbosity arguments
- "-qscale", "0", # Preserve the quality
- "-y", # Overwrite if the file exists
- "-ss", str(start), "-to", str(end), # Cut arguments
- "-filter:v", f'"crop={out_w}:{out_h}:{x}:{y}"', # Crop arguments
+ "-strict", "-2", # Some legacy arguments
+ "-loglevel", "quiet", # Verbosity arguments
+ "-qscale", "0", # Preserve the quality
+ "-y", # Overwrite if the file exists
+ "-ss", str(start),
+ "-to", str(end),
+ "-filter:v", f"crop={out_w}:{out_h}:{x}:{y}", # Crop arguments
output_path
- ])
-
- return_code = subprocess.call(command, shell=True)
+ ]
+ return_code = subprocess.call(command)
success = return_code == 0
if not success:
- print('Command failed:', command)
+ print(f'{Fore.RED} Command failed: {" ".join(command)}')
return success
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Download HDTF dataset")
- parser.add_argument('-s', '--source_dir', type=str, default='HDTF_dataset', help='Path to the directory with the dataset')
- parser.add_argument('-o', '--output_dir', type=str, help='Where to save the videos?')
- parser.add_argument('-w', '--num_workers', type=int, default=8, help='Number of workers for downloading')
+ parser.add_argument('-s', '--source_dir', type=str, default='HDTF_dataset',
+ help='Path to the directory with the dataset description')
+ parser.add_argument('-o', '--output_dir', type=str, default='download',
+ help='Where to save the videos?')
+ parser.add_argument('-w', '--num_workers', type=int, default=1,
+ help='Number of workers for downloading.')
args = parser.parse_args()
download_hdtf(