diff --git a/.gitignore b/.gitignore
index cdb0f4cc..8e6c7218 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,4 +10,6 @@ pretrained_models
./*.png
./*.mp4
demo/tmp
-demo/outputs
\ No newline at end of file
+demo/outputs/*
+.idea/*
+outputs/*
\ No newline at end of file
diff --git a/demo/animate.py b/demo/animate.py
index b71f1940..001132df 100644
--- a/demo/animate.py
+++ b/demo/animate.py
@@ -8,91 +8,216 @@
# disclosure or distribution of this material and related documentation
# without an express license agreement from ByteDance or
# its affiliates is strictly prohibited.
-import argparse
-import argparse
import datetime
import inspect
import os
-import numpy as np
-from PIL import Image
-from omegaconf import OmegaConf
from collections import OrderedDict
+import numpy as np
import torch
-
-from diffusers import AutoencoderKL, DDIMScheduler, UniPCMultistepScheduler
-
-from tqdm import tqdm
+from PIL import Image
+from accelerate.utils import set_seed
+from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline
+from diffusers.models.attention_processor import AttnProcessor2_0
+from einops import rearrange
+from omegaconf import OmegaConf
from transformers import CLIPTextModel, CLIPTokenizer
-from magicanimate.models.unet_controlnet import UNet3DConditionModel
-from magicanimate.models.controlnet import ControlNetModel
+from demo.paths import config_path, inference_config_path, pretrained_encoder_path, pretrained_controlnet_path, \
+ motion_module, pretrained_model_path, pretrained_vae_path, output_path
from magicanimate.models.appearance_encoder import AppearanceEncoderModel
+from magicanimate.models.controlnet import ControlNetModel
from magicanimate.models.mutual_self_attention import ReferenceAttentionControl
+from magicanimate.models.unet_controlnet import UNet3DConditionModel
from magicanimate.pipelines.pipeline_animation import AnimationPipeline
from magicanimate.utils.util import save_videos_grid
-from accelerate.utils import set_seed
-
from magicanimate.utils.videoreader import VideoReader
-from einops import rearrange, repeat
-import csv, pdb, glob
-from safetensors import safe_open
-import math
-from pathlib import Path
+def xformerify(obj):
+ try:
+ import xformers
+ obj.enable_xformers_memory_efficient_attention
+
+ except ImportError:
+ obj.set_attn_processor(AttnProcessor2_0())
+
+
+class MagicAnimate:
+ def __init__(self) -> None:
+ config = OmegaConf.load(config_path)
-class MagicAnimate():
- def __init__(self, config="configs/prompts/animation.yaml") -> None:
print("Initializing MagicAnimate Pipeline...")
*_, func_args = inspect.getargvalues(inspect.currentframe())
- func_args = dict(func_args)
-
- config = OmegaConf.load(config)
-
- inference_config = OmegaConf.load(config.inference_config)
-
- motion_module = config.motion_module
-
- ### >>> create animation pipeline >>> ###
- tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer")
- text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder")
- if config.pretrained_unet_path:
- unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
+ self.func_args = dict(func_args)
+
+ self.config = config
+
+ inference_config = OmegaConf.load(inference_config_path)
+ self.inference_config = inference_config
+
+ ### Load controlnet and appearance encoder
+ self.controlnet = None
+ self.appearance_encoder = None
+ self.pipeline = None
+ self.reference_control_writer = None
+ self.reference_control_reader = None
+ self.L = config.L
+
+ print("Initialization Done!")
+
+ def __call__(self, prompt, source_image, motion_sequence, random_seed, step, guidance_scale, size=512, half_precision=False, checkpoint=None):
+ self.load_pipeline(half_precision, checkpoint)
+ n_prompt = ""
+ random_seed = int(random_seed)
+ step = int(step)
+ guidance_scale = float(guidance_scale)
+ samples_per_video = []
+ # manually set random seed for reproduction
+ if random_seed != -1:
+ torch.manual_seed(random_seed)
+ set_seed(random_seed)
+ else:
+ torch.seed()
+
+ if motion_sequence.endswith('.mp4'):
+ control = VideoReader(motion_sequence).read()
+ if control[0].shape[0] != size:
+ control = [np.array(Image.fromarray(c).resize((size, size))) for c in control]
+ control = np.array(control)
+ else:
+ control = np.load(motion_sequence)
+ if control.shape[1] != size:
+ control = np.array([np.array(Image.fromarray(c).resize((size, size))) for c in control])
+ control = np.array(control)
+
+ if source_image.shape[0] != size:
+ source_image = np.array(Image.fromarray(source_image).resize((size, size)))
+ H, W, C = source_image.shape
+
+ init_latents = None
+ original_length = control.shape[0]
+ if control.shape[0] % self.L > 0:
+ control = np.pad(control, ((0, self.L - control.shape[0] % self.L), (0, 0), (0, 0), (0, 0)), mode='edge')
+ generator = torch.Generator(device=torch.device("cuda:0"))
+ generator.manual_seed(torch.initial_seed())
+ sample = self.pipeline(
+ prompt,
+ negative_prompt=n_prompt,
+ num_inference_steps=step,
+ guidance_scale=guidance_scale,
+ width=W,
+ height=H,
+ video_length=len(control),
+ controlnet_condition=control,
+ init_latents=init_latents,
+ generator=generator,
+ appearance_encoder=self.appearance_encoder,
+ reference_control_writer=self.reference_control_writer,
+ reference_control_reader=self.reference_control_reader,
+ source_image=source_image,
+ ).videos
+
+ source_images = np.array([source_image] * original_length)
+ source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0
+ samples_per_video.append(source_images)
+
+ control = control / 255.0
+ control = rearrange(control, "t h w c -> 1 c t h w")
+ control = torch.from_numpy(control)
+ samples_per_video.append(control[:, :, :original_length])
+
+ samples_per_video.append(sample[:, :, :original_length])
+
+ samples_per_video = torch.cat(samples_per_video)
+
+ time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
+ savedir = output_path
+ animation_path = os.path.join(savedir, f"{time_str}.mp4")
+
+ os.makedirs(savedir, exist_ok=True)
+ save_videos_grid(samples_per_video, animation_path)
+
+ return animation_path
+
+ def load_pipeline(self, half_precision=False, model_path=None):
+ if self.pipeline is not None:
+ del self.pipeline
+
+ if self.appearance_encoder is not None:
+ del self.appearance_encoder
+
+ if self.controlnet is not None:
+ del self.controlnet
+ torch.cuda.empty_cache()
+
+ self.appearance_encoder = AppearanceEncoderModel.from_pretrained(pretrained_encoder_path,
+ subfolder="appearance_encoder").cuda()
+
+ self.controlnet = ControlNetModel.from_pretrained(pretrained_controlnet_path)
+ if half_precision:
+ self.controlnet.to(torch.float16)
+ self.appearance_encoder.to(torch.float16)
+ xformerify(self.controlnet)
+ xformerify(self.appearance_encoder)
+
+ config = self.config
+ inference_config = self.inference_config
+ vae = None
+ print(f"Loading pipeline from {model_path}")
+ if not model_path:
+ model_path = pretrained_model_path
+ unet_path = model_path
else:
- unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
- self.appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder="appearance_encoder").cuda()
- self.reference_control_writer = ReferenceAttentionControl(self.appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks)
- self.reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks)
- if config.pretrained_vae_path is not None:
- vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path)
+ unet_path = model_path
+ if "safetensors" in model_path or "ckpt" in model_path:
+ temp_pipeline = StableDiffusionPipeline.from_single_file(model_path)
+ tokenizer = temp_pipeline.tokenizer
+ text_encoder = temp_pipeline.text_encoder
+ unet_2d = temp_pipeline.unet
+ unet = UNet3DConditionModel.from_2d_unet(unet_2d, inference_config.unet_additional_kwargs)
+ try:
+ vae = temp_pipeline.vae
+ except:
+ print("No VAE found in ckpt, using default VAE")
else:
- vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae")
+ tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
+ text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder")
+ unet = UNet3DConditionModel.from_pretrained_2d(model_path, subfolder="unet",
+ unet_additional_kwargs=OmegaConf.to_container(
+ inference_config.unet_additional_kwargs))
+
+ if vae is None:
+ if os.path.exists(pretrained_vae_path):
+ vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path)
+ else:
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
+
+ self.reference_control_writer = ReferenceAttentionControl(self.appearance_encoder,
+ do_classifier_free_guidance=True, mode='write',
+ fusion_blocks=config.fusion_blocks)
+ self.reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read',
+ fusion_blocks=config.fusion_blocks)
- ### Load controlnet
- controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path)
+ if half_precision:
+ vae.to(torch.float16)
+ unet.to(torch.float16)
+ text_encoder.to(torch.float16)
- vae.to(torch.float16)
- unet.to(torch.float16)
- text_encoder.to(torch.float16)
- controlnet.to(torch.float16)
- self.appearance_encoder.to(torch.float16)
-
- unet.enable_xformers_memory_efficient_attention()
- self.appearance_encoder.enable_xformers_memory_efficient_attention()
- controlnet.enable_xformers_memory_efficient_attention()
+ xformerify(unet)
+ xformerify(vae)
self.pipeline = AnimationPipeline(
- vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet,
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=self.controlnet,
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
# NOTE: UniPCMultistepScheduler
).to("cuda")
- # 1. unet ckpt
- # 1.1 motion module
motion_module_state_dict = torch.load(motion_module, map_location="cpu")
- if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
- motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict
+ if "global_step" in motion_module_state_dict: self.func_args.update(
+ {"global_step": motion_module_state_dict["global_step"]})
+ motion_module_state_dict = motion_module_state_dict[
+ 'state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict
try:
# extra steps for self-trained models
state_dict = OrderedDict()
@@ -122,74 +247,3 @@ def __init__(self, config="configs/prompts/animation.yaml") -> None:
self.pipeline.to("cuda")
self.L = config.L
-
- print("Initialization Done!")
-
- def __call__(self, source_image, motion_sequence, random_seed, step, guidance_scale, size=512):
- prompt = n_prompt = ""
- random_seed = int(random_seed)
- step = int(step)
- guidance_scale = float(guidance_scale)
- samples_per_video = []
- # manually set random seed for reproduction
- if random_seed != -1:
- torch.manual_seed(random_seed)
- set_seed(random_seed)
- else:
- torch.seed()
-
- if motion_sequence.endswith('.mp4'):
- control = VideoReader(motion_sequence).read()
- if control[0].shape[0] != size:
- control = [np.array(Image.fromarray(c).resize((size, size))) for c in control]
- control = np.array(control)
-
- if source_image.shape[0] != size:
- source_image = np.array(Image.fromarray(source_image).resize((size, size)))
- H, W, C = source_image.shape
-
- init_latents = None
- original_length = control.shape[0]
- if control.shape[0] % self.L > 0:
- control = np.pad(control, ((0, self.L-control.shape[0] % self.L), (0, 0), (0, 0), (0, 0)), mode='edge')
- generator = torch.Generator(device=torch.device("cuda:0"))
- generator.manual_seed(torch.initial_seed())
- sample = self.pipeline(
- prompt,
- negative_prompt = n_prompt,
- num_inference_steps = step,
- guidance_scale = guidance_scale,
- width = W,
- height = H,
- video_length = len(control),
- controlnet_condition = control,
- init_latents = init_latents,
- generator = generator,
- appearance_encoder = self.appearance_encoder,
- reference_control_writer = self.reference_control_writer,
- reference_control_reader = self.reference_control_reader,
- source_image = source_image,
- ).videos
-
- source_images = np.array([source_image] * original_length)
- source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0
- samples_per_video.append(source_images)
-
- control = control / 255.0
- control = rearrange(control, "t h w c -> 1 c t h w")
- control = torch.from_numpy(control)
- samples_per_video.append(control[:, :, :original_length])
-
- samples_per_video.append(sample[:, :, :original_length])
-
- samples_per_video = torch.cat(samples_per_video)
-
- time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
- savedir = f"demo/outputs"
- animation_path = f"{savedir}/{time_str}.mp4"
-
- os.makedirs(savedir, exist_ok=True)
- save_videos_grid(samples_per_video, animation_path)
-
- return animation_path
-
\ No newline at end of file
diff --git a/demo/gradio_animate.py b/demo/gradio_animate.py
index 9a932a28..52013c27 100644
--- a/demo/gradio_animate.py
+++ b/demo/gradio_animate.py
@@ -8,21 +8,40 @@
# disclosure or distribution of this material and related documentation
# without an express license agreement from ByteDance or
# its affiliates is strictly prohibited.
-import argparse
+import os
+
+import gradio as gr
import imageio
import numpy as np
-import gradio as gr
from PIL import Image
+from huggingface_hub import hf_hub_download
+from omegaconf import OmegaConf
from demo.animate import MagicAnimate
+from demo.paths import config_path, script_path, models_path, magic_models_path, source_images_path, \
+ motion_sequences_path
-animator = MagicAnimate()
+animator = None
-def animate(reference_image, motion_sequence_state, seed, steps, guidance_scale):
- return animator(reference_image, motion_sequence_state, seed, steps, guidance_scale)
-with gr.Blocks() as demo:
+def list_checkpoints():
+ checkpoint_dir = os.path.join(models_path, "checkpoints")
+ # Recursively find all .ckpt and .safetensors files
+ checkpoints = [""]
+ for root, dirs, files in os.walk(checkpoint_dir):
+ for file in files:
+ if file.endswith(".ckpt") or file.endswith(".safetensors"):
+ checkpoints.append(os.path.join(root, file))
+ return checkpoints
+
+
+# source_image, motion_sequence, random_seed, step, guidance_scale, size=512, checkpoint=None
+def animate(prompt, reference_image, motion_sequence_state, seed, steps, guidance_scale, size, half_precision, checkpoint=None):
+ return animator(prompt, reference_image, motion_sequence_state, seed, steps, guidance_scale, size, half_precision,
+ checkpoint)
+
+with gr.Blocks() as demo:
gr.HTML(
"""
@@ -40,25 +59,33 @@ def animate(reference_image, motion_sequence_state, seed, steps, guidance_scale)
""")
animation = gr.Video(format="mp4", label="Animation Results", autoplay=True)
-
with gr.Row():
- reference_image = gr.Image(label="Reference Image")
- motion_sequence = gr.Video(format="mp4", label="Motion Sequence")
-
+ checkpoint = gr.Dropdown(label="Checkpoint", choices=list_checkpoints())
+ half_precision = gr.Checkbox(label="Half precision", default=False)
+ with gr.Row():
+ prompt = gr.Textbox(label="Prompt", placeholder="Type an optional prompt here", lines=2)
+ with gr.Row():
+ reference_image = gr.Image(label="Reference Image")
+ motion_sequence = gr.Video(format="mp4", label="Motion Sequence")
+
with gr.Column():
- random_seed = gr.Textbox(label="Random seed", value=1, info="default: -1")
- sampling_steps = gr.Textbox(label="Sampling steps", value=25, info="default: 25")
- guidance_scale = gr.Textbox(label="Guidance scale", value=7.5, info="default: 7.5")
- submit = gr.Button("Animate")
+ size = gr.Slider(label="Size", value=512, min=256, max=1024, step=256, info="default: 512", visible=False)
+ random_seed = gr.Slider(label="Random seed", value=1, info="default: -1")
+ sampling_steps = gr.Slider(label="Sampling steps", value=25, info="default: 25")
+ guidance_scale = gr.Slider(label="Guidance scale", value=7.5, info="default: 7.5", step=0.1)
+ submit = gr.Button("Animate")
+
def read_video(video):
reader = imageio.get_reader(video)
fps = reader.get_meta_data()['fps']
return video
-
+
+
def read_image(image, size=512):
return np.array(Image.fromarray(image).resize((size, size)))
-
+
+
# when user uploads a new video
motion_sequence.upload(
read_video,
@@ -72,25 +99,45 @@ def read_image(image, size=512):
reference_image
)
# when the `submit` button is clicked
+ # source_image, motion_sequence, random_seed, step, guidance_scale, size = 512, checkpoint = None
submit.click(
animate,
- [reference_image, motion_sequence, random_seed, sampling_steps, guidance_scale],
+ [prompt, reference_image, motion_sequence, random_seed, sampling_steps, guidance_scale, size, half_precision,
+ checkpoint],
animation
)
# Examples
gr.Markdown("## Examples")
+
gr.Examples(
examples=[
- ["inputs/applications/source_image/monalisa.png", "inputs/applications/driving/densepose/running.mp4"],
- ["inputs/applications/source_image/demo4.png", "inputs/applications/driving/densepose/demo4.mp4"],
- ["inputs/applications/source_image/dalle2.jpeg", "inputs/applications/driving/densepose/running2.mp4"],
- ["inputs/applications/source_image/dalle8.jpeg", "inputs/applications/driving/densepose/dancing2.mp4"],
- ["inputs/applications/source_image/multi1_source.png", "inputs/applications/driving/densepose/multi_dancing.mp4"],
+ [os.path.join(source_images_path, "monalisa.png"), os.path.join(motion_sequences_path, "running.mp4")],
+ [os.path.join(source_images_path, "demo4.png"), os.path.join(motion_sequences_path, "demo4.mp4")],
+ [os.path.join(source_images_path, "dalle2.jpeg"), os.path.join(motion_sequences_path, "running2.mp4")],
+ [os.path.join(source_images_path, "dalle8.jpeg"), os.path.join(motion_sequences_path, "dancing2.mp4")],
+ [os.path.join(source_images_path, "multi1_source.png"),
+ os.path.join(motion_sequences_path, "multi_dancing.mp4")],
],
inputs=[reference_image, motion_sequence],
outputs=animation,
)
+if __name__ == '__main__':
+ if not os.path.exists(models_path):
+ os.mkdir(models_path)
+
+ if not os.path.exists(os.path.join(models_path, "checkpoints")):
+ os.mkdir(os.path.join(models_path, "checkpoints"))
+
+ if not os.path.exists(magic_models_path):
+ # git lfs clone https://huggingface.co/zcxu-eric/MagicAnimate, not hf_hub_download
+ git_lfs_path = os.path.join(models_path, "MagicAnimate")
+ if not os.path.exists(git_lfs_path):
+ os.system(f"git clone https://huggingface.co/zcxu-eric/MagicAnimate {git_lfs_path}")
+ else:
+ print(f"MagicAnimate already exists at {git_lfs_path}")
+
+ animator = MagicAnimate()
-demo.launch(share=True)
\ No newline at end of file
+ demo.launch(share=True)
diff --git a/demo/paths.py b/demo/paths.py
new file mode 100644
index 00000000..2405d9cc
--- /dev/null
+++ b/demo/paths.py
@@ -0,0 +1,23 @@
+import os
+
+script_path = os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), ".."))
+output_path = os.path.join(script_path, "outputs")
+models_path = os.path.join(script_path, "pretrained_models")
+magic_models_path = os.path.join(models_path, "MagicAnimate")
+
+pretrained_model_path = os.path.join(models_path, "stable-diffusion-v1-5")
+pretrained_vae_path = os.path.join(models_path, "pretrained_vae")
+
+pretrained_controlnet_path = os.path.join(magic_models_path, "densepose_controlnet")
+pretrained_encoder_path = os.path.join(magic_models_path, "appearance_encoder")
+pretrained_motion_module_path = os.path.join(magic_models_path, "temporal_attention")
+motion_module = os.path.join(pretrained_motion_module_path, "temporal_attention.ckpt")
+
+pretrained_unet_path = ""
+
+config_path = os.path.join(script_path, "configs", "prompts", "animation.yaml")
+inference_config_path = os.path.join(script_path, "configs", "inference", "inference.yaml")
+
+source_images_path = os.path.join(script_path, "inputs", "applications", "source_image")
+motion_sequences_path = os.path.join(script_path, "inputs", "applications", "driving", "densepose")
+
diff --git a/environment.yaml b/environment.yaml
index 00bcc3f4..c8dec364 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -2,6 +2,7 @@ name: manimate
channels:
- conda-forge
- defaults
+ - nvidia
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
diff --git a/magicanimate/models/unet.py b/magicanimate/models/unet.py
index 09e5e11f..b3f9dfc0 100644
--- a/magicanimate/models/unet.py
+++ b/magicanimate/models/unet.py
@@ -506,3 +506,35 @@ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_addition
print(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
return model
+
+ @classmethod
+ def from_2d_unet(cls, unet2d_model, unet_additional_kwargs=None):
+ # Extract configuration from 2D UNet model
+ config_2d = unet2d_model.config
+ config_3d = config_2d.copy() # Convert to a dictionary if necessary
+
+ # Update configuration for 3D UNet specifics
+ config_3d["_class_name"] = cls.__name__
+ config_3d["down_block_types"] = [
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "DownBlock3D"
+ ]
+ config_3d["up_block_types"] = [
+ "UpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D"
+ ]
+
+ # Initialize a new 3D UNet model with the updated configuration
+ model_3d = cls.from_config(config_3d, **unet_additional_kwargs)
+
+ # Load state dict from 2D model to 3D model
+ # Note: This might require additional handling if the architectures differ significantly
+ state_dict_2d = unet2d_model.state_dict()
+ model_3d.load_state_dict(state_dict_2d, strict=False)
+
+ return model_3d
+
diff --git a/magicanimate/models/unet_controlnet.py b/magicanimate/models/unet_controlnet.py
index 0ccd9cad..ef3501e3 100644
--- a/magicanimate/models/unet_controlnet.py
+++ b/magicanimate/models/unet_controlnet.py
@@ -42,7 +42,6 @@
)
from .resnet import InflatedConv3d
-
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -56,52 +55,52 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
- self,
- sample_size: Optional[int] = None,
- in_channels: int = 4,
- out_channels: int = 4,
- center_input_sample: bool = False,
- flip_sin_to_cos: bool = True,
- freq_shift: int = 0,
- down_block_types: Tuple[str] = (
- "CrossAttnDownBlock3D",
- "CrossAttnDownBlock3D",
- "CrossAttnDownBlock3D",
- "DownBlock3D",
- ),
- mid_block_type: str = "UNetMidBlock3DCrossAttn",
- up_block_types: Tuple[str] = (
- "UpBlock3D",
- "CrossAttnUpBlock3D",
- "CrossAttnUpBlock3D",
- "CrossAttnUpBlock3D"
- ),
- only_cross_attention: Union[bool, Tuple[bool]] = False,
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
- layers_per_block: int = 2,
- downsample_padding: int = 1,
- mid_block_scale_factor: float = 1,
- act_fn: str = "silu",
- norm_num_groups: int = 32,
- norm_eps: float = 1e-5,
- cross_attention_dim: int = 1280,
- attention_head_dim: Union[int, Tuple[int]] = 8,
- dual_cross_attention: bool = False,
- use_linear_projection: bool = False,
- class_embed_type: Optional[str] = None,
- num_class_embeds: Optional[int] = None,
- upcast_attention: bool = False,
- resnet_time_scale_shift: str = "default",
-
- # Additional
- use_motion_module = False,
- motion_module_resolutions = ( 1,2,4,8 ),
- motion_module_mid_block = False,
- motion_module_decoder_only = False,
- motion_module_type = None,
- motion_module_kwargs = {},
- unet_use_cross_frame_attention = None,
- unet_use_temporal_attention = None,
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "DownBlock3D",
+ ),
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
+ up_block_types: Tuple[str] = (
+ "UpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D"
+ ),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+
+ # Additional
+ use_motion_module=False,
+ motion_module_resolutions=(1, 2, 4, 8),
+ motion_module_mid_block=False,
+ motion_module_decoder_only=False,
+ motion_module_type=None,
+ motion_module_kwargs={},
+ unet_use_cross_frame_attention=None,
+ unet_use_temporal_attention=None,
):
super().__init__()
@@ -166,8 +165,9 @@ def __init__(
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
-
- use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
+
+ use_motion_module=use_motion_module and (res in motion_module_resolutions) and (
+ not motion_module_decoder_only),
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
)
@@ -191,14 +191,14 @@ def __init__(
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
-
+
use_motion_module=use_motion_module and motion_module_mid_block,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
)
else:
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
-
+
# count how many layers upsample the videos
self.num_upsamplers = 0
@@ -326,16 +326,16 @@ def _set_gradient_checkpointing(self, module, value=False):
module.gradient_checkpointing = value
def forward(
- self,
- sample: torch.FloatTensor,
- timestep: Union[torch.Tensor, float, int],
- encoder_hidden_states: torch.Tensor,
- class_labels: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- # for controlnet
- down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
- mid_block_additional_residual: Optional[torch.Tensor] = None,
- return_dict: bool = True,
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ # for controlnet
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
) -> Union[UNet3DConditionOutput, Tuple]:
r"""
Args:
@@ -354,7 +354,7 @@ def forward(
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
- default_overall_up_factor = 2**self.num_upsamplers
+ default_overall_up_factor = 2 ** self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
@@ -423,7 +423,8 @@ def forward(
attention_mask=attention_mask,
)
else:
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb,
+ encoder_hidden_states=encoder_hidden_states)
down_block_res_samples += res_samples
@@ -431,7 +432,7 @@ def forward(
new_down_block_res_samples = ()
for down_block_res_sample, down_block_additional_residual in zip(
- down_block_res_samples, down_block_additional_residuals
+ down_block_res_samples, down_block_additional_residuals
):
down_block_res_sample = down_block_res_sample + down_block_additional_residual
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
@@ -450,7 +451,7 @@ def forward(
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
- res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
@@ -469,7 +470,8 @@ def forward(
)
else:
sample = upsample_block(
- hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size,
+ encoder_hidden_states=encoder_hidden_states,
)
# post-process
@@ -506,7 +508,7 @@ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_addition
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D"
]
- # config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
+ config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
from diffusers.utils import WEIGHTS_NAME
model = cls.from_config(config, **unet_additional_kwargs)
@@ -518,8 +520,47 @@ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_addition
m, u = model.load_state_dict(state_dict, strict=False)
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
# print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n")
-
+
params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
print(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
-
+
return model
+
+ @classmethod
+ def from_2d_unet(cls, unet2d_model, unet_additional_kwargs=None):
+ # Extract the configuration from the provided 2D UNet model
+ config_2d = unet2d_model.config
+
+ # Adapt the configuration for the 3D UNet model
+ config_3d = config_2d.copy()
+ config_3d["_class_name"] = cls.__name__
+ config_3d["down_block_types"] = [
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "DownBlock3D"
+ ]
+ config_3d["up_block_types"] = [
+ "UpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D"
+ ]
+ config_3d["mid_block_type"] = "UNetMidBlock3DCrossAttn"
+
+ # Initialize the 3D UNet model with the adapted configuration
+ model_3d = cls.from_config(config_3d, **unet_additional_kwargs)
+
+ # Load the state dict (weights and biases) from the 2D model to the 3D model
+ # This operation might require handling of key mismatches due to different architectures
+ state_dict_2d = unet2d_model.state_dict()
+ # The strict flag is set to False to allow for differences in the model architectures
+ missing_keys, unexpected_keys = model_3d.load_state_dict(state_dict_2d, strict=False)
+
+ # Optionally, handle or log missing and unexpected keys
+ if missing_keys or unexpected_keys:
+ logger.info(f"Missing keys: {missing_keys}")
+ logger.info(f"Unexpected keys: {unexpected_keys}")
+
+ return model_3d
+
diff --git a/magicanimate/pipelines/pipeline_animation.py b/magicanimate/pipelines/pipeline_animation.py
index 08a77bac..26d55135 100644
--- a/magicanimate/pipelines/pipeline_animation.py
+++ b/magicanimate/pipelines/pipeline_animation.py
@@ -33,6 +33,7 @@
import numpy as np
import torch
import torch.distributed as dist
+from diffusers import DiffusionPipeline
from tqdm import tqdm
from diffusers.utils import is_accelerate_available
from packaging import version
@@ -40,7 +41,6 @@
from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL
-from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
@@ -621,7 +621,7 @@ def __call__(
if init_latents is not None:
latents = rearrange(init_latents, "(b f) c h w -> b c f h w", f=video_length)
else:
- num_channels_latents = self.unet.in_channels
+ num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
diff --git a/requirements.txt b/requirements.txt
index 82f59c87..83a35267 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -16,7 +16,8 @@ click==8.1.7
cmake==3.27.2
contourpy==1.1.0
cycler==0.11.0
-datasets==2.14.4
+datasets
+diffusers==0.24.0
dill==0.3.7
einops==0.6.1
exceptiongroup==1.1.3
@@ -34,8 +35,9 @@ grpcio==1.57.0
h11==0.14.0
httpcore==0.17.3
httpx==0.24.1
-huggingface-hub==0.16.4
+huggingface-hub
idna==3.4
+imageio
importlib-metadata==6.8.0
importlib-resources==6.0.1
jinja2==3.1.2
@@ -53,17 +55,6 @@ multidict==6.0.4
multiprocess==0.70.15
networkx==3.1
numpy==1.24.4
-nvidia-cublas-cu11==11.10.3.66
-nvidia-cuda-cupti-cu11==11.7.101
-nvidia-cuda-nvrtc-cu11==11.7.99
-nvidia-cuda-runtime-cu11==11.7.99
-nvidia-cudnn-cu11==8.5.0.96
-nvidia-cufft-cu11==10.9.0.58
-nvidia-curand-cu11==10.2.10.91
-nvidia-cusolver-cu11==11.4.0.1
-nvidia-cusparse-cu11==11.7.4.91
-nvidia-nccl-cu11==2.14.3
-nvidia-nvtx-cu11==11.7.91
oauthlib==3.2.2
omegaconf==2.3.0
opencv-python==4.8.0.76
@@ -102,7 +93,6 @@ toolz==0.12.0
torchmetrics==1.1.0
tqdm==4.66.1
transformers==4.32.0
-triton==2.0.0
tzdata==2023.3
urllib3==1.26.16
uvicorn==0.23.2
@@ -124,4 +114,4 @@ ffmpeg-python
torch==2.0.1
torchvision==0.15.2
xformers==0.0.22
-diffusers==0.21.4
+--extra-index-url https://download.pytorch.org/whl/cu118