From e54857dddbfba52a64b5264bd17e5ba3fe933edd Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Thu, 14 Dec 2023 09:45:54 -0600 Subject: [PATCH 1/8] Ignore working files for a proper IDE... --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index cdb0f4cc..71ff561b 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ pretrained_models ./*.png ./*.mp4 demo/tmp -demo/outputs \ No newline at end of file +demo/outputs +.idea/* \ No newline at end of file From 463c34580389e6af593212faf47b609398318163 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Thu, 14 Dec 2023 09:46:17 -0600 Subject: [PATCH 2/8] Cross-platform paths! --- demo/models.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 demo/models.py diff --git a/demo/models.py b/demo/models.py new file mode 100644 index 00000000..75e2b5e4 --- /dev/null +++ b/demo/models.py @@ -0,0 +1,19 @@ +import os + +script_path = os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")) +models_path = os.path.join(script_path, "pretrained_models") +magic_models_path = os.path.join(models_path, "MagicAnimate") + +pretrained_model_path = os.path.join(models_path, "stable-diffusion-v1-5") +pretrained_vae_path = os.path.join(models_path, "pretrained_vae") + +pretrained_controlnet_path = os.path.join(magic_models_path, "densepose_controlnet") +pretrained_encoder_path = os.path.join(magic_models_path, "appearance_encoder") +pretrained_motion_module_path = os.path.join(magic_models_path, "temporal_attention") +motion_module = os.path.join(pretrained_motion_module_path, "temporal_attention.ckpt") + +pretrained_unet_path = "" + +config_path = os.path.join(script_path, "configs", "prompts", "animation.yaml") +inference_config_path = os.path.join(script_path, "configs", "inference", "inference.yaml") + From 22d0ca8c2567fb643d8d30b51eb001646ae6c60b Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Thu, 14 Dec 2023 09:46:39 -0600 Subject: [PATCH 3/8] Cleanup requirements --- environment.yaml | 1 + requirements.txt | 20 +++++--------------- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/environment.yaml b/environment.yaml index 00bcc3f4..c8dec364 100644 --- a/environment.yaml +++ b/environment.yaml @@ -2,6 +2,7 @@ name: manimate channels: - conda-forge - defaults + - nvidia dependencies: - _libgcc_mutex=0.1=main - _openmp_mutex=5.1=1_gnu diff --git a/requirements.txt b/requirements.txt index 82f59c87..83a35267 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,8 @@ click==8.1.7 cmake==3.27.2 contourpy==1.1.0 cycler==0.11.0 -datasets==2.14.4 +datasets +diffusers==0.24.0 dill==0.3.7 einops==0.6.1 exceptiongroup==1.1.3 @@ -34,8 +35,9 @@ grpcio==1.57.0 h11==0.14.0 httpcore==0.17.3 httpx==0.24.1 -huggingface-hub==0.16.4 +huggingface-hub idna==3.4 +imageio importlib-metadata==6.8.0 importlib-resources==6.0.1 jinja2==3.1.2 @@ -53,17 +55,6 @@ multidict==6.0.4 multiprocess==0.70.15 networkx==3.1 numpy==1.24.4 -nvidia-cublas-cu11==11.10.3.66 -nvidia-cuda-cupti-cu11==11.7.101 -nvidia-cuda-nvrtc-cu11==11.7.99 -nvidia-cuda-runtime-cu11==11.7.99 -nvidia-cudnn-cu11==8.5.0.96 -nvidia-cufft-cu11==10.9.0.58 -nvidia-curand-cu11==10.2.10.91 -nvidia-cusolver-cu11==11.4.0.1 -nvidia-cusparse-cu11==11.7.4.91 -nvidia-nccl-cu11==2.14.3 -nvidia-nvtx-cu11==11.7.91 oauthlib==3.2.2 omegaconf==2.3.0 opencv-python==4.8.0.76 @@ -102,7 +93,6 @@ toolz==0.12.0 torchmetrics==1.1.0 tqdm==4.66.1 transformers==4.32.0 -triton==2.0.0 tzdata==2023.3 urllib3==1.26.16 uvicorn==0.23.2 @@ -124,4 +114,4 @@ ffmpeg-python torch==2.0.1 torchvision==0.15.2 xformers==0.0.22 -diffusers==0.21.4 +--extra-index-url https://download.pytorch.org/whl/cu118 From 1d518b6a4c29f25a6078ad4c3174ade5d446256a Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Thu, 14 Dec 2023 09:47:20 -0600 Subject: [PATCH 4/8] Fix diffusers import --- magicanimate/pipelines/pipeline_animation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/magicanimate/pipelines/pipeline_animation.py b/magicanimate/pipelines/pipeline_animation.py index 08a77bac..972f70fe 100644 --- a/magicanimate/pipelines/pipeline_animation.py +++ b/magicanimate/pipelines/pipeline_animation.py @@ -33,6 +33,7 @@ import numpy as np import torch import torch.distributed as dist +from diffusers import DiffusionPipeline from tqdm import tqdm from diffusers.utils import is_accelerate_available from packaging import version @@ -40,7 +41,6 @@ from diffusers.configuration_utils import FrozenDict from diffusers.models import AutoencoderKL -from diffusers.pipeline_utils import DiffusionPipeline from diffusers.schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, From 39ad631a74399129f89aa3152dbf8e0d60d495fe Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Thu, 14 Dec 2023 09:52:49 -0600 Subject: [PATCH 5/8] Allow any checkpoint? --- demo/animate.py | 301 +++++++++++++++++++---------------- demo/gradio_animate.py | 91 ++++++++--- demo/{models.py => paths.py} | 3 + 3 files changed, 236 insertions(+), 159 deletions(-) rename demo/{models.py => paths.py} (82%) diff --git a/demo/animate.py b/demo/animate.py index b71f1940..db3123e0 100644 --- a/demo/animate.py +++ b/demo/animate.py @@ -8,91 +8,187 @@ # disclosure or distribution of this material and related documentation # without an express license agreement from ByteDance or # its affiliates is strictly prohibited. -import argparse -import argparse import datetime import inspect import os -import numpy as np -from PIL import Image -from omegaconf import OmegaConf from collections import OrderedDict +import numpy as np import torch - -from diffusers import AutoencoderKL, DDIMScheduler, UniPCMultistepScheduler - -from tqdm import tqdm +from PIL import Image +from accelerate.utils import set_seed +from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline +from einops import rearrange +from omegaconf import OmegaConf from transformers import CLIPTextModel, CLIPTokenizer -from magicanimate.models.unet_controlnet import UNet3DConditionModel -from magicanimate.models.controlnet import ControlNetModel +from demo.paths import config_path, inference_config_path, pretrained_encoder_path, pretrained_controlnet_path, \ + motion_module from magicanimate.models.appearance_encoder import AppearanceEncoderModel +from magicanimate.models.controlnet import ControlNetModel from magicanimate.models.mutual_self_attention import ReferenceAttentionControl +from magicanimate.models.unet_controlnet import UNet3DConditionModel from magicanimate.pipelines.pipeline_animation import AnimationPipeline from magicanimate.utils.util import save_videos_grid -from accelerate.utils import set_seed - from magicanimate.utils.videoreader import VideoReader -from einops import rearrange, repeat -import csv, pdb, glob -from safetensors import safe_open -import math -from pathlib import Path +class MagicAnimate: + def __init__(self) -> None: + config = OmegaConf.load(config_path) -class MagicAnimate(): - def __init__(self, config="configs/prompts/animation.yaml") -> None: print("Initializing MagicAnimate Pipeline...") *_, func_args = inspect.getargvalues(inspect.currentframe()) - func_args = dict(func_args) - - config = OmegaConf.load(config) - - inference_config = OmegaConf.load(config.inference_config) - - motion_module = config.motion_module - - ### >>> create animation pipeline >>> ### - tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer") - text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder") - if config.pretrained_unet_path: - unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) - else: - unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) - self.appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder="appearance_encoder").cuda() - self.reference_control_writer = ReferenceAttentionControl(self.appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks) - self.reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks) - if config.pretrained_vae_path is not None: - vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path) - else: - vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae") + self.func_args = dict(func_args) - ### Load controlnet - controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path) + self.config = config - vae.to(torch.float16) - unet.to(torch.float16) - text_encoder.to(torch.float16) - controlnet.to(torch.float16) + inference_config = OmegaConf.load(inference_config_path) + self.inference_config = inference_config + + ### Load controlnet and appearance encoder + self.appearance_encoder = AppearanceEncoderModel.from_pretrained(pretrained_encoder_path, + subfolder="appearance_encoder").cuda() + + self.controlnet = ControlNetModel.from_pretrained(pretrained_controlnet_path) + self.controlnet.to(torch.float16) self.appearance_encoder.to(torch.float16) - - unet.enable_xformers_memory_efficient_attention() self.appearance_encoder.enable_xformers_memory_efficient_attention() - controlnet.enable_xformers_memory_efficient_attention() + self.controlnet.enable_xformers_memory_efficient_attention() - self.pipeline = AnimationPipeline( - vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, - scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), - # NOTE: UniPCMultistepScheduler - ).to("cuda") - # 1. unet ckpt - # 1.1 motion module + self.pipeline = None + self.reference_control_writer = None + self.reference_control_reader = None + self.L = config.L + + print("Initialization Done!") + + def __call__(self, source_image, motion_sequence, random_seed, step, guidance_scale, size=512, checkpoint=None): + self.load_pipeline(checkpoint) + prompt = n_prompt = "" + random_seed = int(random_seed) + step = int(step) + guidance_scale = float(guidance_scale) + samples_per_video = [] + # manually set random seed for reproduction + if random_seed != -1: + torch.manual_seed(random_seed) + set_seed(random_seed) + else: + torch.seed() + + if motion_sequence.endswith('.mp4'): + control = VideoReader(motion_sequence).read() + if control[0].shape[0] != size: + control = [np.array(Image.fromarray(c).resize((size, size))) for c in control] + control = np.array(control) + else: + control = np.load(motion_sequence) + if control.shape[1] != size: + control = np.array([np.array(Image.fromarray(c).resize((size, size))) for c in control]) + control = np.array(control) + + if source_image.shape[0] != size: + source_image = np.array(Image.fromarray(source_image).resize((size, size))) + H, W, C = source_image.shape + + init_latents = None + original_length = control.shape[0] + if control.shape[0] % self.L > 0: + control = np.pad(control, ((0, self.L - control.shape[0] % self.L), (0, 0), (0, 0), (0, 0)), mode='edge') + generator = torch.Generator(device=torch.device("cuda:0")) + generator.manual_seed(torch.initial_seed()) + sample = self.pipeline( + prompt, + negative_prompt=n_prompt, + num_inference_steps=step, + guidance_scale=guidance_scale, + width=W, + height=H, + video_length=len(control), + controlnet_condition=control, + init_latents=init_latents, + generator=generator, + appearance_encoder=self.appearance_encoder, + reference_control_writer=self.reference_control_writer, + reference_control_reader=self.reference_control_reader, + source_image=source_image, + ).videos + + source_images = np.array([source_image] * original_length) + source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0 + samples_per_video.append(source_images) + + control = control / 255.0 + control = rearrange(control, "t h w c -> 1 c t h w") + control = torch.from_numpy(control) + samples_per_video.append(control[:, :, :original_length]) + + samples_per_video.append(sample[:, :, :original_length]) + + samples_per_video = torch.cat(samples_per_video) + + time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + savedir = f"demo/outputs" + animation_path = f"{savedir}/{time_str}.mp4" + + os.makedirs(savedir, exist_ok=True) + save_videos_grid(samples_per_video, animation_path) + + return animation_path + + def load_pipeline(self, model_path=None): + if self.pipeline is not None: + del self.pipeline + + config = self.config + inference_config = self.inference_config + vae = None + print(f"Loading pipeline from {model_path}") + if not model_path: + model_path = config.pretrained_model_path + unet_path = config.pretrained_unet_path if config.pretrained_unet_path else model_path + else: + unet_path = model_path + if "safetensors" in model_path or "ckpt" in model_path: + temp_pipeline = StableDiffusionPipeline.from_single_file(model_path) + tokenizer = temp_pipeline.tokenizer + text_encoder = temp_pipeline.text_encoder + unet = temp_pipeline.unet + try: + vae = temp_pipeline.vae + except: + print("No VAE found in ckpt, using default VAE") + else: + tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder") + if config.pretrained_unet_path: + unet = UNet3DConditionModel.from_pretrained_2d(unet_path, + unet_additional_kwargs=OmegaConf.to_container( + inference_config.unet_additional_kwargs)) + else: + unet = UNet3DConditionModel.from_pretrained_2d(model_path, subfolder="unet", + unet_additional_kwargs=OmegaConf.to_container( + inference_config.unet_additional_kwargs)) + + if vae is None: + if config.pretrained_vae_path is not None: + vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path) + else: + vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae") + + self.reference_control_writer = ReferenceAttentionControl(self.appearance_encoder, + do_classifier_free_guidance=True, mode='write', + fusion_blocks=config.fusion_blocks) + self.reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', + fusion_blocks=config.fusion_blocks) + motion_module_state_dict = torch.load(motion_module, map_location="cpu") - if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]}) - motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict + if "global_step" in motion_module_state_dict: self.func_args.update( + {"global_step": motion_module_state_dict["global_step"]}) + motion_module_state_dict = motion_module_state_dict[ + 'state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict try: # extra steps for self-trained models state_dict = OrderedDict() @@ -115,81 +211,16 @@ def __init__(self, config="configs/prompts/animation.yaml") -> None: _tmp_[_key] = motion_module_state_dict[key] else: _tmp_[key] = motion_module_state_dict[key] - missing, unexpected = unet.load_state_dict(_tmp_, strict=False) - assert len(unexpected) == 0 del _tmp_ del motion_module_state_dict - self.pipeline.to("cuda") - self.L = config.L - - print("Initialization Done!") - - def __call__(self, source_image, motion_sequence, random_seed, step, guidance_scale, size=512): - prompt = n_prompt = "" - random_seed = int(random_seed) - step = int(step) - guidance_scale = float(guidance_scale) - samples_per_video = [] - # manually set random seed for reproduction - if random_seed != -1: - torch.manual_seed(random_seed) - set_seed(random_seed) - else: - torch.seed() - - if motion_sequence.endswith('.mp4'): - control = VideoReader(motion_sequence).read() - if control[0].shape[0] != size: - control = [np.array(Image.fromarray(c).resize((size, size))) for c in control] - control = np.array(control) - - if source_image.shape[0] != size: - source_image = np.array(Image.fromarray(source_image).resize((size, size))) - H, W, C = source_image.shape - - init_latents = None - original_length = control.shape[0] - if control.shape[0] % self.L > 0: - control = np.pad(control, ((0, self.L-control.shape[0] % self.L), (0, 0), (0, 0), (0, 0)), mode='edge') - generator = torch.Generator(device=torch.device("cuda:0")) - generator.manual_seed(torch.initial_seed()) - sample = self.pipeline( - prompt, - negative_prompt = n_prompt, - num_inference_steps = step, - guidance_scale = guidance_scale, - width = W, - height = H, - video_length = len(control), - controlnet_condition = control, - init_latents = init_latents, - generator = generator, - appearance_encoder = self.appearance_encoder, - reference_control_writer = self.reference_control_writer, - reference_control_reader = self.reference_control_reader, - source_image = source_image, - ).videos - - source_images = np.array([source_image] * original_length) - source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0 - samples_per_video.append(source_images) - - control = control / 255.0 - control = rearrange(control, "t h w c -> 1 c t h w") - control = torch.from_numpy(control) - samples_per_video.append(control[:, :, :original_length]) - - samples_per_video.append(sample[:, :, :original_length]) - - samples_per_video = torch.cat(samples_per_video) - - time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - savedir = f"demo/outputs" - animation_path = f"{savedir}/{time_str}.mp4" - - os.makedirs(savedir, exist_ok=True) - save_videos_grid(samples_per_video, animation_path) - - return animation_path - \ No newline at end of file + vae.to(torch.float16) + unet.to(torch.float16) + text_encoder.to(torch.float16) + unet.enable_xformers_memory_efficient_attention() + + self.pipeline = AnimationPipeline( + vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=self.controlnet, + scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), + # NOTE: UniPCMultistepScheduler + ).to("cuda") diff --git a/demo/gradio_animate.py b/demo/gradio_animate.py index 9a932a28..c589c483 100644 --- a/demo/gradio_animate.py +++ b/demo/gradio_animate.py @@ -8,21 +8,39 @@ # disclosure or distribution of this material and related documentation # without an express license agreement from ByteDance or # its affiliates is strictly prohibited. -import argparse +import os + +import gradio as gr import imageio import numpy as np -import gradio as gr from PIL import Image +from huggingface_hub import hf_hub_download +from omegaconf import OmegaConf from demo.animate import MagicAnimate +from demo.paths import config_path, script_path, models_path, magic_models_path, source_images_path, \ + motion_sequences_path -animator = MagicAnimate() +animator = None -def animate(reference_image, motion_sequence_state, seed, steps, guidance_scale): - return animator(reference_image, motion_sequence_state, seed, steps, guidance_scale) -with gr.Blocks() as demo: +def list_checkpoints(): + checkpoint_dir = os.path.join(models_path, "checkpoints") + # Recursively find all .ckpt and .safetensors files + checkpoints = [""] + for root, dirs, files in os.walk(checkpoint_dir): + for file in files: + if file.endswith(".ckpt") or file.endswith(".safetensors"): + checkpoints.append(os.path.join(root, file)) + return checkpoints + +# source_image, motion_sequence, random_seed, step, guidance_scale, size=512, checkpoint=None +def animate(reference_image, motion_sequence_state, seed, steps, guidance_scale, size, checkpoint=None): + return animator(reference_image, motion_sequence_state, seed, steps, guidance_scale, size, checkpoint) + + +with gr.Blocks() as demo: gr.HTML( """
@@ -40,25 +58,30 @@ def animate(reference_image, motion_sequence_state, seed, steps, guidance_scale)
""") animation = gr.Video(format="mp4", label="Animation Results", autoplay=True) - with gr.Row(): - reference_image = gr.Image(label="Reference Image") - motion_sequence = gr.Video(format="mp4", label="Motion Sequence") - + checkpoint = gr.Dropdown(label="Checkpoint", choices=list_checkpoints()) + with gr.Row(): + reference_image = gr.Image(label="Reference Image") + motion_sequence = gr.Video(format="mp4", label="Motion Sequence") + with gr.Column(): - random_seed = gr.Textbox(label="Random seed", value=1, info="default: -1") - sampling_steps = gr.Textbox(label="Sampling steps", value=25, info="default: 25") - guidance_scale = gr.Textbox(label="Guidance scale", value=7.5, info="default: 7.5") - submit = gr.Button("Animate") + size = gr.Slider(label="Size", value=512, min=256, max=1024, step=256, info="default: 512", visible=False) + random_seed = gr.Slider(label="Random seed", value=1, info="default: -1") + sampling_steps = gr.Slider(label="Sampling steps", value=25, info="default: 25") + guidance_scale = gr.Slider(label="Guidance scale", value=7.5, info="default: 7.5", step=0.1) + submit = gr.Button("Animate") + def read_video(video): reader = imageio.get_reader(video) fps = reader.get_meta_data()['fps'] return video - + + def read_image(image, size=512): return np.array(Image.fromarray(image).resize((size, size))) - + + # when user uploads a new video motion_sequence.upload( read_video, @@ -72,25 +95,45 @@ def read_image(image, size=512): reference_image ) # when the `submit` button is clicked + #source_image, motion_sequence, random_seed, step, guidance_scale, size = 512, checkpoint = None submit.click( animate, - [reference_image, motion_sequence, random_seed, sampling_steps, guidance_scale], + [reference_image, motion_sequence, random_seed, sampling_steps, guidance_scale, size, checkpoint], animation ) - # Examples + # source_images_path = os.path.join(script_path, "inputs", "applications", "source_image") + # motion_sequences_path = os.path.join(script_path, "inputs", "applications", "driving", "densepose") + #Examples gr.Markdown("## Examples") gr.Examples( examples=[ - ["inputs/applications/source_image/monalisa.png", "inputs/applications/driving/densepose/running.mp4"], - ["inputs/applications/source_image/demo4.png", "inputs/applications/driving/densepose/demo4.mp4"], - ["inputs/applications/source_image/dalle2.jpeg", "inputs/applications/driving/densepose/running2.mp4"], - ["inputs/applications/source_image/dalle8.jpeg", "inputs/applications/driving/densepose/dancing2.mp4"], - ["inputs/applications/source_image/multi1_source.png", "inputs/applications/driving/densepose/multi_dancing.mp4"], + [f"{source_images_path}/monalisa.png", f"{motion_sequences_path}/running.mp4"], + [f"{source_images_path}/demo4.png", f"{motion_sequences_path}/demo4.mp4"], + [f"{source_images_path}/dalle2.jpeg", f"{motion_sequences_path}/running2.mp4"], + [f"{source_images_path}/dalle8.jpeg", f"{motion_sequences_path}/dancing2.mp4"], + [f"{source_images_path}/multi1_source.png", + f"{motion_sequences_path}/multi_dancing.mp4"], ], inputs=[reference_image, motion_sequence], outputs=animation, ) +if __name__ == '__main__': + if not os.path.exists(models_path): + os.mkdir(models_path) + + if not os.path.exists(os.path.join(models_path, "checkpoints")): + os.mkdir(os.path.join(models_path, "checkpoints")) + + if not os.path.exists(magic_models_path): + # git lfs clone https://huggingface.co/zcxu-eric/MagicAnimate, not hf_hub_download + git_lfs_path = os.path.join(models_path, "MagicAnimate") + if not os.path.exists(git_lfs_path): + os.system(f"git clone https://huggingface.co/zcxu-eric/MagicAnimate {git_lfs_path}") + else: + print(f"MagicAnimate already exists at {git_lfs_path}") + + animator = MagicAnimate() -demo.launch(share=True) \ No newline at end of file + demo.launch(share=True) diff --git a/demo/models.py b/demo/paths.py similarity index 82% rename from demo/models.py rename to demo/paths.py index 75e2b5e4..fd0e8ff0 100644 --- a/demo/models.py +++ b/demo/paths.py @@ -17,3 +17,6 @@ config_path = os.path.join(script_path, "configs", "prompts", "animation.yaml") inference_config_path = os.path.join(script_path, "configs", "inference", "inference.yaml") +source_images_path = os.path.join(script_path, "inputs", "applications", "source_image") +motion_sequences_path = os.path.join(script_path, "inputs", "applications", "driving", "densepose") + From a0e0efa6b46218d35824ebec80771e17b077d2d6 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Thu, 14 Dec 2023 14:04:12 -0600 Subject: [PATCH 6/8] Really ignore outputs --- .gitignore | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 71ff561b..8e6c7218 100644 --- a/.gitignore +++ b/.gitignore @@ -10,5 +10,6 @@ pretrained_models ./*.png ./*.mp4 demo/tmp -demo/outputs -.idea/* \ No newline at end of file +demo/outputs/* +.idea/* +outputs/* \ No newline at end of file From c30079da4f643ce9d3869a38c3b86c41a0ebf677 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Thu, 14 Dec 2023 14:04:50 -0600 Subject: [PATCH 7/8] Improve pipeline loading --- demo/animate.py | 101 +++++++++++++++++++++++++++++------------------- 1 file changed, 62 insertions(+), 39 deletions(-) diff --git a/demo/animate.py b/demo/animate.py index db3123e0..6a11504b 100644 --- a/demo/animate.py +++ b/demo/animate.py @@ -18,12 +18,13 @@ from PIL import Image from accelerate.utils import set_seed from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline +from diffusers.models.attention_processor import AttnProcessor2_0 from einops import rearrange from omegaconf import OmegaConf from transformers import CLIPTextModel, CLIPTokenizer from demo.paths import config_path, inference_config_path, pretrained_encoder_path, pretrained_controlnet_path, \ - motion_module + motion_module, pretrained_model_path, pretrained_vae_path, output_path from magicanimate.models.appearance_encoder import AppearanceEncoderModel from magicanimate.models.controlnet import ControlNetModel from magicanimate.models.mutual_self_attention import ReferenceAttentionControl @@ -33,6 +34,15 @@ from magicanimate.utils.videoreader import VideoReader +def xformerify(obj): + try: + import xformers + obj.enable_xformers_memory_efficient_attention + + except ImportError: + obj.set_attn_processor(AttnProcessor2_0()) + + class MagicAnimate: def __init__(self) -> None: config = OmegaConf.load(config_path) @@ -47,16 +57,8 @@ def __init__(self) -> None: self.inference_config = inference_config ### Load controlnet and appearance encoder - self.appearance_encoder = AppearanceEncoderModel.from_pretrained(pretrained_encoder_path, - subfolder="appearance_encoder").cuda() - - self.controlnet = ControlNetModel.from_pretrained(pretrained_controlnet_path) - self.controlnet.to(torch.float16) - self.appearance_encoder.to(torch.float16) - self.appearance_encoder.enable_xformers_memory_efficient_attention() - self.controlnet.enable_xformers_memory_efficient_attention() - - + self.controlnet = None + self.appearance_encoder = None self.pipeline = None self.reference_control_writer = None self.reference_control_reader = None @@ -64,8 +66,8 @@ def __init__(self) -> None: print("Initialization Done!") - def __call__(self, source_image, motion_sequence, random_seed, step, guidance_scale, size=512, checkpoint=None): - self.load_pipeline(checkpoint) + def __call__(self, source_image, motion_sequence, random_seed, step, guidance_scale, size=512, half_precision=False, checkpoint=None): + self.load_pipeline(half_precision, checkpoint) prompt = n_prompt = "" random_seed = int(random_seed) step = int(step) @@ -130,32 +132,50 @@ def __call__(self, source_image, motion_sequence, random_seed, step, guidance_sc samples_per_video = torch.cat(samples_per_video) time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - savedir = f"demo/outputs" - animation_path = f"{savedir}/{time_str}.mp4" + savedir = output_path + animation_path = os.path.join(savedir, f"{time_str}.mp4") os.makedirs(savedir, exist_ok=True) save_videos_grid(samples_per_video, animation_path) return animation_path - def load_pipeline(self, model_path=None): + def load_pipeline(self, half_precision=False, model_path=None): if self.pipeline is not None: del self.pipeline + if self.appearance_encoder is not None: + del self.appearance_encoder + + if self.controlnet is not None: + del self.controlnet + torch.cuda.empty_cache() + + self.appearance_encoder = AppearanceEncoderModel.from_pretrained(pretrained_encoder_path, + subfolder="appearance_encoder").cuda() + + self.controlnet = ControlNetModel.from_pretrained(pretrained_controlnet_path) + if half_precision: + self.controlnet.to(torch.float16) + self.appearance_encoder.to(torch.float16) + xformerify(self.controlnet) + xformerify(self.appearance_encoder) + config = self.config inference_config = self.inference_config vae = None print(f"Loading pipeline from {model_path}") if not model_path: - model_path = config.pretrained_model_path - unet_path = config.pretrained_unet_path if config.pretrained_unet_path else model_path + model_path = pretrained_model_path + unet_path = model_path else: unet_path = model_path if "safetensors" in model_path or "ckpt" in model_path: temp_pipeline = StableDiffusionPipeline.from_single_file(model_path) tokenizer = temp_pipeline.tokenizer text_encoder = temp_pipeline.text_encoder - unet = temp_pipeline.unet + unet_2d = temp_pipeline.unet + unet = UNet3DConditionModel.from_2d_unet(unet_2d, inference_config.unet_additional_kwargs) try: vae = temp_pipeline.vae except: @@ -163,20 +183,15 @@ def load_pipeline(self, model_path=None): else: tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder") - if config.pretrained_unet_path: - unet = UNet3DConditionModel.from_pretrained_2d(unet_path, - unet_additional_kwargs=OmegaConf.to_container( - inference_config.unet_additional_kwargs)) - else: - unet = UNet3DConditionModel.from_pretrained_2d(model_path, subfolder="unet", - unet_additional_kwargs=OmegaConf.to_container( - inference_config.unet_additional_kwargs)) + unet = UNet3DConditionModel.from_pretrained_2d(model_path, subfolder="unet", + unet_additional_kwargs=OmegaConf.to_container( + inference_config.unet_additional_kwargs)) if vae is None: - if config.pretrained_vae_path is not None: + if os.path.exists(pretrained_vae_path): vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path) else: - vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae") + vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") self.reference_control_writer = ReferenceAttentionControl(self.appearance_encoder, do_classifier_free_guidance=True, mode='write', @@ -184,6 +199,20 @@ def load_pipeline(self, model_path=None): self.reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks) + if half_precision: + vae.to(torch.float16) + unet.to(torch.float16) + text_encoder.to(torch.float16) + + xformerify(unet) + xformerify(vae) + + self.pipeline = AnimationPipeline( + vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=self.controlnet, + scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), + # NOTE: UniPCMultistepScheduler + ).to("cuda") + motion_module_state_dict = torch.load(motion_module, map_location="cpu") if "global_step" in motion_module_state_dict: self.func_args.update( {"global_step": motion_module_state_dict["global_step"]}) @@ -211,16 +240,10 @@ def load_pipeline(self, model_path=None): _tmp_[_key] = motion_module_state_dict[key] else: _tmp_[key] = motion_module_state_dict[key] + missing, unexpected = unet.load_state_dict(_tmp_, strict=False) + assert len(unexpected) == 0 del _tmp_ del motion_module_state_dict - vae.to(torch.float16) - unet.to(torch.float16) - text_encoder.to(torch.float16) - unet.enable_xformers_memory_efficient_attention() - - self.pipeline = AnimationPipeline( - vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=self.controlnet, - scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), - # NOTE: UniPCMultistepScheduler - ).to("cuda") + self.pipeline.to("cuda") + self.L = config.L From 225f9205e049d5b1079b5a82ac65d4a2d9387b29 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Thu, 14 Dec 2023 14:46:21 -0600 Subject: [PATCH 8/8] Add optional prompt, various cleanpu and fixes --- demo/animate.py | 4 +- demo/gradio_animate.py | 30 ++-- demo/paths.py | 1 + magicanimate/models/unet.py | 32 ++++ magicanimate/models/unet_controlnet.py | 179 ++++++++++++------- magicanimate/pipelines/pipeline_animation.py | 2 +- 6 files changed, 163 insertions(+), 85 deletions(-) diff --git a/demo/animate.py b/demo/animate.py index 6a11504b..001132df 100644 --- a/demo/animate.py +++ b/demo/animate.py @@ -66,9 +66,9 @@ def __init__(self) -> None: print("Initialization Done!") - def __call__(self, source_image, motion_sequence, random_seed, step, guidance_scale, size=512, half_precision=False, checkpoint=None): + def __call__(self, prompt, source_image, motion_sequence, random_seed, step, guidance_scale, size=512, half_precision=False, checkpoint=None): self.load_pipeline(half_precision, checkpoint) - prompt = n_prompt = "" + n_prompt = "" random_seed = int(random_seed) step = int(step) guidance_scale = float(guidance_scale) diff --git a/demo/gradio_animate.py b/demo/gradio_animate.py index c589c483..52013c27 100644 --- a/demo/gradio_animate.py +++ b/demo/gradio_animate.py @@ -36,8 +36,9 @@ def list_checkpoints(): # source_image, motion_sequence, random_seed, step, guidance_scale, size=512, checkpoint=None -def animate(reference_image, motion_sequence_state, seed, steps, guidance_scale, size, checkpoint=None): - return animator(reference_image, motion_sequence_state, seed, steps, guidance_scale, size, checkpoint) +def animate(prompt, reference_image, motion_sequence_state, seed, steps, guidance_scale, size, half_precision, checkpoint=None): + return animator(prompt, reference_image, motion_sequence_state, seed, steps, guidance_scale, size, half_precision, + checkpoint) with gr.Blocks() as demo: @@ -60,6 +61,9 @@ def animate(reference_image, motion_sequence_state, seed, steps, guidance_scale, animation = gr.Video(format="mp4", label="Animation Results", autoplay=True) with gr.Row(): checkpoint = gr.Dropdown(label="Checkpoint", choices=list_checkpoints()) + half_precision = gr.Checkbox(label="Half precision", default=False) + with gr.Row(): + prompt = gr.Textbox(label="Prompt", placeholder="Type an optional prompt here", lines=2) with gr.Row(): reference_image = gr.Image(label="Reference Image") motion_sequence = gr.Video(format="mp4", label="Motion Sequence") @@ -95,25 +99,25 @@ def read_image(image, size=512): reference_image ) # when the `submit` button is clicked - #source_image, motion_sequence, random_seed, step, guidance_scale, size = 512, checkpoint = None + # source_image, motion_sequence, random_seed, step, guidance_scale, size = 512, checkpoint = None submit.click( animate, - [reference_image, motion_sequence, random_seed, sampling_steps, guidance_scale, size, checkpoint], + [prompt, reference_image, motion_sequence, random_seed, sampling_steps, guidance_scale, size, half_precision, + checkpoint], animation ) - # source_images_path = os.path.join(script_path, "inputs", "applications", "source_image") - # motion_sequences_path = os.path.join(script_path, "inputs", "applications", "driving", "densepose") - #Examples + # Examples gr.Markdown("## Examples") + gr.Examples( examples=[ - [f"{source_images_path}/monalisa.png", f"{motion_sequences_path}/running.mp4"], - [f"{source_images_path}/demo4.png", f"{motion_sequences_path}/demo4.mp4"], - [f"{source_images_path}/dalle2.jpeg", f"{motion_sequences_path}/running2.mp4"], - [f"{source_images_path}/dalle8.jpeg", f"{motion_sequences_path}/dancing2.mp4"], - [f"{source_images_path}/multi1_source.png", - f"{motion_sequences_path}/multi_dancing.mp4"], + [os.path.join(source_images_path, "monalisa.png"), os.path.join(motion_sequences_path, "running.mp4")], + [os.path.join(source_images_path, "demo4.png"), os.path.join(motion_sequences_path, "demo4.mp4")], + [os.path.join(source_images_path, "dalle2.jpeg"), os.path.join(motion_sequences_path, "running2.mp4")], + [os.path.join(source_images_path, "dalle8.jpeg"), os.path.join(motion_sequences_path, "dancing2.mp4")], + [os.path.join(source_images_path, "multi1_source.png"), + os.path.join(motion_sequences_path, "multi_dancing.mp4")], ], inputs=[reference_image, motion_sequence], outputs=animation, diff --git a/demo/paths.py b/demo/paths.py index fd0e8ff0..2405d9cc 100644 --- a/demo/paths.py +++ b/demo/paths.py @@ -1,6 +1,7 @@ import os script_path = os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")) +output_path = os.path.join(script_path, "outputs") models_path = os.path.join(script_path, "pretrained_models") magic_models_path = os.path.join(models_path, "MagicAnimate") diff --git a/magicanimate/models/unet.py b/magicanimate/models/unet.py index 09e5e11f..b3f9dfc0 100644 --- a/magicanimate/models/unet.py +++ b/magicanimate/models/unet.py @@ -506,3 +506,35 @@ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_addition print(f"### Temporal Module Parameters: {sum(params) / 1e6} M") return model + + @classmethod + def from_2d_unet(cls, unet2d_model, unet_additional_kwargs=None): + # Extract configuration from 2D UNet model + config_2d = unet2d_model.config + config_3d = config_2d.copy() # Convert to a dictionary if necessary + + # Update configuration for 3D UNet specifics + config_3d["_class_name"] = cls.__name__ + config_3d["down_block_types"] = [ + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D" + ] + config_3d["up_block_types"] = [ + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ] + + # Initialize a new 3D UNet model with the updated configuration + model_3d = cls.from_config(config_3d, **unet_additional_kwargs) + + # Load state dict from 2D model to 3D model + # Note: This might require additional handling if the architectures differ significantly + state_dict_2d = unet2d_model.state_dict() + model_3d.load_state_dict(state_dict_2d, strict=False) + + return model_3d + diff --git a/magicanimate/models/unet_controlnet.py b/magicanimate/models/unet_controlnet.py index 0ccd9cad..ef3501e3 100644 --- a/magicanimate/models/unet_controlnet.py +++ b/magicanimate/models/unet_controlnet.py @@ -42,7 +42,6 @@ ) from .resnet import InflatedConv3d - logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -56,52 +55,52 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): @register_to_config def __init__( - self, - sample_size: Optional[int] = None, - in_channels: int = 4, - out_channels: int = 4, - center_input_sample: bool = False, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - down_block_types: Tuple[str] = ( - "CrossAttnDownBlock3D", - "CrossAttnDownBlock3D", - "CrossAttnDownBlock3D", - "DownBlock3D", - ), - mid_block_type: str = "UNetMidBlock3DCrossAttn", - up_block_types: Tuple[str] = ( - "UpBlock3D", - "CrossAttnUpBlock3D", - "CrossAttnUpBlock3D", - "CrossAttnUpBlock3D" - ), - only_cross_attention: Union[bool, Tuple[bool]] = False, - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - layers_per_block: int = 2, - downsample_padding: int = 1, - mid_block_scale_factor: float = 1, - act_fn: str = "silu", - norm_num_groups: int = 32, - norm_eps: float = 1e-5, - cross_attention_dim: int = 1280, - attention_head_dim: Union[int, Tuple[int]] = 8, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - class_embed_type: Optional[str] = None, - num_class_embeds: Optional[int] = None, - upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", - - # Additional - use_motion_module = False, - motion_module_resolutions = ( 1,2,4,8 ), - motion_module_mid_block = False, - motion_module_decoder_only = False, - motion_module_type = None, - motion_module_kwargs = {}, - unet_use_cross_frame_attention = None, - unet_use_temporal_attention = None, + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + mid_block_type: str = "UNetMidBlock3DCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + + # Additional + use_motion_module=False, + motion_module_resolutions=(1, 2, 4, 8), + motion_module_mid_block=False, + motion_module_decoder_only=False, + motion_module_type=None, + motion_module_kwargs={}, + unet_use_cross_frame_attention=None, + unet_use_temporal_attention=None, ): super().__init__() @@ -166,8 +165,9 @@ def __init__( unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, - - use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only), + + use_motion_module=use_motion_module and (res in motion_module_resolutions) and ( + not motion_module_decoder_only), motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) @@ -191,14 +191,14 @@ def __init__( unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, - + use_motion_module=use_motion_module and motion_module_mid_block, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) else: raise ValueError(f"unknown mid_block_type : {mid_block_type}") - + # count how many layers upsample the videos self.num_upsamplers = 0 @@ -326,16 +326,16 @@ def _set_gradient_checkpointing(self, module, value=False): module.gradient_checkpointing = value def forward( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - class_labels: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - # for controlnet - down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, - mid_block_additional_residual: Optional[torch.Tensor] = None, - return_dict: bool = True, + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + # for controlnet + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + return_dict: bool = True, ) -> Union[UNet3DConditionOutput, Tuple]: r""" Args: @@ -354,7 +354,7 @@ def forward( # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. - default_overall_up_factor = 2**self.num_upsamplers + default_overall_up_factor = 2 ** self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False @@ -423,7 +423,8 @@ def forward( attention_mask=attention_mask, ) else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states) + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, + encoder_hidden_states=encoder_hidden_states) down_block_res_samples += res_samples @@ -431,7 +432,7 @@ def forward( new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( - down_block_res_samples, down_block_additional_residuals + down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) @@ -450,7 +451,7 @@ def forward( for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + res_samples = down_block_res_samples[-len(upsample_block.resnets):] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the @@ -469,7 +470,8 @@ def forward( ) else: sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states, + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, + encoder_hidden_states=encoder_hidden_states, ) # post-process @@ -506,7 +508,7 @@ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_addition "CrossAttnUpBlock3D", "CrossAttnUpBlock3D" ] - # config["mid_block_type"] = "UNetMidBlock3DCrossAttn" + config["mid_block_type"] = "UNetMidBlock3DCrossAttn" from diffusers.utils import WEIGHTS_NAME model = cls.from_config(config, **unet_additional_kwargs) @@ -518,8 +520,47 @@ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_addition m, u = model.load_state_dict(state_dict, strict=False) print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n") - + params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()] print(f"### Temporal Module Parameters: {sum(params) / 1e6} M") - + return model + + @classmethod + def from_2d_unet(cls, unet2d_model, unet_additional_kwargs=None): + # Extract the configuration from the provided 2D UNet model + config_2d = unet2d_model.config + + # Adapt the configuration for the 3D UNet model + config_3d = config_2d.copy() + config_3d["_class_name"] = cls.__name__ + config_3d["down_block_types"] = [ + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D" + ] + config_3d["up_block_types"] = [ + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ] + config_3d["mid_block_type"] = "UNetMidBlock3DCrossAttn" + + # Initialize the 3D UNet model with the adapted configuration + model_3d = cls.from_config(config_3d, **unet_additional_kwargs) + + # Load the state dict (weights and biases) from the 2D model to the 3D model + # This operation might require handling of key mismatches due to different architectures + state_dict_2d = unet2d_model.state_dict() + # The strict flag is set to False to allow for differences in the model architectures + missing_keys, unexpected_keys = model_3d.load_state_dict(state_dict_2d, strict=False) + + # Optionally, handle or log missing and unexpected keys + if missing_keys or unexpected_keys: + logger.info(f"Missing keys: {missing_keys}") + logger.info(f"Unexpected keys: {unexpected_keys}") + + return model_3d + diff --git a/magicanimate/pipelines/pipeline_animation.py b/magicanimate/pipelines/pipeline_animation.py index 972f70fe..26d55135 100644 --- a/magicanimate/pipelines/pipeline_animation.py +++ b/magicanimate/pipelines/pipeline_animation.py @@ -621,7 +621,7 @@ def __call__( if init_latents is not None: latents = rearrange(init_latents, "(b f) c h w -> b c f h w", f=video_length) else: - num_channels_latents = self.unet.in_channels + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents,