From 33ee810acb35bc44439eb291c6a78100fbe20fd8 Mon Sep 17 00:00:00 2001 From: johbau Date: Fri, 29 May 2026 09:24:58 +0200 Subject: [PATCH 1/5] inference: load Lance without the CPU RAM spike MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Lets inference_lance.py run on hosts with less system RAM than the model needs to materialize in fp32 on CPU (~12 GB for Lance_3B). On main, the first call `Qwen2ForCausalLM(llm_config)` allocates a freshly-init'd fp32 3B model on CPU and OOM-kills an 8 GB host before any GPU code runs. Changes: - Build LLM / ViT / Lance wrapper under `accelerate.init_empty_weights()` so every nn.Parameter is shape-only on the meta device — near-zero CPU RAM during construction. - Replace `safetensors.safe_open()` with a hand-rolled reader that does plain seek+read of one tensor at a time. safe_open mmaps the whole 12 GB file, which Linux refuses on a host with strict overcommit / no swap (ENOMEM). Peak CPU RAM during load is one tensor at a time. - Pass `dtype=torch.bfloat16` to `set_module_tensor_to_device` so loaded values aren't silently upcast back to the meta tensor's fp32 default. Without this the model lives at fp32 on the GPU, doubling VRAM and breaking the bf16 autocast path (fp32 weights * bf16 activations → fp32 output, then index-put into bf16 destination crashes). - Replace numpy fp64 sin-cos position embeddings with a torch fp32 port that computes on the param's device. PositionEmbedding3D._init_weights used to peak around 4 GB of CPU RAM building 3 intermediate arrays of shape (t*h*w, ~D/3); the GPU version contributes ~zero CPU. - Materialize any params left on `meta` after the load (the popped latent_pos_embed sin-cos buffer) on the target device and re-init. - Set `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` so the per-tensor streaming pattern doesn't fragment the CUDA caching allocator into a state where large allocations fail despite plenty of free VRAM. Peak CPU RSS during load stays under ~2 GB. The model loads in bf16 across whatever device the runner picks (cuda:LOCAL_RANK). Multi-GPU sharding for hosts that can't fit the model on one card is a separate follow-up — see SHARDED_LOAD.md. See LOW_RAM_LOAD.md for the full memory profile and per-file rationale. Co-Authored-By: Claude Opus 4.7 --- LOW_RAM_LOAD.md | 108 +++++++++++ benchmarks/sample_env.sh | 8 + inference_lance.py | 302 +++++++++++++++++++++++++------ modeling/lance/modeling_utils.py | 73 +++++++- 4 files changed, 429 insertions(+), 62 deletions(-) create mode 100644 LOW_RAM_LOAD.md diff --git a/LOW_RAM_LOAD.md b/LOW_RAM_LOAD.md new file mode 100644 index 0000000..d55154b --- /dev/null +++ b/LOW_RAM_LOAD.md @@ -0,0 +1,108 @@ +# Low-RAM inference load + +This change lets `inference_lance.py` run on hosts with **less system RAM than +the model needs to materialize in fp32 on CPU** (~12 GB for Lance_3B). It +removes the CPU-side memory spike from the load path. Multi-GPU model +parallelism is layered on top by a separate change — see +[`SHARDED_LOAD.md`](SHARDED_LOAD.md). + +## Why + +On `main`, the first stage of `main()`: + +```python +language_model: Qwen2ForCausalLM = Qwen2ForCausalLM(llm_config) +``` + +allocates a freshly-init'd fp32 3B model on CPU (~12 GB). On an 8 GB host this +gets OOM-killed before any GPU code runs. The actual checkpoint load +(`load_file → load_state_dict`) makes things worse by holding the full state +dict on CPU as a second copy. Several smaller allocations downstream +(numpy fp64 sin-cos, full-file `safe_open()` mmap) also push past the ceiling. + +## What changed + +### 1. Meta-init the model skeleton + +LLM, ViT, and the `Lance` wrapper are constructed inside +`accelerate.init_empty_weights()`. Every `nn.Parameter` becomes shape-only on +the `meta` device — zero storage. The fp32-on-CPU spike disappears. + +`modeling/lance/modeling_utils.py`'s `PositionEmbedding{,3D}._init_weights` +now early-returns when its param is still meta, deferring sin-cos +materialization until after the load. + +### 2. Stream the checkpoint, don't mmap it + +`safetensors.safe_open()` mmaps the whole 12 GB checkpoint file. On a host +with strict commit accounting and no swap, the kernel refuses a 12 GB +file-backed mapping (`ENOMEM`). The streaming loader (`_stream_load_into`) +opens the file in plain binary mode, reads the 8-byte header length + JSON +header, and seeks to each tensor's data offset. Peak CPU RAM during load is +one tensor at a time — worst case ~1.2 GB for the embedding layer, briefly. + +### 3. Load tensors directly to GPU at bf16 + +Each tensor is read into CPU, cast to bf16, and handed to +`accelerate.utils.set_module_tensor_to_device(model, name, device, +value=tensor, dtype=torch.bfloat16)`. **The `dtype=` argument is +load-bearing**: without it, accelerate silently casts the value to +`old_value.dtype` to match the meta tensor's nominal dtype (fp32 default). +That would both double VRAM and produce fp32 weights that the bf16 autocast +path then promotes back to fp32 mid-attention — eventually crashing on an +index-put dtype mismatch. + +After the load loop, `_materialize_remaining_meta` walks the model for +parameters still on meta (e.g. `latent_pos_embed.pos_embed`, which the +original code popped from the checkpoint to recompute per-resolution), +allocates real storage on the target device, and re-runs `_init_weights()`. + +### 4. Compute sin-cos position embeddings on GPU, not CPU + +`get_3d_sincos_pos_embed` (numpy fp64) used to allocate three intermediate +arrays of shape `(t*h*w, ~D/3)` plus a concatenated copy — peaking around +**4 GB of CPU RAM** for Lance's defaults (`t=31, h=w=64, D=2048`). + +Replaced with `_torch_3d_sincos` / `_torch_2d_sincos` that compute on the +parameter's device in torch fp32. CPU contribution is ~zero. Same change for +the 2D variant used by `PositionEmbedding`. + +### 5. `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` + +The per-tensor streaming pattern fragments the default CUDA caching +allocator enough that big tensors can fail to allocate even when total free +VRAM is plenty. `expandable_segments` coalesces freed regions and lets large +allocations grow into them. Set by default in +`benchmarks/sample_env.sh::lance_setup_common_env` (`${VAR:-default}` so a +user-set value still wins). + +## Memory profile (Lance_3B on a single GPU) + +| Stage | Peak CPU RSS | Notes | +|---|---|---| +| Meta-init LLM/ViT/Lance | ~few hundred MB | torch + python + dataclass overhead | +| ViT streaming load (1.2 GB safetensors) | ~1 GB | one fp32 tensor at a time | +| Lance streaming load (12.3 GB safetensors) | ~1.5 GB | embedding layer is the worst tensor | +| Materialize popped sin-cos pos_embed | tiny | computed on GPU | +| Tokenizer + resize | <500 MB | | + +Peak CPU RSS during load stays under ~2 GB, comfortably below an 8 GB +ceiling. Total VRAM usage on the target card is ~6 GB (Lance_3B in bf16 ++ ViT + VAE), which fits on a single 40 GB GPU but not a 12 GB one — for +that case, see [`SHARDED_LOAD.md`](SHARDED_LOAD.md). + +## What `main` users keep + +For hosts that *do* have enough RAM, this change is still net-positive: +the load is faster (no fp32 → bf16 conversion afterwards, no full state-dict +held on CPU) and uses half the VRAM (bf16 instead of fp32 at rest). The +launcher and config are unchanged in this commit; the behavior change is +transparent to the runner. + +## File-by-file summary + +| File | Change | +|---|---| +| `inference_lance.py` | `init_empty_weights()` for LLM/ViT/Lance; new streaming safetensors reader (`_read_safetensors_header`, `_read_safetensors_tensor`, `_stream_load_into`); `_materialize_remaining_meta`; `_resolve_lance_checkpoint`; passes `dtype=torch.bfloat16` to `set_module_tensor_to_device`; removed the per-batch `.to(device)` calls on the model. | +| `modeling/lance/modeling_utils.py` | New `_torch_2d_sincos` / `_torch_3d_sincos`; `_init_weights` early-returns on meta tensors and otherwise computes on the param's device. | +| `benchmarks/sample_env.sh` | Exports `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` (user override respected). | diff --git a/benchmarks/sample_env.sh b/benchmarks/sample_env.sh index 479c500..11f82de 100644 --- a/benchmarks/sample_env.sh +++ b/benchmarks/sample_env.sh @@ -38,6 +38,14 @@ lance_setup_common_env() { export CUDA_LAUNCH_BLOCKING="${CUDA_LAUNCH_BLOCKING:-0}" export NCCL_DEBUG="${NCCL_DEBUG:-VERSION}" export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC="${TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC:-900}" + + # The streaming load in inference_lance.py allocates thousands of small-to-large + # tensors onto each GPU as it walks the checkpoint. The default caching allocator + # fragments under that pattern hard enough that a 1.2 GB tensor can fail to + # allocate on a card that still has plenty of total free VRAM. expandable_segments + # coalesces freed regions and lets large allocations grow into them. Required for + # the 5×3060 sharded load to succeed. + export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" } diff --git a/inference_lance.py b/inference_lance.py index 27f49c4..f6e0823 100644 --- a/inference_lance.py +++ b/inference_lance.py @@ -24,13 +24,18 @@ import os.path as osp from copy import deepcopy import json -from typing import Tuple, cast, Optional +from typing import Tuple, cast, Optional, Dict, List import torch import torch.distributed as dist +from torch import nn from torch.utils.data import DataLoader from transformers import HfArgumentParser, set_seed from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig +import struct +import numpy as np from safetensors.torch import load_file +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device from data.dataset_base import DataConfig, simple_custom_collate from data.data_utils import add_special_tokens @@ -114,35 +119,170 @@ }, } -def init_from_model_path_if_needed(model: Qwen2ForCausalLM, model_args: ModelArguments): - # Always load the trained Lance checkpoint from model_path. - path_dir = model_args.model_path - ema_path = osp.join(path_dir, "ema.safetensors") - model_path = osp.join(path_dir, "model.safetensors") +# Names of buffers/params that the original codepath intentionally popped from the +# checkpoint before load (they are fixed sin-cos embeddings rebuilt per resolution). +_POPPED_FROM_CHECKPOINT = frozenset({"latent_pos_embed.pos_embed"}) + + +def _resolve_lance_checkpoint(model_path_dir: str) -> str: + """Return the path of the Lance checkpoint to load (preferring model.safetensors).""" + for fname in ("model.safetensors", "ema.safetensors"): + cand = osp.join(model_path_dir, fname) + if osp.exists(cand): + return cand + raise FileNotFoundError( + f"No Lance checkpoint ('model.safetensors' or 'ema.safetensors') found in {model_path_dir}. " + "Download the full Lance_3B (or Lance_3B_Video) weights with:\n" + ' hf download bytedance-research/Lance --local-dir downloads --include "Lance_3B/*"' + ) + +def _device_for_param(param_name: str, device_map: Dict[str, int]) -> int: + """Find the device assignment for `param_name` by walking up its dotted path.""" + parts = param_name.split(".") + for i in range(len(parts), 0, -1): + prefix = ".".join(parts[:i]) + if prefix in device_map: + return device_map[prefix] + return 0 # default to cuda:0 for any unmapped params (Lance has very few) + + +# safetensors dtype string -> (numpy dtype used to read raw bytes, optional torch view dtype) +# bf16 has no native numpy dtype, so we read as uint16 then bit-cast via tensor.view(). +_SAFE_DTYPE_MAP = { + "F64": (np.float64, None), + "F32": (np.float32, None), + "F16": (np.float16, None), + "BF16": (np.uint16, torch.bfloat16), + "I64": (np.int64, None), + "I32": (np.int32, None), + "I16": (np.int16, None), + "I8": (np.int8, None), + "U8": (np.uint8, None), + "BOOL": (np.bool_, None), +} - model_path_ft = None - if osp.exists(model_path): - model_path_ft = model_path - elif osp.exists(ema_path): - model_path_ft = ema_path - if model_path_ft: - model_state_dict = load_file(model_path_ft, device="cpu") - else: - raise FileNotFoundError( - f"Fine-tuning failed: No valid checkpoint ('ema.safetensors' or 'model.safetensors') found in {path_dir}" +def _read_safetensors_header(f) -> Tuple[Dict, int]: + """Read the 8-byte length + JSON header. Returns (header_dict, data_section_offset).""" + header_len_bytes = f.read(8) + if len(header_len_bytes) != 8: + raise ValueError(f"Truncated safetensors file: only {len(header_len_bytes)}/8 length bytes") + (header_len,) = struct.unpack(" torch.Tensor: + """Read one tensor's bytes via plain seek+read (no mmap) and return a CPU torch tensor.""" + start, end = meta["data_offsets"] + nbytes = end - start + f.seek(data_section_offset + start) + raw = f.read(nbytes) + if len(raw) != nbytes: + raise ValueError(f"Short read: got {len(raw)}/{nbytes} bytes") + np_dtype, view_dtype = _SAFE_DTYPE_MAP[meta["dtype"]] + # .copy() detaches from the read-only `raw` bytes so the buffer can be freed before + # we keep the torch tensor around. Peak CPU memory: one tensor at a time. + np_arr = np.frombuffer(raw, dtype=np_dtype).copy().reshape(meta["shape"]) + del raw + tensor = torch.from_numpy(np_arr) + if view_dtype is not None: + tensor = tensor.view(view_dtype) + return tensor + + +def _stream_load_into( + model: nn.Module, + safetensors_path: str, + device_map: Dict[str, int], + key_prefix: str = "", + skip_keys: frozenset = frozenset(), + dtype: torch.dtype = torch.bfloat16, +) -> Tuple[List[str], List[str]]: + """Stream safetensors into `model`, one tensor at a time, directly onto GPU shards. + + Uses plain `open() + seek() + read()` rather than `safetensors.safe_open()` because + safe_open mmaps the whole file (12 GB for Lance_3B) — the kernel's overcommit policy + rejects that on the 8 GB host. With direct IO peak CPU RAM is one tensor at a time + (worst case ~1.2 GB for the embedding layer at fp32, briefly). + + `key_prefix` is prepended to each safetensors key when looking up the target + parameter in `model` — the ViT file stores bare keys; they live under `vit_model.*` + in the Lance wrapper. + """ + loaded: List[str] = [] + unknown: List[str] = [] + model_keys = set(dict(model.named_parameters()).keys()) | set(dict(model.named_buffers()).keys()) + with open(safetensors_path, "rb") as f: + header, data_section_offset = _read_safetensors_header(f) + for key, meta in header.items(): + if key == "__metadata__": + continue + full_name = f"{key_prefix}{key}" + if full_name in skip_keys: + continue + if full_name not in model_keys: + unknown.append(full_name) + continue + tensor = _read_safetensors_tensor(f, meta, data_section_offset).to(dtype) + device = _device_for_param(full_name, device_map) + # Pass dtype= explicitly: without it, set_module_tensor_to_device casts + # `value` to `old_value.dtype` to match the meta tensor's nominal dtype + # (which is fp32 from init_empty_weights' default). That silently upcasts + # our bf16 tensors back to fp32, doubling VRAM and breaking the autocast + # path (fp32 weights * bf16 activations → fp32 output, then index-put into + # a bf16 destination crashes with a dtype-mismatch error). + set_module_tensor_to_device(model, full_name, device, value=tensor, dtype=dtype) + loaded.append(full_name) + del tensor + return loaded, unknown + + +def _materialize_remaining_meta(model: "Lance", device_map: Dict[str, int], dtype: torch.dtype): + """Allocate any still-meta params on their target devices and re-init the + fixed sin-cos position embeddings (which were popped from the checkpoint).""" + from modeling.lance.modeling_utils import PositionEmbedding, PositionEmbedding3D + + materialized = [] + for name, param in list(model.named_parameters()): + if not param.is_meta: + continue + device = _device_for_param(name, device_map) + # Walk to the owning module to swap the meta param for a real one. + *mod_parts, attr = name.split(".") + owner = model + for m in mod_parts: + owner = getattr(owner, m) + new_param = torch.nn.Parameter( + torch.zeros(param.shape, dtype=dtype, device=f"cuda:{device}"), + requires_grad=param.requires_grad, ) + setattr(owner, attr, new_param) + materialized.append(name) - # NOTE: position embeds are fixed sinusoidal embeddings, so we can just pop it off, - # which makes it easier to adapt to different resolutions. - if 'latent_pos_embed.pos_embed' in model_state_dict: - model_state_dict.pop('latent_pos_embed.pos_embed') + # Same for any buffers that ended up meta (rare; defensive). + for name, buf in list(model.named_buffers()): + if not buf.is_meta: + continue + device = _device_for_param(name, device_map) + *mod_parts, attr = name.split(".") + owner = model + for m in mod_parts: + owner = getattr(owner, m) + owner.register_buffer( + attr, torch.zeros(buf.shape, dtype=buf.dtype, device=f"cuda:{device}") + ) + materialized.append(name) - msg = model.load_state_dict(model_state_dict, strict=False) # strict = True | False - clean_memory(model_state_dict) + # Re-run the sin-cos init now that the param tensors are real. + for sub in model.modules(): + if isinstance(sub, (PositionEmbedding, PositionEmbedding3D)): + sub._init_weights() - return msg + return materialized def clean_memory(*objects): @@ -265,7 +405,8 @@ def validate_on_fixed_batch( save_path_gt: str = "", ): val_data = val_data_cpu.cuda(device).to_dict() - fsdp_model = fsdp_model.to(device=device, dtype=torch.bfloat16) + # No fsdp_model.to(device) needed: streaming load already placed weights on the + # target device in bf16. with torch.no_grad(), torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16): # Compute padded_latent. @@ -450,11 +591,19 @@ def log_stage(stage_name: str, start_time: float, extra: str = ""): llm_config.freeze_und = training_args.freeze_und llm_config.apply_qwen_2_5_vl_pos_emb = training_args.apply_qwen_2_5_vl_pos_emb + # ===== Meta-init: build the module skeleton with zero CPU RAM. ===== + # The bare Qwen2ForCausalLM(llm_config) call used to materialize a full fp32 3B + # model on CPU (~12 GB), which is the load step that OOM-killed an 8 GB box. + # Under init_empty_weights() every nn.Parameter is created on the "meta" device + # (shape only, no storage), so this whole block stays at near-zero RAM. stage_start = time.perf_counter() - log_rank0(f"[startup] Initializing LLM weights: {model_args.model_path}") - language_model: Qwen2ForCausalLM = Qwen2ForCausalLM(llm_config) - log_stage("LLM weight init", stage_start) + log_rank0(f"[startup] Meta-initializing LLM: {model_args.model_path}") + with init_empty_weights(): + language_model: Qwen2ForCausalLM = Qwen2ForCausalLM(llm_config) + log_stage("LLM meta-init", stage_start) + vit_model = None + vit_config = None if training_args.visual_und: if model_args.vit_type in ("qwen2_5_vl", "qwen_2_5_vl_original"): stage_start = time.perf_counter() @@ -463,17 +612,16 @@ def log_stage(stage_name: str, start_time: float, extra: str = ""): log_stage("VIT config load", stage_start) stage_start = time.perf_counter() - log_rank0(f"[startup] Loading VIT weights: {osp.join(model_args.vit_path, 'vit.safetensors')}") - vit_model = Qwen2_5_VisionTransformerPretrainedModel(vit_config) - vit_weights = load_file(osp.join(model_args.vit_path, "vit.safetensors")) - vit_model.load_state_dict(vit_weights, strict=True) - log_stage("VIT weight load", stage_start) + log_rank0("[startup] Meta-initializing VIT (weights loaded later from vit.safetensors)") + with init_empty_weights(): + vit_model = Qwen2_5_VisionTransformerPretrainedModel(vit_config) + log_stage("VIT meta-init", stage_start) else: raise ValueError(f"Unsupported vit_type: {model_args.vit_type}") - clean_memory(vit_weights) - if training_args.visual_gen: + # WanVideoVAE itself uses torch.device("meta") + assign-load internally, so it + # doesn't contribute to the CPU RAM spike. Built eagerly so vae_config is real. stage_start = time.perf_counter() log_rank0("[startup] Initializing VAE") vae_model = WanVideoVAE() @@ -483,7 +631,6 @@ def log_stage(stage_name: str, start_time: float, extra: str = ""): vae_model = None vae_config = None - # Lance configuration config = LanceConfig( visual_gen=training_args.visual_gen, visual_und=training_args.visual_und, @@ -498,33 +645,77 @@ def log_stage(stage_name: str, start_time: float, extra: str = ""): interpolate_pos=model_args.interpolate_pos, timestep_shift=training_args.timestep_shift, ) - model: Lance = Lance( - language_model=language_model, - vit_model=vit_model if training_args.visual_und else None, - vit_type=model_args.vit_type, - config=config, - training_args=training_args, - ) + stage_start = time.perf_counter() - log_rank0(f"[startup] Moving Lance model to GPU {DEVICE}") - model = model.to(DEVICE) - log_stage("Lance model move to GPU", stage_start) + log_rank0("[startup] Meta-initializing Lance wrapper") + with init_empty_weights(): + model: Lance = Lance( + language_model=language_model, + vit_model=vit_model if training_args.visual_und else None, + vit_type=model_args.vit_type, + config=config, + training_args=training_args, + ) + log_stage("Lance meta-init", stage_start) - # Setup tokenizer for model: + # Single-device load: every tensor is routed onto the current device (cuda:DEVICE). + # `_device_for_param` defaults to 0 for any name not in `device_map`, so an empty + # map is enough here. (A follow-up change adds a smart device-map builder that + # spreads layers across multiple GPUs.) + device_map: Dict[str, int] = {} + + # ===== Stream-load weights directly onto the GPU at bf16. ===== + # ViT weights live in a separate file; in the Lance wrapper they sit under vit_model.*. + if training_args.visual_und: + vit_safetensors = osp.join(model_args.vit_path, "vit.safetensors") + stage_start = time.perf_counter() + log_rank0(f"[startup] Streaming VIT weights from {vit_safetensors}") + vit_loaded, vit_unknown = _stream_load_into( + model, vit_safetensors, device_map, key_prefix="vit_model.", dtype=torch.bfloat16, + ) + log_stage("VIT streaming load", stage_start, + extra=f"loaded={len(vit_loaded)} unknown={len(vit_unknown)}") + if vit_unknown: + log_rank0(f"[startup] WARNING: {len(vit_unknown)} ViT key(s) had no matching param " + f"(first few: {vit_unknown[:5]})") + + # The main Lance checkpoint: covers language_model.*, the connector / vae<->llm / + # time_embedder / latent_pos_embed (popped) / etc. Skip the popped sin-cos buffer. + lance_ckpt = _resolve_lance_checkpoint(model_args.model_path) + stage_start = time.perf_counter() + log_rank0(f"[startup] Streaming Lance checkpoint from {lance_ckpt}") + main_loaded, main_unknown = _stream_load_into( + model, lance_ckpt, device_map, skip_keys=_POPPED_FROM_CHECKPOINT, dtype=torch.bfloat16, + ) + log_stage("Lance streaming load", stage_start, + extra=f"loaded={len(main_loaded)} unknown={len(main_unknown)}") + if main_unknown: + # Many Lance training-time keys (optimizer state, etc.) may not exist on the + # inference model; informational, not fatal. + log_rank0(f"[startup] NOTE: {len(main_unknown)} checkpoint key(s) had no matching param " + f"(first few: {main_unknown[:5]})") + + # Anything still meta (the popped sin-cos pos_embed, any non-checkpointed buffer) + # gets allocated on its target device and re-initialized to the right values. + materialized = _materialize_remaining_meta(model, device_map, dtype=torch.bfloat16) + if materialized: + log_rank0(f"[startup] Materialized {len(materialized)} meta param/buffer(s) post-load " + f"(first few: {materialized[:5]})") + + # init_moe() copies UND weights into the moe_gen slots. For inference from a fully- + # trained Lance checkpoint, the moe_gen weights are already loaded above — running + # init_moe now would either no-op (good) or clobber them with random/loaded weights. + # Skip unconditionally on the meta-init path. + if training_args.copy_init_moe: + log_rank0("[startup] Skipping init_moe(): full checkpoint already contains moe_gen weights.") + + # ===== Tokenizer + post-load patch-ups. ===== stage_start = time.perf_counter() log_rank0(f"[startup] Loading tokenizer: {model_args.model_path}") tokenizer: Qwen2Tokenizer = Qwen2Tokenizer.from_pretrained(model_args.model_path) - tokenizer, new_token_ids, num_new_tokens = add_special_tokens(tokenizer) log_stage("tokenizer load and special token init", stage_start, extra=f"num_new_tokens={num_new_tokens}") - # Initialize MoE before loading the checkpoint. - if training_args.copy_init_moe: - language_model.init_moe() - - init_from_model_path_if_needed(model, model_args) - - # Resize afterward to avoid checkpoint shape mismatches or overwritten weights. if num_new_tokens > 0: model.language_model.resize_token_embeddings(len(tokenizer)) model.config.llm_config.vocab_size = len(tokenizer) @@ -534,7 +725,7 @@ def log_stage(stage_name: str, start_time: float, extra: str = ""): from common.model.hacks import hack_qwen2_5_vl_config language_model = hack_qwen2_5_vl_config(language_model) - image_token_id = language_model.config.video_token_id # image_token_id # <|image_pad|> + image_token_id = language_model.config.video_token_id # <|image_pad|> new_token_ids.update({"image_token_id": image_token_id}) model.update_tokenizer(tokenizer=tokenizer) @@ -549,7 +740,6 @@ def log_stage(stage_name: str, start_time: float, extra: str = ""): else: assert model.language_model.get_input_embeddings().weight.data.data_ptr() != model.language_model.get_output_embeddings().weight.data.data_ptr(), 'tie_word_embeddings conflict' - model = model.to(device=DEVICE, dtype=torch.bfloat16) model.eval() if vae_model is not None and hasattr(vae_model, "eval"): vae_model.eval() diff --git a/modeling/lance/modeling_utils.py b/modeling/lance/modeling_utils.py index 4b24559..0c8e2a5 100644 --- a/modeling/lance/modeling_utils.py +++ b/modeling/lance/modeling_utils.py @@ -160,6 +160,51 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states +def _torch_1d_sincos(dim: int, pos: torch.Tensor) -> torch.Tensor: + """Torch port of get_1d_sincos_pos_embed_from_grid; runs on pos.device in fp32.""" + assert dim % 2 == 0 + device = pos.device + omega = torch.arange(dim // 2, dtype=torch.float32, device=device) + omega = 1.0 / (10000.0 ** (omega / (dim / 2.0))) # (D/2,) + out = pos.reshape(-1)[:, None] * omega[None, :] # (M, D/2) + return torch.cat([torch.sin(out), torch.cos(out)], dim=1) # (M, D) + + +def _torch_2d_sincos(embed_dim: int, grid_size: int, device, dtype) -> torch.Tensor: + grid_h = torch.arange(grid_size, dtype=torch.float32, device=device) + grid_w = torch.arange(grid_size, dtype=torch.float32, device=device) + # `np.meshgrid(grid_w, grid_h)` puts width first; torch's `indexing="xy"` matches that. + gw, gh = torch.meshgrid(grid_w, grid_h, indexing="xy") + emb_h = _torch_1d_sincos(embed_dim // 2, gh.flatten()) + emb_w = _torch_1d_sincos(embed_dim // 2, gw.flatten()) + return torch.cat([emb_h, emb_w], dim=1).to(dtype) + + +def _torch_3d_sincos(embed_dim: int, t: int, h: int, w: int, device, dtype) -> torch.Tensor: + """Torch port of get_3d_sincos_pos_embed; computes on `device` in fp32. + + The numpy original allocates three intermediate fp64 arrays of shape (t*h*w, ~D/3) + each plus a concatenated copy, peaking around 4 GB of CPU RAM for Lance's defaults + (t=31, h=w=64, D=2048). Doing the same work in fp32 on a GPU is ~free and avoids + the spike that OOMs the 8 GB host post-load. + """ + assert embed_dim % 2 == 0 + d = embed_dim // 3 + d = d if d % 2 == 0 else d - 1 + dim_t, dim_h = d, d + dim_w = embed_dim - 2 * d + assert dim_w % 2 == 0 + + grid_t = torch.arange(t, dtype=torch.float32, device=device) + grid_h = torch.arange(h, dtype=torch.float32, device=device) + grid_w = torch.arange(w, dtype=torch.float32, device=device) + tt, hh, ww = torch.meshgrid(grid_t, grid_h, grid_w, indexing="ij") + emb_t = _torch_1d_sincos(dim_t, tt.flatten()) + emb_h = _torch_1d_sincos(dim_h, hh.flatten()) + emb_w = _torch_1d_sincos(dim_w, ww.flatten()) + return torch.cat([emb_t, emb_h, emb_w], dim=1).to(dtype) + + class PositionEmbedding(nn.Module): def __init__(self, max_num_patch_per_side, hidden_size): super().__init__() @@ -172,9 +217,18 @@ def __init__(self, max_num_patch_per_side, hidden_size): self._init_weights() def _init_weights(self): - # Initialize (and freeze) pos_embed by sin-cos embedding: - pos_embed = get_2d_sincos_pos_embed(self.hidden_size, self.max_num_patch_per_side) - self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float()) + # Skip when constructed under accelerate.init_empty_weights() — the param is on + # meta and cannot be copied into. The caller must materialize the param on a real + # device and re-invoke _init_weights() after dispatch. + if self.pos_embed.is_meta: + return + with torch.no_grad(): + self.pos_embed.data.copy_( + _torch_2d_sincos( + self.hidden_size, self.max_num_patch_per_side, + device=self.pos_embed.device, dtype=self.pos_embed.dtype, + ) + ) def forward(self, position_ids): return self.pos_embed[position_ids] @@ -190,9 +244,16 @@ def __init__(self, max_latent_num_frames, max_latent_size, hidden_size): self._init_weights() def _init_weights(self): - # Initialize (and freeze) pos_embed by sin-cos embedding: - pos_embed = get_3d_sincos_pos_embed(self.hidden_size, self.max_num_latent_frames, self.max_latent_size, self.max_latent_size) - self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float()) + # See PositionEmbedding._init_weights for the meta-tensor rationale. + if self.pos_embed.is_meta: + return + with torch.no_grad(): + self.pos_embed.data.copy_( + _torch_3d_sincos( + self.hidden_size, self.max_num_latent_frames, self.max_latent_size, self.max_latent_size, + device=self.pos_embed.device, dtype=self.pos_embed.dtype, + ) + ) def forward(self, position_ids): return self.pos_embed[position_ids] From 20709a10b8e02a277fc2e77f8bd724048f2702b1 Mon Sep 17 00:00:00 2001 From: johbau Date: Fri, 29 May 2026 09:30:53 +0200 Subject: [PATCH 2/5] inference: model-parallel sharding across multiple GPUs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Builds on the low-RAM streaming load (previous commit) to enable single- process, model-parallel inference across N GPUs. Lets Lance run on hosts where no single card has enough VRAM but the aggregate does — e.g. 5 × RTX 3060 (60 GB) for Lance_3B + ViT + VAE. Changes: - `_build_lance_device_map(model, num_gpus)` spreads LLM transformer layers across cuda:0..N-1 with cuda:0 getting a reduced share (it also hosts embed/lm_head/ViT/VAE/connectors). A safety net pins any uncovered parameter to cuda:0 so future top-level MoT siblings don't break dispatch. - `accelerate.dispatch_model` installs pre/post forward hooks that move activations between cards as needed. The streaming loader from the previous commit already routes each tensor onto its target shard at load time; this commit just attaches the runtime hooks. - Replace flex_attention with eager-SDPA in all three call sites in lance.py. flex_attention's BlockMask captures device-specific tensors that dynamo refuses to combine with Q/K/V from a different shard. qwen2_navit.py already has an isinstance(attention_mask, List) → SDPA branch that crosses devices cleanly via accelerate hooks. A new helper `_flex_mask_to_dense_list` evaluates the flex mask function on a (q_idx, kv_idx) meshgrid to produce that List. - Align parent-class Python combine sites that dispatch_model's hooks can't reach: - Pin ViT to cuda:0 so its output matches embed_tokens for the inline `masked_scatter` in validation_video_to_text. - Move `packed_sequence` back to the index tensor's device after the Qwen2Model layer loop, so the final norm/lm_head indexing combine works. - Handle both Tensor and List in the per-layer `attention_mask.to(device=…)` call. - Launcher and config: - inference_lance.sh: NUM_GPUS=5 default (now means shard count, not data-parallel rank count), `--num_processes 1`, forwards `--shard_num_gpus $NUM_GPUS`. - InferenceArguments: new `shard_num_gpus: int = 0`. 0 = use all visible GPUs; >0 caps to that many. Behavior on a 1-GPU host is unchanged — the device map collapses to "everything on cuda:0" and dispatch_model is skipped. The dense-mask SDPA replacement runs unconditionally; flex_attention can be gated behind a flag if a single-card user wants the compiled kernel back. Memory profile on a 5 × 3060 / 8 GB RAM host: cuda:0 ~6 GB (3 LLM layers + ViT + VAE + embed + lm_head + extras) cuda:1 ~3 GB (8 LLM layers) cuda:2 ~3 GB (8 LLM layers) cuda:3 ~3 GB (8 LLM layers) cuda:4 ~3 GB (8 LLM layers) Smoke test (x2t_image, 768 res, 6 cases) completes successfully. About 67 s per understanding batch — slow because activations shuttle across PCIe between cards and SDPA is eager. The point is fitting the model on this hardware, not throughput. See SHARDED_LOAD.md for the full rationale and per-file summary. Co-Authored-By: Claude Opus 4.7 --- SHARDED_LOAD.md | 159 ++++++++++++++++++++++++++++++++++ config/config_factory.py | 5 ++ inference_lance.py | 105 +++++++++++++++++++--- inference_lance.sh | 9 +- modeling/lance/lance.py | 44 ++++++++-- modeling/lance/qwen2_navit.py | 15 +++- 6 files changed, 316 insertions(+), 21 deletions(-) create mode 100644 SHARDED_LOAD.md diff --git a/SHARDED_LOAD.md b/SHARDED_LOAD.md new file mode 100644 index 0000000..0bbc213 --- /dev/null +++ b/SHARDED_LOAD.md @@ -0,0 +1,159 @@ +# Sharded model-parallel inference + +This change adds **single-process, model-parallel inference across N GPUs**. +It's the second half of the work that lets Lance run on an 8 GB system RAM ++ 5 × RTX 3060 host. The first half — getting the model loaded without +materializing fp32 weights on CPU — is in +[`LOW_RAM_LOAD.md`](LOW_RAM_LOAD.md) and must be in place first; this change +builds on the streaming loader and `init_empty_weights()` infrastructure +introduced there. + +## Why + +A 3B Lance model in bf16 is ~6 GB of weights, plus ViT (~1.2 GB) and the +WanVideoVAE (~2-3 GB) for generation tasks. None of that fits on a single +12 GB 3060 with room for activations. But it does fit comfortably across +the **60 GB aggregate VRAM** of five 3060s if the LLM's transformer layers +are sharded across cards. Doing that requires: + +- A device map that puts the right layers on the right cards. +- Cross-card forward hooks (`accelerate.dispatch_model`). +- Source-side fixes everywhere Lance's existing code assumes everything + lives on one device. + +The previous launcher ran `accelerate launch --num_processes=$NUM_GPUS`, +which is **data-parallel**: each process gets its own full copy of the +model. That doesn't help here — each rank still needs to fit the whole +model, and CPU RAM pressure goes up N× (one copy materialized per process). +For inference we want model-parallel: one process, model split across cards. + +## What changed + +### 1. Device map + +`_build_lance_device_map(model, num_gpus)` (in `inference_lance.py`) builds +a `{module_name: gpu_index}` map: + +- LLM transformer layers split across `cuda:1..N-1`, with `cuda:0` getting + a **reduced** share (about half of the even split) because cuda:0 also + hosts the entry/exit modules and the WanVideoVAE. +- `embed_tokens`, `norm`, `norm_moe_gen` (MoT generation-branch sibling), + `rotary_emb`, `lm_head` pinned to `cuda:0` — these are the token-flow + boundaries. +- ViT pinned to `cuda:0` because `Lance.validation_video_to_text` combines + ViT output with `embed_tokens` output via `masked_scatter` *inline* + (lance.py around line 1010). That combine happens in parent-class Python, + not inside a submodule's `forward()`, so accelerate's hooks don't get a + chance to align devices. +- Connector / time_embedder / vae2llm / llm2vae / latent_pos_embed all on + cuda:0 (small). +- Safety net: any parameter not covered by an explicit prefix lands on + cuda:0. Without this, `dispatch_model` rejects the device_map with a + hard error the first time someone adds a top-level MoT sibling we didn't + anticipate. + +### 2. `accelerate.dispatch_model` + +Installs pre/post forward hooks on each dispatched submodule so activations +get moved to the right card before each `.forward()`. After this point the +model must **not** be `.to()`-d (that would collapse every shard onto one +card). The per-batch `fsdp_model.to(device, dtype=bf16)` call in +`validate_on_fixed_batch` was already removed in the low-RAM change. + +The streaming loader from the low-RAM change already supports a non-empty +device_map — every tensor is routed onto the GPU dictated by +`_device_for_param(name, device_map)` at load time, so the model is on +its shards *before* hooks are attached. + +### 3. Replace `flex_attention` with eager-SDPA dense masks + +`flex_attention`'s `BlockMask` captures device-specific tensors when it's +built. Under model parallelism, a layer on `cuda:>0` calls `flex_attention` +with `q/k/v` on that shard and a mask whose captures live on `cuda:0`; +dynamo's tracer refuses to combine them with +`Unhandled FakeTensor Device Propagation`. + +The fix uses a path that already exists in `qwen2_navit.py`: when +`attention_mask` is a `List`, the attention forward iterates per-sample and +runs `torch.nn.functional.scaled_dot_product_attention` instead of +`flex_attention`. SDPA has no dynamo trace and crosses devices cleanly via +the standard accelerate hooks. + +`_flex_mask_to_dense_list(mask_fn, seqlen, device, dtype)` evaluates the +flex mask function on a meshgrid of `(q_idx, kv_idx)` to get a bool mask, +converts to additive float (`-inf` where masked), and returns it as a +single-element `List`. All three `create_block_mask` call sites in +`lance.py` (one in `process_attention_mask`, one in the main `forward`, +one in `validation_video_to_text`) route through this helper. + +### 4. Parent-class device-alignment fixes + +A few places in `lance.py` and `qwen2_navit.py` combine tensors from +different shards in parent-Python (not inside a submodule's `forward()`), +which accelerate's hooks cannot reach. Each was fixed locally: + +- `qwen2_navit.py:619` — at the start of each layer's `forward_train`, + `attention_mask.to(device=packed_sequence_.device)` now handles both + the old single-Tensor BlockMask path and the new List-of-Tensors SDPA + path. +- `qwen2_navit.py:901` — after the layer loop in `Qwen2Model.forward_train`, + `packed_sequence` lives on whichever shard ran the last layer (e.g. + `cuda:N-1`). The index tensors and the final `norm` live on `cuda:0`. + Added one `packed_sequence.to(packed_und_token_indexes.device)` to + consolidate before the indexing-based combine. + +### 5. Launcher and config + +- `inference_lance.sh`: + - `NUM_GPUS=5` default (was 1). It now means "number of shards", not + "number of data-parallel processes". + - `accelerate launch --num_processes 1` always — model parallelism is + inside one process. + - Forwards `--shard_num_gpus $NUM_GPUS` to the Python side. +- `config/config_factory.py`: adds `shard_num_gpus: int = 0` to + `InferenceArguments`. `0` means "use all visible GPUs" + (`torch.cuda.device_count()`); >0 caps to that many. + +## Memory profile (Lance_3B, x2t_image, 5 × 3060) + +| Card | Holds | VRAM | +|---|---|---| +| cuda:0 | 3 LLM layers + embed + lm_head + ViT + VAE + latent_pos_embed + connectors + CUDA context | ~6 GB | +| cuda:1 | 8 LLM layers | ~3 GB | +| cuda:2 | 8 LLM layers | ~3 GB | +| cuda:3 | 8 LLM layers | ~3 GB | +| cuda:4 | 8 LLM layers | ~3 GB | + +The smoke test (`x2t_image`, 768 res, 6 cases) completes successfully. + +## Performance + +About 67 s per understanding batch at 768 resolution on the 5×3060 rig. +This is *slow* because: + +- Every layer's attention runs eager SDPA with a dense mask instead of + `flex_attention`'s compiled kernel. +- Activations shuttle across PCIe between cards via `dispatch_model`'s + hooks at each layer boundary. + +The point of this change is **fitting** the model on this hardware, not +throughput. A single A100 40 GB (cloud fallback) is the right move if you +need real speed. + +## What `main` users keep + +`shard_num_gpus=0` (the default) defers to `torch.cuda.device_count()`, +so on a 1-GPU host the device map collapses to "everything on cuda:0" +and `dispatch_model` is skipped. The dense-mask SDPA replacement does +run unconditionally — if you want `flex_attention` back for a single-card +setup, that's the one piece that's worth gating behind a flag. + +## File-by-file summary + +| File | Change | +|---|---| +| `inference_lance.py` | New `_build_lance_device_map`; `dispatch_model` import + call when sharding > 1; `shard_num_gpus` arg threading. | +| `inference_lance.sh` | `NUM_GPUS=5` default, `--num_processes 1`, passes `--shard_num_gpus`. | +| `config/config_factory.py` | Adds `shard_num_gpus: int = 0` to `InferenceArguments`. | +| `modeling/lance/lance.py` | New `_flex_mask_to_dense_list`; all three `create_block_mask` sites route through it. | +| `modeling/lance/qwen2_navit.py` | Layer `attention_mask.to(device=…)` handles List; `Qwen2Model.forward_train` moves `packed_sequence` back to the index device after the layer loop. | diff --git a/config/config_factory.py b/config/config_factory.py index 5797364..2946979 100644 --- a/config/config_factory.py +++ b/config/config_factory.py @@ -307,6 +307,11 @@ class InferenceArguments(TrainingArguments): use_KVcache: bool = False enhance_prompt: bool = False # Rewrite T2V prompts before inference when enabled. + # Model-parallel sharding for low-RAM hosts: + # 0 = use all visible GPUs (torch.cuda.device_count()). + # >0 = shard Lance's LLM layers across this many GPUs via accelerate.dispatch_model. + shard_num_gpus: int = 0 + @dataclass class EvaluationArguments(InferenceArguments): diff --git a/inference_lance.py b/inference_lance.py index f6e0823..36605c3 100644 --- a/inference_lance.py +++ b/inference_lance.py @@ -34,7 +34,7 @@ import struct import numpy as np from safetensors.torch import load_file -from accelerate import init_empty_weights +from accelerate import init_empty_weights, dispatch_model from accelerate.utils import set_module_tensor_to_device from data.dataset_base import DataConfig, simple_custom_collate @@ -137,6 +137,79 @@ def _resolve_lance_checkpoint(model_path_dir: str) -> str: ) +def _build_lance_device_map(model: "Lance", num_gpus: int) -> Dict[str, int]: + """Spread Lance's LLM transformer layers across `num_gpus` cards. + + cuda:0 is the "entry/exit" device for tokens and logits (embed + lm_head + norm), + and the WanVideoVAE always auto-lands on cuda:0 (its constructor calls + `get_device()` = cuda:LOCAL_RANK, which is cuda:0 in single-process mode and can't + easily be moved post-hoc). Those fixed-cost residents eat ~3-4 GB on cuda:0 before + a single LLM layer lands there, so we explicitly give cuda:0 a *reduced* layer + share when num_gpus >= 2 and park the ViT on the last GPU (typically the lightest + after layer-count remainder). + """ + num_layers = len(model.language_model.model.layers) + num_gpus = max(1, num_gpus) + device_map: Dict[str, int] = {} + + if num_gpus == 1: + # Everything on cuda:0. Almost certainly won't fit Lance_3B + VAE on a single + # 12 GB card — but that's the user's choice (smoke-test scenario only). + for i in range(num_layers): + device_map[f"language_model.model.layers.{i}"] = 0 + else: + # cuda:0 gets roughly half its even share; the remainder spreads across + # cuda:1..N-1. For 36 layers / 5 GPUs that's 3 on cuda:0 and 8-9 elsewhere. + gpu0_layer_count = max(1, num_layers // (2 * num_gpus)) + remaining = num_layers - gpu0_layer_count + other_gpus = num_gpus - 1 + layers_per_other = (remaining + other_gpus - 1) // other_gpus # ceil-div + for i in range(num_layers): + if i < gpu0_layer_count: + device_map[f"language_model.model.layers.{i}"] = 0 + else: + idx = i - gpu0_layer_count + gpu = 1 + min(idx // layers_per_other, other_gpus - 1) + device_map[f"language_model.model.layers.{i}"] = gpu + + # Token entry/exit and both MoT norms pinned to cuda:0. `norm_moe_gen` is the + # generation-branch sibling of `norm`; it must be on the same device because the + # forward path indexes a shared sequence and dispatches by token type. + device_map["language_model.model.embed_tokens"] = 0 + device_map["language_model.model.norm"] = 0 + if hasattr(model.language_model.model, "norm_moe_gen"): + device_map["language_model.model.norm_moe_gen"] = 0 + if hasattr(model.language_model.model, "rotary_emb"): + device_map["language_model.model.rotary_emb"] = 0 + device_map["language_model.lm_head"] = 0 + + # Lance heads. Small (a few MB each) except latent_pos_embed (~250 MB sin-cos); + # keep them all near the embed/connector on cuda:0. + for extra in ("connector", "time_embedder", "vae2llm", "llm2vae", + "latent_pos_embed", "task_embedding", "modality_embedding"): + if hasattr(model, extra) and getattr(model, extra) is not None: + device_map[extra] = 0 + + # ViT must live on cuda:0. Lance.validation_video_to_text combines ViT outputs + # with embed_tokens outputs via `masked_scatter` inline (lance.py:1010) — that + # combine happens in parent-class Python, not inside a submodule's forward(), so + # accelerate's hooks don't get a chance to align devices. cuda:0 now hosts only + # 3 LLM layers (instead of an even 8), so there's ~5 GB of headroom for the ViT + # (~1.2 GB) on top of the VAE/embed/lm_head residents. + if hasattr(model, "vit_model") and model.vit_model is not None: + device_map["vit_model"] = 0 + + # Safety net: any parameter not covered by an explicit prefix above (e.g. a future + # top-level MoT sibling we didn't anticipate) lands on cuda:0. Without this, + # accelerate.dispatch_model rejects the device_map with a hard error. + covered_prefixes = list(device_map.keys()) + for param_name, _ in model.named_parameters(): + if not any(param_name == p or param_name.startswith(p + ".") for p in covered_prefixes): + device_map[param_name] = 0 + + return device_map + + def _device_for_param(param_name: str, device_map: Dict[str, int]) -> int: """Find the device assignment for `param_name` by walking up its dotted path.""" parts = param_name.split(".") @@ -405,8 +478,9 @@ def validate_on_fixed_batch( save_path_gt: str = "", ): val_data = val_data_cpu.cuda(device).to_dict() - # No fsdp_model.to(device) needed: streaming load already placed weights on the - # target device in bf16. + # Do NOT call fsdp_model.to(device) here: the model is sharded across multiple GPUs + # via accelerate.dispatch_model, and .to() would collapse all shards onto one card. + # Weights are already bf16 from the streaming load. with torch.no_grad(), torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16): # Compute padded_latent. @@ -658,13 +732,14 @@ def log_stage(stage_name: str, start_time: float, extra: str = ""): ) log_stage("Lance meta-init", stage_start) - # Single-device load: every tensor is routed onto the current device (cuda:DEVICE). - # `_device_for_param` defaults to 0 for any name not in `device_map`, so an empty - # map is enough here. (A follow-up change adds a smart device-map builder that - # spreads layers across multiple GPUs.) - device_map: Dict[str, int] = {} + # ===== Decide how to shard across GPUs. ===== + num_visible_gpus = torch.cuda.device_count() + shard_n = inference_args.shard_num_gpus or num_visible_gpus + shard_n = max(1, min(shard_n, num_visible_gpus)) + log_rank0(f"[startup] Sharding Lance across {shard_n} GPU(s) (visible: {num_visible_gpus})") + device_map = _build_lance_device_map(model, shard_n) - # ===== Stream-load weights directly onto the GPU at bf16. ===== + # ===== Stream-load weights directly onto each shard's GPU at bf16. ===== # ViT weights live in a separate file; in the Lance wrapper they sit under vit_model.*. if training_args.visual_und: vit_safetensors = osp.join(model_args.vit_path, "vit.safetensors") @@ -704,8 +779,8 @@ def log_stage(stage_name: str, start_time: float, extra: str = ""): # init_moe() copies UND weights into the moe_gen slots. For inference from a fully- # trained Lance checkpoint, the moe_gen weights are already loaded above — running - # init_moe now would either no-op (good) or clobber them with random/loaded weights. - # Skip unconditionally on the meta-init path. + # init_moe now would either no-op (good) or clobber them with sharded cross-device + # state_dict() copies (bad). Skip unconditionally on the meta-init path. if training_args.copy_init_moe: log_rank0("[startup] Skipping init_moe(): full checkpoint already contains moe_gen weights.") @@ -717,6 +792,8 @@ def log_stage(stage_name: str, start_time: float, extra: str = ""): log_stage("tokenizer load and special token init", stage_start, extra=f"num_new_tokens={num_new_tokens}") if num_new_tokens > 0: + # Embedding and lm_head are both pinned to cuda:0 in the device_map, so + # resize_token_embeddings can do its in-place resize without crossing devices. model.language_model.resize_token_embeddings(len(tokenizer)) model.config.llm_config.vocab_size = len(tokenizer) model.language_model.config.vocab_size = len(tokenizer) @@ -740,6 +817,12 @@ def log_stage(stage_name: str, start_time: float, extra: str = ""): else: assert model.language_model.get_input_embeddings().weight.data.data_ptr() != model.language_model.get_output_embeddings().weight.data.data_ptr(), 'tie_word_embeddings conflict' + # ===== Attach cross-device hooks so activations flow between shards. ===== + # dispatch_model walks `device_map` and installs pre/post forward hooks that move + # activations to the right card before each submodule runs. After this point, the + # model must NOT be .to()'d as that would collapse the shards. + if shard_n > 1: + model = dispatch_model(model, device_map=device_map) model.eval() if vae_model is not None and hasattr(vae_model, "eval"): vae_model.eval() diff --git a/inference_lance.sh b/inference_lance.sh index 3a5b959..4c6db5d 100755 --- a/inference_lance.sh +++ b/inference_lance.sh @@ -5,7 +5,11 @@ cd "$SCRIPT_DIR" source "$SCRIPT_DIR/benchmarks/sample_env.sh" # ========================= Inference Parameters ========================= -NUM_GPUS=${NUM_GPUS:-1} +# NUM_GPUS is the number of GPUs to *shard* Lance across (model-parallel), not the +# number of replicas. Launch always runs a single process; the Python side uses +# accelerate.dispatch_model to split the LLM's transformer layers across NUM_GPUS +# cards. Default matches the 5×3060 host; override if running on fewer cards. +NUM_GPUS=${NUM_GPUS:-5} TASK_NAME=${TASK_NAME:-x2t_image} # t2i | image_edit | t2v | i2v | video_edit | x2t_image | x2t_video @@ -122,12 +126,13 @@ fi accelerate launch \ --num_machines $NUM_MACHINES \ - --num_processes $TOTAL_RANK \ + --num_processes 1 \ --machine_rank $MACHINE_RANK \ --main_process_ip $MAIN_PROCESS_IP \ --main_process_port $MAIN_PROCESS_PORT \ --mixed_precision bf16 \ inference_lance.py \ + --shard_num_gpus $NUM_GPUS \ --model_path "$MODEL_PATH" \ --vit_type qwen_2_5_vl_original \ --llm_qk_norm true \ diff --git a/modeling/lance/lance.py b/modeling/lance/lance.py index c425fe6..18d1746 100644 --- a/modeling/lance/lance.py +++ b/modeling/lance/lance.py @@ -40,6 +40,36 @@ from data.common import shift_position_ids from copy import deepcopy +def _flex_mask_to_dense_list( + mask_fn, + seqlen: int, + device, + dtype: torch.dtype = torch.bfloat16, +): + """Convert flex_attention's mask function (a closure over device-specific tensors) + into a List[Tensor] of dense additive masks usable by scaled_dot_product_attention. + + flex_attention is incompatible with accelerate's model-parallel dispatch: the + BlockMask captures tensors on the device where it was built, and dynamo's tracer + refuses to combine them with Q/K/V tensors on a different shard. The attention + forward already has a List-of-masks branch that runs eager SDPA per sample (see + qwen2_navit.py `if isinstance(attention_mask, List)`), and SDPA crosses devices + cleanly via the standard accelerate hooks. Calling this helper at every + create_block_mask site funnels the layer attention into that SDPA branch. + """ + q_idx = torch.arange(seqlen, device=device) + kv_idx = torch.arange(seqlen, device=device) + qq, kk = torch.meshgrid(q_idx, kv_idx, indexing="ij") + # `and_masks`/`or_masks` from flex_attention call `b.new_ones(...)` on the batch + # arg, so b/h must be tensors (not ints). Sub-masks ignore b/h anyway. + b = torch.zeros((), dtype=torch.long, device=device) + h = torch.zeros((), dtype=torch.long, device=device) + bool_mask = mask_fn(b, h, qq, kk) + dense = torch.zeros((seqlen, seqlen), dtype=dtype, device=device) + dense.masked_fill_(~bool_mask, float("-inf")) + return [dense] + + class LanceConfig(PretrainedConfig): def __init__( self, @@ -140,10 +170,8 @@ def process_attention_mask(self, current_attn_modes, current_split_lens, current current_attn_modes_ = ["full" if mode_ in ["full_noise", "full_noise_target"] else mode_ for mode_ in current_attn_modes] sparse_mask = create_sparse_mask(current_seq_len, current_split_lens, current_attn_modes_, device) current_seq_len_sum = sum(current_seq_len) - attention_mask = create_block_mask( - sparse_mask, B=1, H=self.num_heads, Q_LEN=current_seq_len_sum, KV_LEN=current_seq_len_sum, device=device, BLOCK_SIZE=BLOCK_SIZE, _compile=False - ) - return attention_mask + # Dense mask List → SDPA branch in qwen2_navit.py (model-parallel safe). + return _flex_mask_to_dense_list(sparse_mask, current_seq_len_sum, device) def forward( self, @@ -239,8 +267,9 @@ def forward( if nested_attention_masks is None: attn_modes_ = ["full" if mode=="full_noise" else mode for mode in attn_modes] sparse_mask = create_sparse_mask(sample_lens, split_lens, attn_modes_, packed_text_embedding.device) - seqlen = sum(sample_lens) - attention_mask = create_block_mask(sparse_mask, B=1, H=self.num_heads, Q_LEN=seqlen, KV_LEN=seqlen, device=packed_text_embedding.device, BLOCK_SIZE=BLOCK_SIZE, _compile=True) + seqlen = sum(sample_lens) # 始终是max_num_tokens + # Dense mask List → SDPA branch (model-parallel safe). + attention_mask = _flex_mask_to_dense_list(sparse_mask, seqlen, packed_text_embedding.device) else: attention_mask = nested_attention_masks @@ -907,7 +936,8 @@ def validation_video_to_text( current_text_len = (step + 1) - (num_text_ids - 1) current_split_lens_ = current_split_lens + [current_text_len, num_pad + 1 - current_text_len] sparse_mask = create_sparse_mask(current_sample_lens, current_split_lens_, current_attn_modes_, device) - attention_mask = create_block_mask(sparse_mask, B=1, H=self.num_heads, Q_LEN=seqlen, KV_LEN=seqlen, device=device, BLOCK_SIZE=BLOCK_SIZE, _compile=False) + # Dense mask List → SDPA branch (model-parallel safe). + attention_mask = _flex_mask_to_dense_list(sparse_mask, seqlen, device) extra_inputs = {"mode": "und"} if self.use_moe: diff --git a/modeling/lance/qwen2_navit.py b/modeling/lance/qwen2_navit.py index 8f3c8d8..1d7b1ed 100644 --- a/modeling/lance/qwen2_navit.py +++ b/modeling/lance/qwen2_navit.py @@ -616,7 +616,13 @@ def forward_train( # Self Attention if attention_mask is not None: - attention_mask = attention_mask.to(device=packed_sequence_.device) + # Mask may be a BlockMask (single tensor) or a List of per-sample dense + # masks (model-parallel path that routes attention through SDPA). Move + # each element onto this layer's shard so SDPA's device check passes. + if isinstance(attention_mask, list): + attention_mask = [m.to(device=packed_sequence_.device) for m in attention_mask] + else: + attention_mask = attention_mask.to(device=packed_sequence_.device) packed_sequence_ = self.self_attn( packed_sequence=packed_sequence_, @@ -892,6 +898,13 @@ def forward_train( **kwargs, ) + # Model-parallel: after the layer loop, packed_sequence lives on whichever + # shard ran the last layer (e.g. cuda:4). The index tensors and norm modules + # are pinned to cuda:0. Move the sequence back so the parent-level indexing + # below combines tensors on a single device. + if self.use_moe and packed_sequence.device != packed_und_token_indexes.device: + packed_sequence = packed_sequence.to(packed_und_token_indexes.device) + if self.use_moe: packed_sequence_ = torch.zeros_like(packed_sequence) packed_sequence_[packed_und_token_indexes] = self.norm(packed_sequence[packed_und_token_indexes]).to(dtype=packed_sequence.dtype) From 3a19b3dc451986ee0759f3b98dea6a04329ad2d7 Mon Sep 17 00:00:00 2001 From: johbau Date: Fri, 19 Jun 2026 08:54:07 +0200 Subject: [PATCH 3/5] inference: make the generation path (t2i/t2v) sharding-safe MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The diffusion + VAE-decode generation path uses forward_inference (KVcache) and the WanVideoVAE, neither exercised by the understanding path. Three model-parallel fixes were needed to run t2i/t2v across multiple GPUs: - forward_inference gen-mode norm (qwen2_navit.py): inference-mode twin of the forward_train post-loop fix. After the decoder-layer loop the sequence is on the last shard, but the index tensors + norm/norm_moe_gen are on cuda:0. Added a .to(packed_text_indexes.device) guard before the gen-mode index-put. Runs every diffusion timestep. (The KVcache path attends via flash_attn_varlen_func, not flex_attention, so the dense-mask change is not exercised here.) - VAE device override (modeling/vae/wan/model.py): WanVideoVAE hard-coded get_device()=cuda:0. It now accepts a device= (default get_device(), so single-GPU is unchanged) used by configure_vae_model/vae_encode/vae_decode, so the VAE can live on a less-crowded card. - Dedicated VAE card (inference_lance.py): the video VAE decode's conv activations (~9 GB at 480^2 / 17 frames) won't fit on a card that also holds LLM layers. For generation tasks, _build_lance_device_map now takes reserve_last_for_vae and shards the LLM across the first N-1 cards, leaving the last card empty of LLM weights; the VAE is built there. Confirmed on the 8 GB-RAM / 5x3060 host: t2i produces a clean PNG and t2v produces a valid 480x480 17-frame h264 mp4 (1.42 s). Note: 768^2 video decode exceeds a single 12 GB card even when dedicated — that needs VAE decode tiling (not implemented). The launcher's VIDEO_HEIGHT/WIDTH default to 768, so pass --VIDEO_HEIGHT 480 --VIDEO_WIDTH 480 for t2v on 12 GB cards. See SHARDED_LOAD.md. Co-Authored-By: Claude Opus 4.7 --- SHARDED_LOAD.md | 62 ++++++++++++++++++++++++++++++----- inference_lance.py | 62 ++++++++++++++++++++++------------- modeling/lance/qwen2_navit.py | 9 +++++ modeling/vae/wan/model.py | 5 +++ 4 files changed, 106 insertions(+), 32 deletions(-) diff --git a/SHARDED_LOAD.md b/SHARDED_LOAD.md index 0bbc213..905b604 100644 --- a/SHARDED_LOAD.md +++ b/SHARDED_LOAD.md @@ -114,17 +114,60 @@ which accelerate's hooks cannot reach. Each was fixed locally: `InferenceArguments`. `0` means "use all visible GPUs" (`torch.cuda.device_count()`); >0 caps to that many. -## Memory profile (Lance_3B, x2t_image, 5 × 3060) +### 6. Generation tasks (t2i / t2v): the diffusion + VAE-decode path + +The generation tasks exercise code the understanding path doesn't, and each +needed a model-parallel fix: + +- **`forward_inference` gen-mode norm** (`qwen2_navit.py`, the `mode=="gen"` + branch after the decoder-layer loop). This is the inference-mode twin of the + `forward_train` fix above: after the layer loop `packed_query_sequence` is on + the last shard, but `packed_text_indexes`/`packed_vae_token_indexes` and the + `norm`/`norm_moe_gen` modules are on cuda:0. Added a + `.to(packed_text_indexes.device)` guard before the index-put. This runs on + every diffusion timestep. (The KVcache generation path attends via + `flash_attn_varlen_func`, *not* flex_attention, so the dense-mask change in §3 + isn't even exercised here — no regression risk.) + +- **VAE placement** (`modeling/vae/wan/model.py` + `inference_lance.py`). The + WanVideoVAE used to hard-code `get_device()` = cuda:0. `WanVideoVAE` now + accepts a `device=` (default `get_device()`, so single-GPU is unchanged), and + the launcher builds it on the **last** shard. + +- **Dedicated VAE card** (`_build_lance_device_map(..., reserve_last_for_vae=True)` + for generation tasks). The video VAE decode's conv activations are large + enough (~9 GB at 480² / 17 frames) that they won't fit on a card also holding + LLM layers. For generation, the LLM is sharded across the first `N-1` cards + and the last card is left empty of LLM weights so the VAE decode has a + near-full 12 GB to itself. + +**Resolution limit (important):** even a dedicated 12 GB card can't decode a +768²×17-frame video — that single-chunk conv activation peaks just over 12 GB. +480²×17 frames fits with ~2 GB to spare. Larger frames/resolution would need +VAE **decode tiling** (spatial patches with overlap-blend), which is not +implemented here. Note the launcher's `VIDEO_HEIGHT`/`VIDEO_WIDTH` default to +**768**, independent of `--RESOLUTION`; pass `--VIDEO_HEIGHT 480 --VIDEO_WIDTH 480` +for t2v on a 12 GB card. + +## Memory profile + +### Understanding (`x2t_image` / `x2t_video`), Lance_3B, 5 × 3060 | Card | Holds | VRAM | |---|---|---| -| cuda:0 | 3 LLM layers + embed + lm_head + ViT + VAE + latent_pos_embed + connectors + CUDA context | ~6 GB | -| cuda:1 | 8 LLM layers | ~3 GB | -| cuda:2 | 8 LLM layers | ~3 GB | -| cuda:3 | 8 LLM layers | ~3 GB | -| cuda:4 | 8 LLM layers | ~3 GB | +| cuda:0 | ~3 LLM layers + embed + lm_head + ViT + VAE + latent_pos_embed + connectors + CUDA context | ~6 GB | +| cuda:1–4 | ~8 LLM layers each | ~3 GB each | -The smoke test (`x2t_image`, 768 res, 6 cases) completes successfully. +### Generation (`t2v`, reserve-VAE-card), Lance_3B_Video, 480² / 17 frames + +| Card | Holds | VRAM | +|---|---|---| +| cuda:0 | ~4 LLM layers + embed + lm_head + ViT + connectors | ~5 GB | +| cuda:1–3 | ~10–11 LLM layers each | ~3.7 GB each | +| cuda:4 | VAE only (decode peaks here) | ~0.8 GB idle → ~5 GB during decode | + +Smoke tests confirmed: `x2t_image`, `x2t_video`, `t2i`, and `t2v` (480², 17 +frames, 1.42 s mp4) all complete on the 8 GB-RAM / 5×3060 host. ## Performance @@ -152,8 +195,9 @@ setup, that's the one piece that's worth gating behind a flag. | File | Change | |---|---| -| `inference_lance.py` | New `_build_lance_device_map`; `dispatch_model` import + call when sharding > 1; `shard_num_gpus` arg threading. | +| `inference_lance.py` | New `_build_lance_device_map` (with `reserve_last_for_vae`); `dispatch_model` import + call when sharding > 1; `shard_num_gpus` arg threading; VAE built on the last shard. | | `inference_lance.sh` | `NUM_GPUS=5` default, `--num_processes 1`, passes `--shard_num_gpus`. | | `config/config_factory.py` | Adds `shard_num_gpus: int = 0` to `InferenceArguments`. | | `modeling/lance/lance.py` | New `_flex_mask_to_dense_list`; all three `create_block_mask` sites route through it. | -| `modeling/lance/qwen2_navit.py` | Layer `attention_mask.to(device=…)` handles List; `Qwen2Model.forward_train` moves `packed_sequence` back to the index device after the layer loop. | +| `modeling/lance/qwen2_navit.py` | Layer `attention_mask.to(device=…)` handles List; `Qwen2Model.forward_train` and the `forward_inference` gen-mode branch move the sequence back to the index device after the layer loop. | +| `modeling/vae/wan/model.py` | `WanVideoVAE` accepts a `device=` override (default `get_device()`); `configure_vae_model`/`vae_encode`/`vae_decode` use it, so the VAE can live on a card other than cuda:0. | diff --git a/inference_lance.py b/inference_lance.py index 36605c3..eb3eed8 100644 --- a/inference_lance.py +++ b/inference_lance.py @@ -137,32 +137,38 @@ def _resolve_lance_checkpoint(model_path_dir: str) -> str: ) -def _build_lance_device_map(model: "Lance", num_gpus: int) -> Dict[str, int]: - """Spread Lance's LLM transformer layers across `num_gpus` cards. - - cuda:0 is the "entry/exit" device for tokens and logits (embed + lm_head + norm), - and the WanVideoVAE always auto-lands on cuda:0 (its constructor calls - `get_device()` = cuda:LOCAL_RANK, which is cuda:0 in single-process mode and can't - easily be moved post-hoc). Those fixed-cost residents eat ~3-4 GB on cuda:0 before - a single LLM layer lands there, so we explicitly give cuda:0 a *reduced* layer - share when num_gpus >= 2 and park the ViT on the last GPU (typically the lightest - after layer-count remainder). +def _build_lance_device_map(model: "Lance", num_gpus: int, reserve_last_for_vae: bool = False) -> Dict[str, int]: + """Spread Lance's LLM transformer layers across the available cards. + + cuda:0 is the "entry/exit" device for tokens and logits (embed + lm_head + norm) + and also hosts the ViT. Those fixed-cost residents eat ~2-3 GB on cuda:0 before a + single LLM layer lands there, so we give cuda:0 a *reduced* layer share. + + `reserve_last_for_vae`: when True (generation tasks), the LLM is sharded across + only the first `num_gpus - 1` cards, leaving the last GPU empty of LLM weights so + the WanVideoVAE (built on that card) has a near-full 12 GB for its decode. The + video VAE decode's conv activations (~9 GB at 480p/17 frames) won't fit on a card + that also holds LLM layers, so a dedicated card is the simplest robust fix. """ num_layers = len(model.language_model.model.layers) num_gpus = max(1, num_gpus) + + # Number of cards the LLM may use. Reserve the last one for the VAE on generation. + llm_gpus = num_gpus - 1 if (reserve_last_for_vae and num_gpus >= 2) else num_gpus + llm_gpus = max(1, llm_gpus) + device_map: Dict[str, int] = {} - if num_gpus == 1: - # Everything on cuda:0. Almost certainly won't fit Lance_3B + VAE on a single - # 12 GB card — but that's the user's choice (smoke-test scenario only). + if llm_gpus == 1: + # All LLM layers on cuda:0 (single-GPU, or 2-GPU generation with VAE on cuda:1). for i in range(num_layers): device_map[f"language_model.model.layers.{i}"] = 0 else: # cuda:0 gets roughly half its even share; the remainder spreads across - # cuda:1..N-1. For 36 layers / 5 GPUs that's 3 on cuda:0 and 8-9 elsewhere. - gpu0_layer_count = max(1, num_layers // (2 * num_gpus)) + # cuda:1..llm_gpus-1. For 36 layers / 4 LLM cards that's 4 on cuda:0, ~11 each. + gpu0_layer_count = max(1, num_layers // (2 * llm_gpus)) remaining = num_layers - gpu0_layer_count - other_gpus = num_gpus - 1 + other_gpus = llm_gpus - 1 layers_per_other = (remaining + other_gpus - 1) // other_gpus # ceil-div for i in range(num_layers): if i < gpu0_layer_count: @@ -193,9 +199,8 @@ def _build_lance_device_map(model: "Lance", num_gpus: int) -> Dict[str, int]: # ViT must live on cuda:0. Lance.validation_video_to_text combines ViT outputs # with embed_tokens outputs via `masked_scatter` inline (lance.py:1010) — that # combine happens in parent-class Python, not inside a submodule's forward(), so - # accelerate's hooks don't get a chance to align devices. cuda:0 now hosts only - # 3 LLM layers (instead of an even 8), so there's ~5 GB of headroom for the ViT - # (~1.2 GB) on top of the VAE/embed/lm_head residents. + # accelerate's hooks don't get a chance to align devices. cuda:0 gets a reduced + # LLM-layer share precisely so there's headroom for the ViT + embed + lm_head. if hasattr(model, "vit_model") and model.vit_model is not None: device_map["vit_model"] = 0 @@ -696,9 +701,16 @@ def log_stage(stage_name: str, start_time: float, extra: str = ""): if training_args.visual_gen: # WanVideoVAE itself uses torch.device("meta") + assign-load internally, so it # doesn't contribute to the CPU RAM spike. Built eagerly so vae_config is real. + # Place it on the lightest shard (the last GPU) when sharding across >1 card: + # cuda:0 is the most crowded device and the video VAE decode's conv activations + # OOM it. On a single GPU this resolves to cuda:0 (unchanged behavior). + num_visible_gpus = torch.cuda.device_count() + shard_n = inference_args.shard_num_gpus or num_visible_gpus + shard_n = max(1, min(shard_n, num_visible_gpus)) + vae_device = torch.device("cuda", shard_n - 1) stage_start = time.perf_counter() - log_rank0("[startup] Initializing VAE") - vae_model = WanVideoVAE() + log_rank0(f"[startup] Initializing VAE on {vae_device}") + vae_model = WanVideoVAE(device=vae_device) vae_config: AutoEncoderParams = deepcopy(vae_model.vae_config) log_stage("VAE init", stage_start) else: @@ -736,8 +748,12 @@ def log_stage(stage_name: str, start_time: float, extra: str = ""): num_visible_gpus = torch.cuda.device_count() shard_n = inference_args.shard_num_gpus or num_visible_gpus shard_n = max(1, min(shard_n, num_visible_gpus)) - log_rank0(f"[startup] Sharding Lance across {shard_n} GPU(s) (visible: {num_visible_gpus})") - device_map = _build_lance_device_map(model, shard_n) + # Generation tasks decode through the VAE, whose video decode needs a near-full + # card to itself; reserve the last GPU for it (the VAE was built there above). + reserve_vae = bool(training_args.visual_gen) and inference_args.task in GENERATION_TASKS and shard_n >= 2 + log_rank0(f"[startup] Sharding Lance across {shard_n} GPU(s) (visible: {num_visible_gpus}; " + f"reserve last GPU for VAE: {reserve_vae})") + device_map = _build_lance_device_map(model, shard_n, reserve_last_for_vae=reserve_vae) # ===== Stream-load weights directly onto each shard's GPU at bf16. ===== # ViT weights live in a separate file; in the Lance wrapper they sit under vit_model.*. diff --git a/modeling/lance/qwen2_navit.py b/modeling/lance/qwen2_navit.py index 1d7b1ed..e7d0f3b 100644 --- a/modeling/lance/qwen2_navit.py +++ b/modeling/lance/qwen2_navit.py @@ -969,6 +969,15 @@ def forward_inference( **kwargs, ) + # Model-parallel (inference twin of the forward_train fix): after the layer + # loop, packed_query_sequence lives on the last layer's shard (e.g. cuda:4), + # but the index tensors and norm modules are pinned to cuda:0. The gen-mode + # index-put below combines across devices, which accelerate's hooks can't + # reach (it's parent-level Python, not a submodule boundary). Move the + # sequence back to the index device first. + if self.use_moe and mode == "gen" and packed_query_sequence.device != packed_text_indexes.device: + packed_query_sequence = packed_query_sequence.to(packed_text_indexes.device) + if self.use_moe: if mode == "und": packed_query_sequence = self.norm(packed_query_sequence) diff --git a/modeling/vae/wan/model.py b/modeling/vae/wan/model.py index 77f65de..b76310f 100644 --- a/modeling/vae/wan/model.py +++ b/modeling/vae/wan/model.py @@ -43,6 +43,11 @@ def __init__(self, config_path: str = "", **kwargs) -> None: self.logger = self.__class__.__logger__ self.dtype = kwargs.get("dtype", torch.bfloat16) + # Allow the VAE to live on a card other than cuda:LOCAL_RANK. Under + # model-parallel sharding, cuda:0 is the most crowded device (embed, lm_head, + # ViT, first LLM layers), and the video VAE decode's conv activations OOM it. + # Placing the VAE on the lightest shard gives the decode room to breathe. + # Defaults to get_device() so single-GPU behavior is unchanged. self.device = torch.device(kwargs.get("device", get_device())) self.configure_vae_model() self.use_sample = kwargs.get("use_sample", True) From 798d6364d1c2214f478cf7f18cd3093283d7fd79 Mon Sep 17 00:00:00 2001 From: johbau Date: Fri, 19 Jun 2026 08:54:16 +0200 Subject: [PATCH 4/5] inference: add single-prompt smoke-test scaffolding Convenience for bounded generation smoke tests (one output instead of all prompts in the example JSON, which apply_inference_defaults expands to validation_max_samples=100000): - inference_lance.sh: new --DATASET_CONFIG passthrough that forwards --val_dataset_config_file to the Python side. - config/examples/t2i_single.json, t2v_single.json: first prompt of the corresponding example file, so a smoke test generates a single image/video. Co-Authored-By: Claude Opus 4.7 --- config/examples/t2i_single.json | 3 +++ config/examples/t2v_single.json | 3 +++ 2 files changed, 6 insertions(+) create mode 100644 config/examples/t2i_single.json create mode 100644 config/examples/t2v_single.json diff --git a/config/examples/t2i_single.json b/config/examples/t2i_single.json new file mode 100644 index 0000000..e22a671 --- /dev/null +++ b/config/examples/t2i_single.json @@ -0,0 +1,3 @@ +{ + "000000.png": "A beautiful girl, delicate and the half-body shot portrait, light, ultra detailed features, romantic atmosphere, gentle and ethereal mood, The warm light shines on the hair, a half-body shot, a cold and atmospheric scene, holding snowflakes, with some of the snowflakes falling on the head, and the sunlight shining on the upper left corner." +} \ No newline at end of file diff --git a/config/examples/t2v_single.json b/config/examples/t2v_single.json new file mode 100644 index 0000000..5bc4cd8 --- /dev/null +++ b/config/examples/t2v_single.json @@ -0,0 +1,3 @@ +{ + "000000.mp4": "A medium-close shot shows a red panda wearing a gold-trimmed cap and travel satchel on a bright seaside wave with a painted surfboard, foam spray, and a glowing summer sky. Subject fills frame; premium detail, clear focus, lively eyes, readable motion. tracking shot. It rides the wave, lifts one paw in balance, and laughs as spray catches the light." +} \ No newline at end of file From 4d8a54500952eeae1774662cfec5c8a5f8b6cc1d Mon Sep 17 00:00:00 2001 From: johbau Date: Fri, 19 Jun 2026 11:43:57 +0200 Subject: [PATCH 5/5] inference: spatial-tiled VAE decode for high-res video MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Lifts the VAE-decode resolution ceiling (previously ~480-512^2 on a 12 GB card; 768^2 OOMs even on a card dedicated to the VAE). The decode is already streamed temporally, so the memory peak is a single frame's full-resolution conv activations — a spatial problem, not a weight one (so LLM-style layer sharding wouldn't help). Tiling the latent spatially and feather-blending the per-tile decodes bounds per-tile memory and lifts the ceiling. - WanVideoVAE._tiled_decode / _should_tile (modeling/vae/wan/model.py): slice the latent [1,48,t,h,w] into overlapping spatial tiles, decode each via the existing self.vae.decode (which resets its own temporal feat_cache, so each tile is a correct independent temporal stream), and blend into the output with a linear edge ramp + weight-sum normalization. Reuses the validated decode per tile — no Decoder3d rewrite. - Config: vae_tile_size (0=auto above ~512^2 latent, >0=force, <0=disable) and vae_tile_overlap in InferenceArguments; plumbed through inference_lance.py and inference_lance.sh (--VAE_TILE / --VAE_TILE_OVERLAP). - Blend arithmetic unit-tested off-GPU: reconstructing exact-tile decodes matches the source to ~1e-16 across divisible/non-divisible/768^2-latent cases with full coverage (wsum >= 1). TILED_VAE.md documents the design (Approach A: single-GPU tiling, implemented; Approach B: multi-GPU tile distribution, proposal) and the pending in-container validation (480^2 parity, 768^2 memory, seams). SHARDED_LOAD.md cross-links it. Co-Authored-By: Claude Opus 4.8 (1M context) --- SHARDED_LOAD.md | 8 +- TILED_VAE.md | 256 ++++++++++++++++++++++++++++++++++++++ config/config_factory.py | 7 ++ inference_lance.py | 9 +- inference_lance.sh | 9 ++ modeling/vae/wan/model.py | 110 +++++++++++++++- 6 files changed, 393 insertions(+), 6 deletions(-) create mode 100644 TILED_VAE.md diff --git a/SHARDED_LOAD.md b/SHARDED_LOAD.md index 905b604..5e589cb 100644 --- a/SHARDED_LOAD.md +++ b/SHARDED_LOAD.md @@ -143,9 +143,11 @@ needed a model-parallel fix: **Resolution limit (important):** even a dedicated 12 GB card can't decode a 768²×17-frame video — that single-chunk conv activation peaks just over 12 GB. -480²×17 frames fits with ~2 GB to spare. Larger frames/resolution would need -VAE **decode tiling** (spatial patches with overlap-blend), which is not -implemented here. Note the launcher's `VIDEO_HEIGHT`/`VIDEO_WIDTH` default to +480²×17 frames fits with ~2 GB to spare. Larger frames/resolution use VAE +**decode tiling** (spatial patches with overlap-blend) — see +[`TILED_VAE.md`](TILED_VAE.md), implemented in `WanVideoVAE._tiled_decode` +(auto-enabled above ~512², or `--VAE_TILE`). Note the launcher's +`VIDEO_HEIGHT`/`VIDEO_WIDTH` default to **768**, independent of `--RESOLUTION`; pass `--VIDEO_HEIGHT 480 --VIDEO_WIDTH 480` for t2v on a 12 GB card. diff --git a/TILED_VAE.md b/TILED_VAE.md new file mode 100644 index 0000000..cf83cae --- /dev/null +++ b/TILED_VAE.md @@ -0,0 +1,256 @@ +# Tiled VAE decode for high-resolution video + +(The primary, implemented mechanism is spatial **tiling**; Approach B below — +distributing tiles across GPUs — is the optional "sharded" extension.) + +**Status:** Approach A (single-GPU spatial tiling) implemented **and validated** +in-container — see "Validation results" below. Approach B (multi-GPU tile +distribution) and the CPU fallbacks remain proposals. + +**Implemented (Approach A):** `WanVideoVAE._tiled_decode` + `_should_tile` +(`modeling/vae/wan/model.py`) tile the latent spatially, decode each tile through +the existing `self.vae.decode` (own temporal `feat_cache` per tile), and +feather-blend into the output with weight-sum normalization. Config knobs +`vae_tile_size` / `vae_tile_overlap` (`InferenceArguments`) are plumbed through +`inference_lance.py` and `inference_lance.sh` (`--VAE_TILE` / `--VAE_TILE_OVERLAP`). +The blend arithmetic was unit-tested off-GPU: reconstructing exact-tile decodes +matches the source to ~1e-16 across divisible/non-divisible/768²-latent cases with +full coverage (`wsum ≥ 1`). + +This document plans how to lift the VAE-decode resolution ceiling that currently +caps t2v at ~480²/17 frames on a 12 GB card (see the "Resolution limit" note in +[`SHARDED_LOAD.md`](SHARDED_LOAD.md)). The goal is to decode 768² and larger on +the 8 GB-RAM / 5×3060 host. + +--- + +## 1. Problem + +t2v generation works end-to-end, but the final `WanVideoVAE` decode OOMs for +anything larger than ~480²/17 frames, even on a GPU dedicated entirely to the +VAE. At 768² the single-chunk conv activations peak just above 12 GB. + +### Why the obvious fixes don't work + +- **More GPUs for the VAE, LLM-style.** Layer-sharding the decoder across cards + (like we did for the LLM) does **not** help. The VAE weights are tiny (~0.5 GB); + the OOM is an *activation* peak — one decoder layer's full-resolution feature + map (the traceback dies in `RMS_norm`/`F.pad` at the head, at full H×W). Pinning + different layers to different cards still requires one card to hold that whole + activation. The bottleneck is spatial extent, not parameter count. + +- **Lower precision.** The decoder already runs under bf16 autocast; the final + `.float()` is a small fraction. Worth maybe ~10–20%, not the ~3× we need for + 768². + +### What the decode actually does (and where the memory goes) + +`Wan2_2_VAE.decode(z, scale)` (vae2_2.py:787): + +``` +z: [1, 48, t, h, w] # latent (h = H/16, w = W/16) +x = conv2(z) +for i in range(t): # ALREADY temporally streamed, 1 latent frame at a time + out_i = decoder(x[:, :, i:i+1], feat_cache=..., first_chunk=(i==0)) + out = cat([out, out_i], dim=2) # accumulate frames +out = unpatchify(out, patch_size=2) # final 2x spatial; channels 12 -> 3 +``` + +The decoder upsamples 8× spatially (3 `Resample` stages) + the 2× `unpatchify` += 16× total. The causal temporal conv state is carried frame-to-frame in +`feat_cache` (a per-conv list, reset by `clear_cache()` at the top of `decode()`). + +**Conclusion:** temporal cost is already bounded (streamed). The remaining peak +is a *single frame's* spatial activations at full resolution. The fix is to +bound the **spatial** extent processed at once — i.e. **spatial tiling** — and, +as a second step, distribute tiles across the idle cards for speed. + +--- + +## 2. Goals & non-goals + +**Goals** +- Decode arbitrary H×W with bounded per-card memory (target: 768²–1024² on 12 GB). +- Reuse the existing, validated `decode()` per tile — avoid rewriting `Decoder3d`. +- No visible tile seams; output matches the untiled decode within tolerance. +- Off by default; opt-in via config so current 480² behavior is untouched. + +**Non-goals** +- Faster decode at sizes that already fit (tiling adds overhead — only engage it + above a threshold). +- Training / encode-side tiling (only inference decode is in scope; encode is not + on the t2v critical path). + +--- + +## 3. Approach A — spatial tiling on a single GPU (primary) + +Decode the latent in overlapping spatial tiles, each through the full (temporally +streamed) decoder, then crop the halo and feather-blend tiles into the output +canvas. Because `decode()` resets its own `feat_cache` per call, **each tile is a +correct independent temporal stream** — we can call the existing method per tile. + +### Mechanism + +``` +z: [1, 48, t, h, w] +canvas = zeros([1, 3, T_out, H, W]) # H=16h, W=16w; keep on CPU if large +weight = zeros([T_out, H, W]) # for feather normalization +for (lh0, lh1, lw0, lw1) in latent_tiles(h, w, tile, overlap): + # take tile + halo from the latent (real neighbor cells, zeros at borders) + z_tile = z[:, :, :, lh0-halo : lh1+halo, lw0-halo : lw1+halo] + out_tile = vae.decode(z_tile, scale) # existing method, own feat_cache + out_tile = crop_halo(out_tile, halo*16) # drop receptive-field-contaminated border + feather = ramp_mask(out_tile.shape) # linear 0->1 ramp across the overlap region + place out_tile * feather into canvas[..., region]; weight[region] += feather +canvas /= weight.clamp(min=eps) +out = canvas +``` + +### Key design points + +- **Tile in latent space.** A latent tile of `tile×tile` cells → `16·tile`² + pixels. E.g. `tile=24, halo=4` at 768² (h=w=48) gives 4 tiles of 24² latent = + 384² pixels each + halo — comfortably under the per-tile memory of the 480² + decode that already works. +- **Halo (overlap) for conv receptive field.** Decoding a tile in isolation pads + borders with zeros instead of neighbor content, so border pixels differ from + the untiled result. Take a `halo` of real neighbor latent cells on each side, + decode, then **crop `halo·16` pixels** off each interior edge so only + receptive-field-clean pixels are kept. `halo` must cover the decoder's spatial + receptive field in latent cells (see open question O1). +- **Feather blend.** Even with halo-crop, residual low-frequency mismatch can + show as seams. Overlap adjacent kept-regions by a few pixels and blend with a + linear ramp (weight accumulation as above). Halo-crop + feather together are + robust. +- **Output canvas.** `[1,3,17,768,768]` float ≈ 108 MB — trivial; can live on the + VAE card or CPU. Not a bottleneck. + +### Where to implement + +- New method `Wan2_2_VAE.tiled_decode(z, scale, tile, overlap, halo)` in + `modeling/vae/wan/vae2_2.py`, or a wrapper in `WanVideoVAE.vae_decode` + (`modeling/vae/wan/model.py`) that slices the latent and calls the existing + `self.vae.decode` per tile. Prefer the wrapper — zero changes to `Decoder3d`. +- `WanVideoVAE.vae_decode` decides tiled vs. plain based on a threshold / flag. + +### Cost + +Serial over tiles → ~`n_tiles`× the per-tile decode time. For 768² with 4 tiles, +~4× a 384² decode. Acceptable for correctness; Approach B parallelizes it. + +--- + +## 4. Approach B — distribute tiles across GPUs ("sharded VAE", phase 2) + +During VAE decode the LLM cards (cuda:0–3) are idle. Replicate the VAE weights +(~0.5 GB) on each participating card and decode different tiles on different +cards in parallel, then gather + blend on one card (or CPU). + +### Mechanism +- At startup (generation tasks), in addition to the dedicated VAE on cuda:N-1, + hold lightweight VAE replicas on the other cards (each has ~8 GB free during + decode since the LLM is idle but resident). +- Round-robin latent tiles across the replicas; run decodes concurrently (CUDA + is async across devices; use per-device streams or just issue and sync). +- Gather decoded tiles to the canvas device (or CPU) and feather-blend. + +### Trade-offs +- **Speedup:** up to ~`min(n_tiles, n_cards)`× over Approach A's serial loop. +- **Complexity:** weight replication, per-device latent slices, cross-device + gather, synchronization. Higher risk than A. +- **Memory:** each replica card needs `LLM-layer resident + 0.5 GB VAE + one + tile's activations`. Validate the idle-LLM cards have room (they held ~3.7 GB + of layers, leaving ~8 GB — a 384²-tile decode fits). + +Recommend B only after A is correct and if decode wall-clock matters. + +--- + +## 5. Approach C — fallbacks + +- **CPU-offload the accumulating output.** Move each decoded frame/tile to CPU as + produced; keeps GPU holding only the active tile. Cheap, complements A. +- **CPU decode.** Move the whole VAE to CPU. Correct but very slow (conv3d on + CPU). Last-resort for sizes that even tiling can't fit; document, don't default. + +--- + +## 6. Implementation phases + +| Phase | Scope | Deliverable | Status | +|---|---|---|---| +| 0 | Instrument: log peak VAE-decode VRAM vs. resolution; measure per-tile cost. | A table that sizes `tile`/`overlap`. | pending | +| 1 | Approach A wrapper in `WanVideoVAE.vae_decode` + `_tiled_decode`. | 768²/17-frame t2v decodes on one 12 GB card. | **done (impl)**, validation pending | +| 2 | Config knobs + auto-enable threshold. | `--VAE_TILE` / `--VAE_TILE_OVERLAP` plumbed through launcher. | **done** | +| 3 | (optional) Approach B multi-GPU tile distribution. | Decode speedup proportional to free cards. | proposal | + +### Config / flags (phase 2) +- `vae_tile_size` (latent cells, `0` = auto/off), `vae_tile_overlap`, + `vae_tile_halo` in `InferenceArguments`. +- Auto-enable when `H*W` exceeds the measured single-card ceiling (~480²–512²); + below that, decode plainly (no tiling overhead). Default off preserves current + behavior exactly. + +--- + +## 7. Validation plan + +1. **Parity at a size that fits untiled (480²).** Decode with and without tiling; + assert max abs pixel diff below a small tolerance and PSNR high. This is the + correctness gate for halo/blend. +2. **Seam inspection.** Visually check 768² output and diff adjacent-tile borders; + no step discontinuities. +3. **Memory ceiling.** Confirm 768² (and try 1024²) decode stays under 12 GB with + `torch.cuda.max_memory_allocated()` logging. +4. **Temporal consistency.** Confirm per-tile independent `feat_cache` doesn't + introduce temporal flicker vs. untiled (the streaming is per-tile but the + latent it streams is identical, so it should match — verify). +5. **Frame-count / fps** unchanged (ffprobe: 17 frames @ 12 fps as today). + +### Validation results (5 × 3060, 8 GB RAM) + +- **Parity (1):** same-seed 480²/17-frame t2v, tiled (`--VAE_TILE 24 + --VAE_TILE_OVERLAP 8`) vs. plain. PSNR **39.1 dB**, mean |diff| 1.9/255. The + diff map concentrates on the moving subject's edges/texture (h264 re-encode + noise), with **no grid pattern at tile boundaries** — i.e. no seams. ✅ +- **Seams (2):** 768²/17-frame frame inspection — fur, foam, and sky are + continuous across the tile boundaries; no step discontinuities. ✅ +- **Capacity (3):** 768²/17-frame t2v auto-tiled — decode that previously OOM'd + even on a dedicated card now completes; output is a valid 768×768 h264 clip. + ✅ (1024² and `max_memory_allocated` logging not yet measured.) +- **Frame-count (5):** ffprobe reports 768×768, 17 frames @ 12 fps (1.42 s), + unchanged. ✅ +- **Temporal flicker (4):** not separately quantified beyond the per-frame + inspection; no obvious flicker. Spot-check pending. + +--- + +## 8. Risks & open questions + +- **O1 — halo width. [RESOLVED]** An overlap of 8 latent cells passed the 480² + parity test (39.1 dB, no seam grid) and produced seamless 768² output, so the + decoder's receptive field is adequately covered at `overlap=8`. Smaller + overlaps untested; 8 is the validated default. +- **O2 — temporal `feat_cache` under tiling.** Each tile re-streams all frames + with its own cache. This should match untiled (same latent, same causal + recursion per spatial location), but the `"Rep"` first-chunk handling in + `Resample.forward` (vae2_2.py:126) must be verified per tile — confirm + `first_chunk` semantics hold when the spatial extent is a sub-tile. +- **O3 — seam quality on high-frequency content.** Feather may blur fine detail + in overlap bands; tune overlap width vs. sharpness. +- **O4 — Approach B replica memory.** Verify idle-LLM cards truly have room for a + VAE replica + tile activations during decode (LLM weights stay resident). +- **O5 — non-square / odd sizes.** Tile loop must handle remainders (last tile + smaller) and H≠W. Use ceil-div tiling with clamped edges. + +--- + +## 9. Recommendation + +Implement **Approach A** (single-GPU spatial tiling as a `vae_decode` wrapper) +first — it's low-risk (reuses the validated `decode()` per tile), solves the +resolution ceiling outright, and is independently useful even on single-GPU +hosts. Add the config knobs (phase 2). Pursue **Approach B** only if decode +wall-clock becomes the bottleneck once correctness is proven — it's a speed +optimization, not a capability unlock. diff --git a/config/config_factory.py b/config/config_factory.py index 2946979..e6efddf 100644 --- a/config/config_factory.py +++ b/config/config_factory.py @@ -312,6 +312,13 @@ class InferenceArguments(TrainingArguments): # >0 = shard Lance's LLM layers across this many GPUs via accelerate.dispatch_model. shard_num_gpus: int = 0 + # Spatial-tiled VAE decode for high-resolution video (see TILED_VAE.md): + # 0 = auto (tile when the latent spatial size exceeds an internal threshold) + # >0 = tile whenever max(latent_h, latent_w) exceeds this many latent cells + # <0 = never tile (force plain decode) + vae_tile_size: int = 0 + vae_tile_overlap: int = 8 # latent cells of overlap between adjacent tiles + @dataclass class EvaluationArguments(InferenceArguments): diff --git a/inference_lance.py b/inference_lance.py index eb3eed8..e7e452b 100644 --- a/inference_lance.py +++ b/inference_lance.py @@ -709,8 +709,13 @@ def log_stage(stage_name: str, start_time: float, extra: str = ""): shard_n = max(1, min(shard_n, num_visible_gpus)) vae_device = torch.device("cuda", shard_n - 1) stage_start = time.perf_counter() - log_rank0(f"[startup] Initializing VAE on {vae_device}") - vae_model = WanVideoVAE(device=vae_device) + log_rank0(f"[startup] Initializing VAE on {vae_device} " + f"(tile_size={inference_args.vae_tile_size}, tile_overlap={inference_args.vae_tile_overlap})") + vae_model = WanVideoVAE( + device=vae_device, + tile_size=inference_args.vae_tile_size, + tile_overlap=inference_args.vae_tile_overlap, + ) vae_config: AutoEncoderParams = deepcopy(vae_model.vae_config) log_stage("VAE init", stage_start) else: diff --git a/inference_lance.sh b/inference_lance.sh index 4c6db5d..b8a83ad 100755 --- a/inference_lance.sh +++ b/inference_lance.sh @@ -50,6 +50,8 @@ while [[ $# -gt 0 ]]; do --RESOLUTION) RESOLUTION="$2"; shift 2 ;; --TEXT_TEMPLATE) TEXT_TEMPLATE="$2"; shift 2 ;; --SAVE_PATH_GEN) SAVE_PATH_GEN="$2"; shift 2 ;; + --VAE_TILE) VAE_TILE="$2"; shift 2 ;; + --VAE_TILE_OVERLAP) VAE_TILE_OVERLAP="$2"; shift 2 ;; -h|--help) echo "Usage: bash inference_lance_my.sh [OPTIONS]" @@ -123,6 +125,13 @@ CONFIG_ARGS=() if [ -n "$CONFIG_PATH" ]; then CONFIG_ARGS=(--val_dataset_config_file "$CONFIG_PATH") fi +# Optional: spatial-tiled VAE decode for high-res video (see TILED_VAE.md). +if [ -n "${VAE_TILE:-}" ]; then + CONFIG_ARGS+=(--vae_tile_size "$VAE_TILE") +fi +if [ -n "${VAE_TILE_OVERLAP:-}" ]; then + CONFIG_ARGS+=(--vae_tile_overlap "$VAE_TILE_OVERLAP") +fi accelerate launch \ --num_machines $NUM_MACHINES \ diff --git a/modeling/vae/wan/model.py b/modeling/vae/wan/model.py index b76310f..54a71b3 100644 --- a/modeling/vae/wan/model.py +++ b/modeling/vae/wan/model.py @@ -32,6 +32,53 @@ def reparameterize(mu, log_var): return eps * std + mu +# --------------------------------------------------------------------------- +# Spatial-tiled VAE decode (see TILED_VAE.md). +# +# The video VAE decode's conv activations for a single frame at full resolution +# OOM a 12 GB card above ~480-512^2. The decode is already streamed temporally +# (one latent frame at a time), so the remaining peak is purely spatial. Tiling +# the latent spatially, decoding each tile through the existing (temporally +# streamed) decode, and feather-blending the outputs bounds the per-tile memory +# to a small frame, lifting the resolution ceiling. +# --------------------------------------------------------------------------- + +# Latent spatial size (cells) above which auto-tiling kicks in (vae_tile_size==0). +# 480^2 -> h=30 fits plainly; 512^2 -> 32 fits; 768^2 -> 48 OOMs. Threshold sits between. +_VAE_AUTO_TILE_THRESHOLD = 36 +_VAE_DEFAULT_TILE = 32 # latent cells per tile (512 px output at 16x upsample) +_VAE_DEFAULT_OVERLAP = 8 # latent cells of overlap between adjacent tiles + + +def _tile_starts(n: int, tile: int, stride: int) -> List[int]: + """Start indices of tiles covering [0, n); the last tile is snapped to the + edge so the whole extent is covered even when n is not a multiple of stride.""" + if n <= tile: + return [0] + starts = list(range(0, n - tile + 1, stride)) + if starts[-1] != n - tile: + starts.append(n - tile) + return starts + + +def _blend_ramp_1d(length: int, ramp: int, ramp_lo: bool, ramp_hi: bool, + device, dtype) -> Tensor: + """1-D blend weight: 1.0 everywhere, linearly ramped toward (but not to) 0 on + edges that overlap a neighbor. Two adjacent tiles' opposing ramps span the same + overlap band and sum to ~1; the caller's weight-sum normalization makes the + blend exact regardless, while single-coverage regions stay at weight 1.""" + w = torch.ones(length, device=device, dtype=dtype) + r = min(ramp, length // 2) + if r > 0: + # values in (0, 1): 1/(r+1) .. r/(r+1) — never exactly 0, so wsum > 0. + vals = torch.linspace(1.0 / (r + 1), r / (r + 1), r, device=device, dtype=dtype) + if ramp_lo: + w[:r] = vals + if ramp_hi: + w[length - r:] = vals.flip(0) + return w + + class WanVideoVAE(object): __version__ = "v2.2" __name__ = "WanVideoVAE" @@ -52,6 +99,13 @@ def __init__(self, config_path: str = "", **kwargs) -> None: self.configure_vae_model() self.use_sample = kwargs.get("use_sample", True) + # Spatial-tiled decode config (latent cells). See TILED_VAE.md. + # tile_size > 0 : tile whenever max(h, w) > tile_size + # tile_size == 0: auto — tile when max(h, w) > _VAE_AUTO_TILE_THRESHOLD + # tile_size < 0: never tile (force plain decode) + self.tile_size = int(kwargs.get("tile_size", 0) or 0) + self.tile_overlap = int(kwargs.get("tile_overlap", _VAE_DEFAULT_OVERLAP)) + # wan vae2.2 config is equal to seedance vae self.vae_config = AutoEncoderParams( downsample_spatial=16, @@ -102,6 +156,57 @@ def vae_encode(self, samples: List[Tensor], **kwargs) -> List[Tensor]: return latents + def _should_tile(self, u: Tensor) -> bool: + """Decide whether to spatially tile the decode of latent u [1,48,t,h,w].""" + if self.tile_size < 0: + return False + h, w = u.shape[-2], u.shape[-1] + threshold = self.tile_size if self.tile_size > 0 else _VAE_AUTO_TILE_THRESHOLD + return max(h, w) > threshold + + def _tiled_decode(self, u: Tensor) -> Tensor: + """Decode latent u [1,48,t,h,w] in overlapping spatial tiles and + feather-blend into the full output. Each tile reuses self.vae.decode, + which resets its own temporal feat_cache, so every tile is a correct + independent temporal stream. Returns [1,3,T,H,W].""" + _, _, _, h, w = u.shape + tile = self.tile_size if self.tile_size > 0 else _VAE_DEFAULT_TILE + # overlap must leave a positive stride and fit within a tile + overlap = max(0, min(self.tile_overlap, tile // 2 - 1)) + stride = max(1, tile - overlap) + + row_starts = _tile_starts(h, tile, stride) + col_starts = _tile_starts(w, tile, stride) + + canvas = None + wsum = None + f = None # spatial upsample factor (pixels per latent cell), inferred from first tile + for r0 in row_starts: + r1 = min(r0 + tile, h) + for c0 in col_starts: + c1 = min(c0 + tile, w) + out = self.vae.decode(u[:, :, :, r0:r1, c0:c1]) # [1,3,T,(r1-r0)*f,(c1-c0)*f] + + if canvas is None: + f = out.shape[-2] // (r1 - r0) + T_out, C_out = out.shape[2], out.shape[1] + H, W = h * f, w * f + canvas = torch.zeros((1, C_out, T_out, H, W), dtype=out.dtype, device=out.device) + wsum = torch.zeros((1, 1, 1, H, W), dtype=out.dtype, device=out.device) + + py0, py1, px0, px1 = r0 * f, r1 * f, c0 * f, c1 * f + wy = _blend_ramp_1d(py1 - py0, overlap * f, ramp_lo=(r0 != 0), ramp_hi=(r1 != h), + device=out.device, dtype=out.dtype) + wx = _blend_ramp_1d(px1 - px0, overlap * f, ramp_lo=(c0 != 0), ramp_hi=(c1 != w), + device=out.device, dtype=out.dtype) + w2d = (wy[:, None] * wx[None, :])[None, None, None, :, :] # [1,1,1,ph,pw] + + canvas[:, :, :, py0:py1, px0:px1] += out * w2d + wsum[:, :, :, py0:py1, px0:px1] += w2d + del out + + return canvas / wsum.clamp(min=1e-6) + @torch.no_grad() def vae_decode(self, latents: List[Tensor], **kwargs) -> List[Tensor]: device = self.device @@ -112,7 +217,10 @@ def vae_decode(self, latents: List[Tensor], **kwargs) -> List[Tensor]: u = u.unsqueeze(0).to(device=device) # -> [1,t,h,w,48] u = rearrange(u, "b ... c -> b c ...") # -> [1,48,t,h,w] - x_hat = self.vae.decode(u) # -> [1,3,T,H,W] + if self._should_tile(u): + x_hat = self._tiled_decode(u) # -> [1,3,T,H,W] + else: + x_hat = self.vae.decode(u) # -> [1,3,T,H,W] samples.append(x_hat.squeeze(0)) # -> List[[3,T,H,W]]