diff --git a/LOW_RAM_LOAD.md b/LOW_RAM_LOAD.md new file mode 100644 index 0000000..d55154b --- /dev/null +++ b/LOW_RAM_LOAD.md @@ -0,0 +1,108 @@ +# Low-RAM inference load + +This change lets `inference_lance.py` run on hosts with **less system RAM than +the model needs to materialize in fp32 on CPU** (~12 GB for Lance_3B). It +removes the CPU-side memory spike from the load path. Multi-GPU model +parallelism is layered on top by a separate change — see +[`SHARDED_LOAD.md`](SHARDED_LOAD.md). + +## Why + +On `main`, the first stage of `main()`: + +```python +language_model: Qwen2ForCausalLM = Qwen2ForCausalLM(llm_config) +``` + +allocates a freshly-init'd fp32 3B model on CPU (~12 GB). On an 8 GB host this +gets OOM-killed before any GPU code runs. The actual checkpoint load +(`load_file → load_state_dict`) makes things worse by holding the full state +dict on CPU as a second copy. Several smaller allocations downstream +(numpy fp64 sin-cos, full-file `safe_open()` mmap) also push past the ceiling. + +## What changed + +### 1. Meta-init the model skeleton + +LLM, ViT, and the `Lance` wrapper are constructed inside +`accelerate.init_empty_weights()`. Every `nn.Parameter` becomes shape-only on +the `meta` device — zero storage. The fp32-on-CPU spike disappears. + +`modeling/lance/modeling_utils.py`'s `PositionEmbedding{,3D}._init_weights` +now early-returns when its param is still meta, deferring sin-cos +materialization until after the load. + +### 2. Stream the checkpoint, don't mmap it + +`safetensors.safe_open()` mmaps the whole 12 GB checkpoint file. On a host +with strict commit accounting and no swap, the kernel refuses a 12 GB +file-backed mapping (`ENOMEM`). The streaming loader (`_stream_load_into`) +opens the file in plain binary mode, reads the 8-byte header length + JSON +header, and seeks to each tensor's data offset. Peak CPU RAM during load is +one tensor at a time — worst case ~1.2 GB for the embedding layer, briefly. + +### 3. Load tensors directly to GPU at bf16 + +Each tensor is read into CPU, cast to bf16, and handed to +`accelerate.utils.set_module_tensor_to_device(model, name, device, +value=tensor, dtype=torch.bfloat16)`. **The `dtype=` argument is +load-bearing**: without it, accelerate silently casts the value to +`old_value.dtype` to match the meta tensor's nominal dtype (fp32 default). +That would both double VRAM and produce fp32 weights that the bf16 autocast +path then promotes back to fp32 mid-attention — eventually crashing on an +index-put dtype mismatch. + +After the load loop, `_materialize_remaining_meta` walks the model for +parameters still on meta (e.g. `latent_pos_embed.pos_embed`, which the +original code popped from the checkpoint to recompute per-resolution), +allocates real storage on the target device, and re-runs `_init_weights()`. + +### 4. Compute sin-cos position embeddings on GPU, not CPU + +`get_3d_sincos_pos_embed` (numpy fp64) used to allocate three intermediate +arrays of shape `(t*h*w, ~D/3)` plus a concatenated copy — peaking around +**4 GB of CPU RAM** for Lance's defaults (`t=31, h=w=64, D=2048`). + +Replaced with `_torch_3d_sincos` / `_torch_2d_sincos` that compute on the +parameter's device in torch fp32. CPU contribution is ~zero. Same change for +the 2D variant used by `PositionEmbedding`. + +### 5. `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` + +The per-tensor streaming pattern fragments the default CUDA caching +allocator enough that big tensors can fail to allocate even when total free +VRAM is plenty. `expandable_segments` coalesces freed regions and lets large +allocations grow into them. Set by default in +`benchmarks/sample_env.sh::lance_setup_common_env` (`${VAR:-default}` so a +user-set value still wins). + +## Memory profile (Lance_3B on a single GPU) + +| Stage | Peak CPU RSS | Notes | +|---|---|---| +| Meta-init LLM/ViT/Lance | ~few hundred MB | torch + python + dataclass overhead | +| ViT streaming load (1.2 GB safetensors) | ~1 GB | one fp32 tensor at a time | +| Lance streaming load (12.3 GB safetensors) | ~1.5 GB | embedding layer is the worst tensor | +| Materialize popped sin-cos pos_embed | tiny | computed on GPU | +| Tokenizer + resize | <500 MB | | + +Peak CPU RSS during load stays under ~2 GB, comfortably below an 8 GB +ceiling. Total VRAM usage on the target card is ~6 GB (Lance_3B in bf16 ++ ViT + VAE), which fits on a single 40 GB GPU but not a 12 GB one — for +that case, see [`SHARDED_LOAD.md`](SHARDED_LOAD.md). + +## What `main` users keep + +For hosts that *do* have enough RAM, this change is still net-positive: +the load is faster (no fp32 → bf16 conversion afterwards, no full state-dict +held on CPU) and uses half the VRAM (bf16 instead of fp32 at rest). The +launcher and config are unchanged in this commit; the behavior change is +transparent to the runner. + +## File-by-file summary + +| File | Change | +|---|---| +| `inference_lance.py` | `init_empty_weights()` for LLM/ViT/Lance; new streaming safetensors reader (`_read_safetensors_header`, `_read_safetensors_tensor`, `_stream_load_into`); `_materialize_remaining_meta`; `_resolve_lance_checkpoint`; passes `dtype=torch.bfloat16` to `set_module_tensor_to_device`; removed the per-batch `.to(device)` calls on the model. | +| `modeling/lance/modeling_utils.py` | New `_torch_2d_sincos` / `_torch_3d_sincos`; `_init_weights` early-returns on meta tensors and otherwise computes on the param's device. | +| `benchmarks/sample_env.sh` | Exports `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` (user override respected). | diff --git a/SHARDED_LOAD.md b/SHARDED_LOAD.md new file mode 100644 index 0000000..5e589cb --- /dev/null +++ b/SHARDED_LOAD.md @@ -0,0 +1,205 @@ +# Sharded model-parallel inference + +This change adds **single-process, model-parallel inference across N GPUs**. +It's the second half of the work that lets Lance run on an 8 GB system RAM ++ 5 × RTX 3060 host. The first half — getting the model loaded without +materializing fp32 weights on CPU — is in +[`LOW_RAM_LOAD.md`](LOW_RAM_LOAD.md) and must be in place first; this change +builds on the streaming loader and `init_empty_weights()` infrastructure +introduced there. + +## Why + +A 3B Lance model in bf16 is ~6 GB of weights, plus ViT (~1.2 GB) and the +WanVideoVAE (~2-3 GB) for generation tasks. None of that fits on a single +12 GB 3060 with room for activations. But it does fit comfortably across +the **60 GB aggregate VRAM** of five 3060s if the LLM's transformer layers +are sharded across cards. Doing that requires: + +- A device map that puts the right layers on the right cards. +- Cross-card forward hooks (`accelerate.dispatch_model`). +- Source-side fixes everywhere Lance's existing code assumes everything + lives on one device. + +The previous launcher ran `accelerate launch --num_processes=$NUM_GPUS`, +which is **data-parallel**: each process gets its own full copy of the +model. That doesn't help here — each rank still needs to fit the whole +model, and CPU RAM pressure goes up N× (one copy materialized per process). +For inference we want model-parallel: one process, model split across cards. + +## What changed + +### 1. Device map + +`_build_lance_device_map(model, num_gpus)` (in `inference_lance.py`) builds +a `{module_name: gpu_index}` map: + +- LLM transformer layers split across `cuda:1..N-1`, with `cuda:0` getting + a **reduced** share (about half of the even split) because cuda:0 also + hosts the entry/exit modules and the WanVideoVAE. +- `embed_tokens`, `norm`, `norm_moe_gen` (MoT generation-branch sibling), + `rotary_emb`, `lm_head` pinned to `cuda:0` — these are the token-flow + boundaries. +- ViT pinned to `cuda:0` because `Lance.validation_video_to_text` combines + ViT output with `embed_tokens` output via `masked_scatter` *inline* + (lance.py around line 1010). That combine happens in parent-class Python, + not inside a submodule's `forward()`, so accelerate's hooks don't get a + chance to align devices. +- Connector / time_embedder / vae2llm / llm2vae / latent_pos_embed all on + cuda:0 (small). +- Safety net: any parameter not covered by an explicit prefix lands on + cuda:0. Without this, `dispatch_model` rejects the device_map with a + hard error the first time someone adds a top-level MoT sibling we didn't + anticipate. + +### 2. `accelerate.dispatch_model` + +Installs pre/post forward hooks on each dispatched submodule so activations +get moved to the right card before each `.forward()`. After this point the +model must **not** be `.to()`-d (that would collapse every shard onto one +card). The per-batch `fsdp_model.to(device, dtype=bf16)` call in +`validate_on_fixed_batch` was already removed in the low-RAM change. + +The streaming loader from the low-RAM change already supports a non-empty +device_map — every tensor is routed onto the GPU dictated by +`_device_for_param(name, device_map)` at load time, so the model is on +its shards *before* hooks are attached. + +### 3. Replace `flex_attention` with eager-SDPA dense masks + +`flex_attention`'s `BlockMask` captures device-specific tensors when it's +built. Under model parallelism, a layer on `cuda:>0` calls `flex_attention` +with `q/k/v` on that shard and a mask whose captures live on `cuda:0`; +dynamo's tracer refuses to combine them with +`Unhandled FakeTensor Device Propagation`. + +The fix uses a path that already exists in `qwen2_navit.py`: when +`attention_mask` is a `List`, the attention forward iterates per-sample and +runs `torch.nn.functional.scaled_dot_product_attention` instead of +`flex_attention`. SDPA has no dynamo trace and crosses devices cleanly via +the standard accelerate hooks. + +`_flex_mask_to_dense_list(mask_fn, seqlen, device, dtype)` evaluates the +flex mask function on a meshgrid of `(q_idx, kv_idx)` to get a bool mask, +converts to additive float (`-inf` where masked), and returns it as a +single-element `List`. All three `create_block_mask` call sites in +`lance.py` (one in `process_attention_mask`, one in the main `forward`, +one in `validation_video_to_text`) route through this helper. + +### 4. Parent-class device-alignment fixes + +A few places in `lance.py` and `qwen2_navit.py` combine tensors from +different shards in parent-Python (not inside a submodule's `forward()`), +which accelerate's hooks cannot reach. Each was fixed locally: + +- `qwen2_navit.py:619` — at the start of each layer's `forward_train`, + `attention_mask.to(device=packed_sequence_.device)` now handles both + the old single-Tensor BlockMask path and the new List-of-Tensors SDPA + path. +- `qwen2_navit.py:901` — after the layer loop in `Qwen2Model.forward_train`, + `packed_sequence` lives on whichever shard ran the last layer (e.g. + `cuda:N-1`). The index tensors and the final `norm` live on `cuda:0`. + Added one `packed_sequence.to(packed_und_token_indexes.device)` to + consolidate before the indexing-based combine. + +### 5. Launcher and config + +- `inference_lance.sh`: + - `NUM_GPUS=5` default (was 1). It now means "number of shards", not + "number of data-parallel processes". + - `accelerate launch --num_processes 1` always — model parallelism is + inside one process. + - Forwards `--shard_num_gpus $NUM_GPUS` to the Python side. +- `config/config_factory.py`: adds `shard_num_gpus: int = 0` to + `InferenceArguments`. `0` means "use all visible GPUs" + (`torch.cuda.device_count()`); >0 caps to that many. + +### 6. Generation tasks (t2i / t2v): the diffusion + VAE-decode path + +The generation tasks exercise code the understanding path doesn't, and each +needed a model-parallel fix: + +- **`forward_inference` gen-mode norm** (`qwen2_navit.py`, the `mode=="gen"` + branch after the decoder-layer loop). This is the inference-mode twin of the + `forward_train` fix above: after the layer loop `packed_query_sequence` is on + the last shard, but `packed_text_indexes`/`packed_vae_token_indexes` and the + `norm`/`norm_moe_gen` modules are on cuda:0. Added a + `.to(packed_text_indexes.device)` guard before the index-put. This runs on + every diffusion timestep. (The KVcache generation path attends via + `flash_attn_varlen_func`, *not* flex_attention, so the dense-mask change in §3 + isn't even exercised here — no regression risk.) + +- **VAE placement** (`modeling/vae/wan/model.py` + `inference_lance.py`). The + WanVideoVAE used to hard-code `get_device()` = cuda:0. `WanVideoVAE` now + accepts a `device=` (default `get_device()`, so single-GPU is unchanged), and + the launcher builds it on the **last** shard. + +- **Dedicated VAE card** (`_build_lance_device_map(..., reserve_last_for_vae=True)` + for generation tasks). The video VAE decode's conv activations are large + enough (~9 GB at 480² / 17 frames) that they won't fit on a card also holding + LLM layers. For generation, the LLM is sharded across the first `N-1` cards + and the last card is left empty of LLM weights so the VAE decode has a + near-full 12 GB to itself. + +**Resolution limit (important):** even a dedicated 12 GB card can't decode a +768²×17-frame video — that single-chunk conv activation peaks just over 12 GB. +480²×17 frames fits with ~2 GB to spare. Larger frames/resolution use VAE +**decode tiling** (spatial patches with overlap-blend) — see +[`TILED_VAE.md`](TILED_VAE.md), implemented in `WanVideoVAE._tiled_decode` +(auto-enabled above ~512², or `--VAE_TILE`). Note the launcher's +`VIDEO_HEIGHT`/`VIDEO_WIDTH` default to +**768**, independent of `--RESOLUTION`; pass `--VIDEO_HEIGHT 480 --VIDEO_WIDTH 480` +for t2v on a 12 GB card. + +## Memory profile + +### Understanding (`x2t_image` / `x2t_video`), Lance_3B, 5 × 3060 + +| Card | Holds | VRAM | +|---|---|---| +| cuda:0 | ~3 LLM layers + embed + lm_head + ViT + VAE + latent_pos_embed + connectors + CUDA context | ~6 GB | +| cuda:1–4 | ~8 LLM layers each | ~3 GB each | + +### Generation (`t2v`, reserve-VAE-card), Lance_3B_Video, 480² / 17 frames + +| Card | Holds | VRAM | +|---|---|---| +| cuda:0 | ~4 LLM layers + embed + lm_head + ViT + connectors | ~5 GB | +| cuda:1–3 | ~10–11 LLM layers each | ~3.7 GB each | +| cuda:4 | VAE only (decode peaks here) | ~0.8 GB idle → ~5 GB during decode | + +Smoke tests confirmed: `x2t_image`, `x2t_video`, `t2i`, and `t2v` (480², 17 +frames, 1.42 s mp4) all complete on the 8 GB-RAM / 5×3060 host. + +## Performance + +About 67 s per understanding batch at 768 resolution on the 5×3060 rig. +This is *slow* because: + +- Every layer's attention runs eager SDPA with a dense mask instead of + `flex_attention`'s compiled kernel. +- Activations shuttle across PCIe between cards via `dispatch_model`'s + hooks at each layer boundary. + +The point of this change is **fitting** the model on this hardware, not +throughput. A single A100 40 GB (cloud fallback) is the right move if you +need real speed. + +## What `main` users keep + +`shard_num_gpus=0` (the default) defers to `torch.cuda.device_count()`, +so on a 1-GPU host the device map collapses to "everything on cuda:0" +and `dispatch_model` is skipped. The dense-mask SDPA replacement does +run unconditionally — if you want `flex_attention` back for a single-card +setup, that's the one piece that's worth gating behind a flag. + +## File-by-file summary + +| File | Change | +|---|---| +| `inference_lance.py` | New `_build_lance_device_map` (with `reserve_last_for_vae`); `dispatch_model` import + call when sharding > 1; `shard_num_gpus` arg threading; VAE built on the last shard. | +| `inference_lance.sh` | `NUM_GPUS=5` default, `--num_processes 1`, passes `--shard_num_gpus`. | +| `config/config_factory.py` | Adds `shard_num_gpus: int = 0` to `InferenceArguments`. | +| `modeling/lance/lance.py` | New `_flex_mask_to_dense_list`; all three `create_block_mask` sites route through it. | +| `modeling/lance/qwen2_navit.py` | Layer `attention_mask.to(device=…)` handles List; `Qwen2Model.forward_train` and the `forward_inference` gen-mode branch move the sequence back to the index device after the layer loop. | +| `modeling/vae/wan/model.py` | `WanVideoVAE` accepts a `device=` override (default `get_device()`); `configure_vae_model`/`vae_encode`/`vae_decode` use it, so the VAE can live on a card other than cuda:0. | diff --git a/TILED_VAE.md b/TILED_VAE.md new file mode 100644 index 0000000..cf83cae --- /dev/null +++ b/TILED_VAE.md @@ -0,0 +1,256 @@ +# Tiled VAE decode for high-resolution video + +(The primary, implemented mechanism is spatial **tiling**; Approach B below — +distributing tiles across GPUs — is the optional "sharded" extension.) + +**Status:** Approach A (single-GPU spatial tiling) implemented **and validated** +in-container — see "Validation results" below. Approach B (multi-GPU tile +distribution) and the CPU fallbacks remain proposals. + +**Implemented (Approach A):** `WanVideoVAE._tiled_decode` + `_should_tile` +(`modeling/vae/wan/model.py`) tile the latent spatially, decode each tile through +the existing `self.vae.decode` (own temporal `feat_cache` per tile), and +feather-blend into the output with weight-sum normalization. Config knobs +`vae_tile_size` / `vae_tile_overlap` (`InferenceArguments`) are plumbed through +`inference_lance.py` and `inference_lance.sh` (`--VAE_TILE` / `--VAE_TILE_OVERLAP`). +The blend arithmetic was unit-tested off-GPU: reconstructing exact-tile decodes +matches the source to ~1e-16 across divisible/non-divisible/768²-latent cases with +full coverage (`wsum ≥ 1`). + +This document plans how to lift the VAE-decode resolution ceiling that currently +caps t2v at ~480²/17 frames on a 12 GB card (see the "Resolution limit" note in +[`SHARDED_LOAD.md`](SHARDED_LOAD.md)). The goal is to decode 768² and larger on +the 8 GB-RAM / 5×3060 host. + +--- + +## 1. Problem + +t2v generation works end-to-end, but the final `WanVideoVAE` decode OOMs for +anything larger than ~480²/17 frames, even on a GPU dedicated entirely to the +VAE. At 768² the single-chunk conv activations peak just above 12 GB. + +### Why the obvious fixes don't work + +- **More GPUs for the VAE, LLM-style.** Layer-sharding the decoder across cards + (like we did for the LLM) does **not** help. The VAE weights are tiny (~0.5 GB); + the OOM is an *activation* peak — one decoder layer's full-resolution feature + map (the traceback dies in `RMS_norm`/`F.pad` at the head, at full H×W). Pinning + different layers to different cards still requires one card to hold that whole + activation. The bottleneck is spatial extent, not parameter count. + +- **Lower precision.** The decoder already runs under bf16 autocast; the final + `.float()` is a small fraction. Worth maybe ~10–20%, not the ~3× we need for + 768². + +### What the decode actually does (and where the memory goes) + +`Wan2_2_VAE.decode(z, scale)` (vae2_2.py:787): + +``` +z: [1, 48, t, h, w] # latent (h = H/16, w = W/16) +x = conv2(z) +for i in range(t): # ALREADY temporally streamed, 1 latent frame at a time + out_i = decoder(x[:, :, i:i+1], feat_cache=..., first_chunk=(i==0)) + out = cat([out, out_i], dim=2) # accumulate frames +out = unpatchify(out, patch_size=2) # final 2x spatial; channels 12 -> 3 +``` + +The decoder upsamples 8× spatially (3 `Resample` stages) + the 2× `unpatchify` += 16× total. The causal temporal conv state is carried frame-to-frame in +`feat_cache` (a per-conv list, reset by `clear_cache()` at the top of `decode()`). + +**Conclusion:** temporal cost is already bounded (streamed). The remaining peak +is a *single frame's* spatial activations at full resolution. The fix is to +bound the **spatial** extent processed at once — i.e. **spatial tiling** — and, +as a second step, distribute tiles across the idle cards for speed. + +--- + +## 2. Goals & non-goals + +**Goals** +- Decode arbitrary H×W with bounded per-card memory (target: 768²–1024² on 12 GB). +- Reuse the existing, validated `decode()` per tile — avoid rewriting `Decoder3d`. +- No visible tile seams; output matches the untiled decode within tolerance. +- Off by default; opt-in via config so current 480² behavior is untouched. + +**Non-goals** +- Faster decode at sizes that already fit (tiling adds overhead — only engage it + above a threshold). +- Training / encode-side tiling (only inference decode is in scope; encode is not + on the t2v critical path). + +--- + +## 3. Approach A — spatial tiling on a single GPU (primary) + +Decode the latent in overlapping spatial tiles, each through the full (temporally +streamed) decoder, then crop the halo and feather-blend tiles into the output +canvas. Because `decode()` resets its own `feat_cache` per call, **each tile is a +correct independent temporal stream** — we can call the existing method per tile. + +### Mechanism + +``` +z: [1, 48, t, h, w] +canvas = zeros([1, 3, T_out, H, W]) # H=16h, W=16w; keep on CPU if large +weight = zeros([T_out, H, W]) # for feather normalization +for (lh0, lh1, lw0, lw1) in latent_tiles(h, w, tile, overlap): + # take tile + halo from the latent (real neighbor cells, zeros at borders) + z_tile = z[:, :, :, lh0-halo : lh1+halo, lw0-halo : lw1+halo] + out_tile = vae.decode(z_tile, scale) # existing method, own feat_cache + out_tile = crop_halo(out_tile, halo*16) # drop receptive-field-contaminated border + feather = ramp_mask(out_tile.shape) # linear 0->1 ramp across the overlap region + place out_tile * feather into canvas[..., region]; weight[region] += feather +canvas /= weight.clamp(min=eps) +out = canvas +``` + +### Key design points + +- **Tile in latent space.** A latent tile of `tile×tile` cells → `16·tile`² + pixels. E.g. `tile=24, halo=4` at 768² (h=w=48) gives 4 tiles of 24² latent = + 384² pixels each + halo — comfortably under the per-tile memory of the 480² + decode that already works. +- **Halo (overlap) for conv receptive field.** Decoding a tile in isolation pads + borders with zeros instead of neighbor content, so border pixels differ from + the untiled result. Take a `halo` of real neighbor latent cells on each side, + decode, then **crop `halo·16` pixels** off each interior edge so only + receptive-field-clean pixels are kept. `halo` must cover the decoder's spatial + receptive field in latent cells (see open question O1). +- **Feather blend.** Even with halo-crop, residual low-frequency mismatch can + show as seams. Overlap adjacent kept-regions by a few pixels and blend with a + linear ramp (weight accumulation as above). Halo-crop + feather together are + robust. +- **Output canvas.** `[1,3,17,768,768]` float ≈ 108 MB — trivial; can live on the + VAE card or CPU. Not a bottleneck. + +### Where to implement + +- New method `Wan2_2_VAE.tiled_decode(z, scale, tile, overlap, halo)` in + `modeling/vae/wan/vae2_2.py`, or a wrapper in `WanVideoVAE.vae_decode` + (`modeling/vae/wan/model.py`) that slices the latent and calls the existing + `self.vae.decode` per tile. Prefer the wrapper — zero changes to `Decoder3d`. +- `WanVideoVAE.vae_decode` decides tiled vs. plain based on a threshold / flag. + +### Cost + +Serial over tiles → ~`n_tiles`× the per-tile decode time. For 768² with 4 tiles, +~4× a 384² decode. Acceptable for correctness; Approach B parallelizes it. + +--- + +## 4. Approach B — distribute tiles across GPUs ("sharded VAE", phase 2) + +During VAE decode the LLM cards (cuda:0–3) are idle. Replicate the VAE weights +(~0.5 GB) on each participating card and decode different tiles on different +cards in parallel, then gather + blend on one card (or CPU). + +### Mechanism +- At startup (generation tasks), in addition to the dedicated VAE on cuda:N-1, + hold lightweight VAE replicas on the other cards (each has ~8 GB free during + decode since the LLM is idle but resident). +- Round-robin latent tiles across the replicas; run decodes concurrently (CUDA + is async across devices; use per-device streams or just issue and sync). +- Gather decoded tiles to the canvas device (or CPU) and feather-blend. + +### Trade-offs +- **Speedup:** up to ~`min(n_tiles, n_cards)`× over Approach A's serial loop. +- **Complexity:** weight replication, per-device latent slices, cross-device + gather, synchronization. Higher risk than A. +- **Memory:** each replica card needs `LLM-layer resident + 0.5 GB VAE + one + tile's activations`. Validate the idle-LLM cards have room (they held ~3.7 GB + of layers, leaving ~8 GB — a 384²-tile decode fits). + +Recommend B only after A is correct and if decode wall-clock matters. + +--- + +## 5. Approach C — fallbacks + +- **CPU-offload the accumulating output.** Move each decoded frame/tile to CPU as + produced; keeps GPU holding only the active tile. Cheap, complements A. +- **CPU decode.** Move the whole VAE to CPU. Correct but very slow (conv3d on + CPU). Last-resort for sizes that even tiling can't fit; document, don't default. + +--- + +## 6. Implementation phases + +| Phase | Scope | Deliverable | Status | +|---|---|---|---| +| 0 | Instrument: log peak VAE-decode VRAM vs. resolution; measure per-tile cost. | A table that sizes `tile`/`overlap`. | pending | +| 1 | Approach A wrapper in `WanVideoVAE.vae_decode` + `_tiled_decode`. | 768²/17-frame t2v decodes on one 12 GB card. | **done (impl)**, validation pending | +| 2 | Config knobs + auto-enable threshold. | `--VAE_TILE` / `--VAE_TILE_OVERLAP` plumbed through launcher. | **done** | +| 3 | (optional) Approach B multi-GPU tile distribution. | Decode speedup proportional to free cards. | proposal | + +### Config / flags (phase 2) +- `vae_tile_size` (latent cells, `0` = auto/off), `vae_tile_overlap`, + `vae_tile_halo` in `InferenceArguments`. +- Auto-enable when `H*W` exceeds the measured single-card ceiling (~480²–512²); + below that, decode plainly (no tiling overhead). Default off preserves current + behavior exactly. + +--- + +## 7. Validation plan + +1. **Parity at a size that fits untiled (480²).** Decode with and without tiling; + assert max abs pixel diff below a small tolerance and PSNR high. This is the + correctness gate for halo/blend. +2. **Seam inspection.** Visually check 768² output and diff adjacent-tile borders; + no step discontinuities. +3. **Memory ceiling.** Confirm 768² (and try 1024²) decode stays under 12 GB with + `torch.cuda.max_memory_allocated()` logging. +4. **Temporal consistency.** Confirm per-tile independent `feat_cache` doesn't + introduce temporal flicker vs. untiled (the streaming is per-tile but the + latent it streams is identical, so it should match — verify). +5. **Frame-count / fps** unchanged (ffprobe: 17 frames @ 12 fps as today). + +### Validation results (5 × 3060, 8 GB RAM) + +- **Parity (1):** same-seed 480²/17-frame t2v, tiled (`--VAE_TILE 24 + --VAE_TILE_OVERLAP 8`) vs. plain. PSNR **39.1 dB**, mean |diff| 1.9/255. The + diff map concentrates on the moving subject's edges/texture (h264 re-encode + noise), with **no grid pattern at tile boundaries** — i.e. no seams. ✅ +- **Seams (2):** 768²/17-frame frame inspection — fur, foam, and sky are + continuous across the tile boundaries; no step discontinuities. ✅ +- **Capacity (3):** 768²/17-frame t2v auto-tiled — decode that previously OOM'd + even on a dedicated card now completes; output is a valid 768×768 h264 clip. + ✅ (1024² and `max_memory_allocated` logging not yet measured.) +- **Frame-count (5):** ffprobe reports 768×768, 17 frames @ 12 fps (1.42 s), + unchanged. ✅ +- **Temporal flicker (4):** not separately quantified beyond the per-frame + inspection; no obvious flicker. Spot-check pending. + +--- + +## 8. Risks & open questions + +- **O1 — halo width. [RESOLVED]** An overlap of 8 latent cells passed the 480² + parity test (39.1 dB, no seam grid) and produced seamless 768² output, so the + decoder's receptive field is adequately covered at `overlap=8`. Smaller + overlaps untested; 8 is the validated default. +- **O2 — temporal `feat_cache` under tiling.** Each tile re-streams all frames + with its own cache. This should match untiled (same latent, same causal + recursion per spatial location), but the `"Rep"` first-chunk handling in + `Resample.forward` (vae2_2.py:126) must be verified per tile — confirm + `first_chunk` semantics hold when the spatial extent is a sub-tile. +- **O3 — seam quality on high-frequency content.** Feather may blur fine detail + in overlap bands; tune overlap width vs. sharpness. +- **O4 — Approach B replica memory.** Verify idle-LLM cards truly have room for a + VAE replica + tile activations during decode (LLM weights stay resident). +- **O5 — non-square / odd sizes.** Tile loop must handle remainders (last tile + smaller) and H≠W. Use ceil-div tiling with clamped edges. + +--- + +## 9. Recommendation + +Implement **Approach A** (single-GPU spatial tiling as a `vae_decode` wrapper) +first — it's low-risk (reuses the validated `decode()` per tile), solves the +resolution ceiling outright, and is independently useful even on single-GPU +hosts. Add the config knobs (phase 2). Pursue **Approach B** only if decode +wall-clock becomes the bottleneck once correctness is proven — it's a speed +optimization, not a capability unlock. diff --git a/benchmarks/sample_env.sh b/benchmarks/sample_env.sh index 479c500..11f82de 100644 --- a/benchmarks/sample_env.sh +++ b/benchmarks/sample_env.sh @@ -38,6 +38,14 @@ lance_setup_common_env() { export CUDA_LAUNCH_BLOCKING="${CUDA_LAUNCH_BLOCKING:-0}" export NCCL_DEBUG="${NCCL_DEBUG:-VERSION}" export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC="${TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC:-900}" + + # The streaming load in inference_lance.py allocates thousands of small-to-large + # tensors onto each GPU as it walks the checkpoint. The default caching allocator + # fragments under that pattern hard enough that a 1.2 GB tensor can fail to + # allocate on a card that still has plenty of total free VRAM. expandable_segments + # coalesces freed regions and lets large allocations grow into them. Required for + # the 5×3060 sharded load to succeed. + export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" } diff --git a/config/config_factory.py b/config/config_factory.py index 5797364..e6efddf 100644 --- a/config/config_factory.py +++ b/config/config_factory.py @@ -307,6 +307,18 @@ class InferenceArguments(TrainingArguments): use_KVcache: bool = False enhance_prompt: bool = False # Rewrite T2V prompts before inference when enabled. + # Model-parallel sharding for low-RAM hosts: + # 0 = use all visible GPUs (torch.cuda.device_count()). + # >0 = shard Lance's LLM layers across this many GPUs via accelerate.dispatch_model. + shard_num_gpus: int = 0 + + # Spatial-tiled VAE decode for high-resolution video (see TILED_VAE.md): + # 0 = auto (tile when the latent spatial size exceeds an internal threshold) + # >0 = tile whenever max(latent_h, latent_w) exceeds this many latent cells + # <0 = never tile (force plain decode) + vae_tile_size: int = 0 + vae_tile_overlap: int = 8 # latent cells of overlap between adjacent tiles + @dataclass class EvaluationArguments(InferenceArguments): diff --git a/config/examples/t2i_single.json b/config/examples/t2i_single.json new file mode 100644 index 0000000..e22a671 --- /dev/null +++ b/config/examples/t2i_single.json @@ -0,0 +1,3 @@ +{ + "000000.png": "A beautiful girl, delicate and the half-body shot portrait, light, ultra detailed features, romantic atmosphere, gentle and ethereal mood, The warm light shines on the hair, a half-body shot, a cold and atmospheric scene, holding snowflakes, with some of the snowflakes falling on the head, and the sunlight shining on the upper left corner." +} \ No newline at end of file diff --git a/config/examples/t2v_single.json b/config/examples/t2v_single.json new file mode 100644 index 0000000..5bc4cd8 --- /dev/null +++ b/config/examples/t2v_single.json @@ -0,0 +1,3 @@ +{ + "000000.mp4": "A medium-close shot shows a red panda wearing a gold-trimmed cap and travel satchel on a bright seaside wave with a painted surfboard, foam spray, and a glowing summer sky. Subject fills frame; premium detail, clear focus, lively eyes, readable motion. tracking shot. It rides the wave, lifts one paw in balance, and laughs as spray catches the light." +} \ No newline at end of file diff --git a/inference_lance.py b/inference_lance.py index 27f49c4..e7e452b 100644 --- a/inference_lance.py +++ b/inference_lance.py @@ -24,13 +24,18 @@ import os.path as osp from copy import deepcopy import json -from typing import Tuple, cast, Optional +from typing import Tuple, cast, Optional, Dict, List import torch import torch.distributed as dist +from torch import nn from torch.utils.data import DataLoader from transformers import HfArgumentParser, set_seed from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig +import struct +import numpy as np from safetensors.torch import load_file +from accelerate import init_empty_weights, dispatch_model +from accelerate.utils import set_module_tensor_to_device from data.dataset_base import DataConfig, simple_custom_collate from data.data_utils import add_special_tokens @@ -114,35 +119,248 @@ }, } -def init_from_model_path_if_needed(model: Qwen2ForCausalLM, model_args: ModelArguments): - # Always load the trained Lance checkpoint from model_path. - path_dir = model_args.model_path - ema_path = osp.join(path_dir, "ema.safetensors") - model_path = osp.join(path_dir, "model.safetensors") +# Names of buffers/params that the original codepath intentionally popped from the +# checkpoint before load (they are fixed sin-cos embeddings rebuilt per resolution). +_POPPED_FROM_CHECKPOINT = frozenset({"latent_pos_embed.pos_embed"}) + + +def _resolve_lance_checkpoint(model_path_dir: str) -> str: + """Return the path of the Lance checkpoint to load (preferring model.safetensors).""" + for fname in ("model.safetensors", "ema.safetensors"): + cand = osp.join(model_path_dir, fname) + if osp.exists(cand): + return cand + raise FileNotFoundError( + f"No Lance checkpoint ('model.safetensors' or 'ema.safetensors') found in {model_path_dir}. " + "Download the full Lance_3B (or Lance_3B_Video) weights with:\n" + ' hf download bytedance-research/Lance --local-dir downloads --include "Lance_3B/*"' + ) + + +def _build_lance_device_map(model: "Lance", num_gpus: int, reserve_last_for_vae: bool = False) -> Dict[str, int]: + """Spread Lance's LLM transformer layers across the available cards. + + cuda:0 is the "entry/exit" device for tokens and logits (embed + lm_head + norm) + and also hosts the ViT. Those fixed-cost residents eat ~2-3 GB on cuda:0 before a + single LLM layer lands there, so we give cuda:0 a *reduced* layer share. + `reserve_last_for_vae`: when True (generation tasks), the LLM is sharded across + only the first `num_gpus - 1` cards, leaving the last GPU empty of LLM weights so + the WanVideoVAE (built on that card) has a near-full 12 GB for its decode. The + video VAE decode's conv activations (~9 GB at 480p/17 frames) won't fit on a card + that also holds LLM layers, so a dedicated card is the simplest robust fix. + """ + num_layers = len(model.language_model.model.layers) + num_gpus = max(1, num_gpus) - model_path_ft = None - if osp.exists(model_path): - model_path_ft = model_path - elif osp.exists(ema_path): - model_path_ft = ema_path + # Number of cards the LLM may use. Reserve the last one for the VAE on generation. + llm_gpus = num_gpus - 1 if (reserve_last_for_vae and num_gpus >= 2) else num_gpus + llm_gpus = max(1, llm_gpus) - if model_path_ft: - model_state_dict = load_file(model_path_ft, device="cpu") + device_map: Dict[str, int] = {} + + if llm_gpus == 1: + # All LLM layers on cuda:0 (single-GPU, or 2-GPU generation with VAE on cuda:1). + for i in range(num_layers): + device_map[f"language_model.model.layers.{i}"] = 0 else: - raise FileNotFoundError( - f"Fine-tuning failed: No valid checkpoint ('ema.safetensors' or 'model.safetensors') found in {path_dir}" + # cuda:0 gets roughly half its even share; the remainder spreads across + # cuda:1..llm_gpus-1. For 36 layers / 4 LLM cards that's 4 on cuda:0, ~11 each. + gpu0_layer_count = max(1, num_layers // (2 * llm_gpus)) + remaining = num_layers - gpu0_layer_count + other_gpus = llm_gpus - 1 + layers_per_other = (remaining + other_gpus - 1) // other_gpus # ceil-div + for i in range(num_layers): + if i < gpu0_layer_count: + device_map[f"language_model.model.layers.{i}"] = 0 + else: + idx = i - gpu0_layer_count + gpu = 1 + min(idx // layers_per_other, other_gpus - 1) + device_map[f"language_model.model.layers.{i}"] = gpu + + # Token entry/exit and both MoT norms pinned to cuda:0. `norm_moe_gen` is the + # generation-branch sibling of `norm`; it must be on the same device because the + # forward path indexes a shared sequence and dispatches by token type. + device_map["language_model.model.embed_tokens"] = 0 + device_map["language_model.model.norm"] = 0 + if hasattr(model.language_model.model, "norm_moe_gen"): + device_map["language_model.model.norm_moe_gen"] = 0 + if hasattr(model.language_model.model, "rotary_emb"): + device_map["language_model.model.rotary_emb"] = 0 + device_map["language_model.lm_head"] = 0 + + # Lance heads. Small (a few MB each) except latent_pos_embed (~250 MB sin-cos); + # keep them all near the embed/connector on cuda:0. + for extra in ("connector", "time_embedder", "vae2llm", "llm2vae", + "latent_pos_embed", "task_embedding", "modality_embedding"): + if hasattr(model, extra) and getattr(model, extra) is not None: + device_map[extra] = 0 + + # ViT must live on cuda:0. Lance.validation_video_to_text combines ViT outputs + # with embed_tokens outputs via `masked_scatter` inline (lance.py:1010) — that + # combine happens in parent-class Python, not inside a submodule's forward(), so + # accelerate's hooks don't get a chance to align devices. cuda:0 gets a reduced + # LLM-layer share precisely so there's headroom for the ViT + embed + lm_head. + if hasattr(model, "vit_model") and model.vit_model is not None: + device_map["vit_model"] = 0 + + # Safety net: any parameter not covered by an explicit prefix above (e.g. a future + # top-level MoT sibling we didn't anticipate) lands on cuda:0. Without this, + # accelerate.dispatch_model rejects the device_map with a hard error. + covered_prefixes = list(device_map.keys()) + for param_name, _ in model.named_parameters(): + if not any(param_name == p or param_name.startswith(p + ".") for p in covered_prefixes): + device_map[param_name] = 0 + + return device_map + + +def _device_for_param(param_name: str, device_map: Dict[str, int]) -> int: + """Find the device assignment for `param_name` by walking up its dotted path.""" + parts = param_name.split(".") + for i in range(len(parts), 0, -1): + prefix = ".".join(parts[:i]) + if prefix in device_map: + return device_map[prefix] + return 0 # default to cuda:0 for any unmapped params (Lance has very few) + + +# safetensors dtype string -> (numpy dtype used to read raw bytes, optional torch view dtype) +# bf16 has no native numpy dtype, so we read as uint16 then bit-cast via tensor.view(). +_SAFE_DTYPE_MAP = { + "F64": (np.float64, None), + "F32": (np.float32, None), + "F16": (np.float16, None), + "BF16": (np.uint16, torch.bfloat16), + "I64": (np.int64, None), + "I32": (np.int32, None), + "I16": (np.int16, None), + "I8": (np.int8, None), + "U8": (np.uint8, None), + "BOOL": (np.bool_, None), +} + + +def _read_safetensors_header(f) -> Tuple[Dict, int]: + """Read the 8-byte length + JSON header. Returns (header_dict, data_section_offset).""" + header_len_bytes = f.read(8) + if len(header_len_bytes) != 8: + raise ValueError(f"Truncated safetensors file: only {len(header_len_bytes)}/8 length bytes") + (header_len,) = struct.unpack(" torch.Tensor: + """Read one tensor's bytes via plain seek+read (no mmap) and return a CPU torch tensor.""" + start, end = meta["data_offsets"] + nbytes = end - start + f.seek(data_section_offset + start) + raw = f.read(nbytes) + if len(raw) != nbytes: + raise ValueError(f"Short read: got {len(raw)}/{nbytes} bytes") + np_dtype, view_dtype = _SAFE_DTYPE_MAP[meta["dtype"]] + # .copy() detaches from the read-only `raw` bytes so the buffer can be freed before + # we keep the torch tensor around. Peak CPU memory: one tensor at a time. + np_arr = np.frombuffer(raw, dtype=np_dtype).copy().reshape(meta["shape"]) + del raw + tensor = torch.from_numpy(np_arr) + if view_dtype is not None: + tensor = tensor.view(view_dtype) + return tensor + + +def _stream_load_into( + model: nn.Module, + safetensors_path: str, + device_map: Dict[str, int], + key_prefix: str = "", + skip_keys: frozenset = frozenset(), + dtype: torch.dtype = torch.bfloat16, +) -> Tuple[List[str], List[str]]: + """Stream safetensors into `model`, one tensor at a time, directly onto GPU shards. + + Uses plain `open() + seek() + read()` rather than `safetensors.safe_open()` because + safe_open mmaps the whole file (12 GB for Lance_3B) — the kernel's overcommit policy + rejects that on the 8 GB host. With direct IO peak CPU RAM is one tensor at a time + (worst case ~1.2 GB for the embedding layer at fp32, briefly). + + `key_prefix` is prepended to each safetensors key when looking up the target + parameter in `model` — the ViT file stores bare keys; they live under `vit_model.*` + in the Lance wrapper. + """ + loaded: List[str] = [] + unknown: List[str] = [] + model_keys = set(dict(model.named_parameters()).keys()) | set(dict(model.named_buffers()).keys()) + with open(safetensors_path, "rb") as f: + header, data_section_offset = _read_safetensors_header(f) + for key, meta in header.items(): + if key == "__metadata__": + continue + full_name = f"{key_prefix}{key}" + if full_name in skip_keys: + continue + if full_name not in model_keys: + unknown.append(full_name) + continue + tensor = _read_safetensors_tensor(f, meta, data_section_offset).to(dtype) + device = _device_for_param(full_name, device_map) + # Pass dtype= explicitly: without it, set_module_tensor_to_device casts + # `value` to `old_value.dtype` to match the meta tensor's nominal dtype + # (which is fp32 from init_empty_weights' default). That silently upcasts + # our bf16 tensors back to fp32, doubling VRAM and breaking the autocast + # path (fp32 weights * bf16 activations → fp32 output, then index-put into + # a bf16 destination crashes with a dtype-mismatch error). + set_module_tensor_to_device(model, full_name, device, value=tensor, dtype=dtype) + loaded.append(full_name) + del tensor + return loaded, unknown + + +def _materialize_remaining_meta(model: "Lance", device_map: Dict[str, int], dtype: torch.dtype): + """Allocate any still-meta params on their target devices and re-init the + fixed sin-cos position embeddings (which were popped from the checkpoint).""" + from modeling.lance.modeling_utils import PositionEmbedding, PositionEmbedding3D + + materialized = [] + for name, param in list(model.named_parameters()): + if not param.is_meta: + continue + device = _device_for_param(name, device_map) + # Walk to the owning module to swap the meta param for a real one. + *mod_parts, attr = name.split(".") + owner = model + for m in mod_parts: + owner = getattr(owner, m) + new_param = torch.nn.Parameter( + torch.zeros(param.shape, dtype=dtype, device=f"cuda:{device}"), + requires_grad=param.requires_grad, ) + setattr(owner, attr, new_param) + materialized.append(name) - # NOTE: position embeds are fixed sinusoidal embeddings, so we can just pop it off, - # which makes it easier to adapt to different resolutions. - if 'latent_pos_embed.pos_embed' in model_state_dict: - model_state_dict.pop('latent_pos_embed.pos_embed') + # Same for any buffers that ended up meta (rare; defensive). + for name, buf in list(model.named_buffers()): + if not buf.is_meta: + continue + device = _device_for_param(name, device_map) + *mod_parts, attr = name.split(".") + owner = model + for m in mod_parts: + owner = getattr(owner, m) + owner.register_buffer( + attr, torch.zeros(buf.shape, dtype=buf.dtype, device=f"cuda:{device}") + ) + materialized.append(name) - msg = model.load_state_dict(model_state_dict, strict=False) # strict = True | False - clean_memory(model_state_dict) + # Re-run the sin-cos init now that the param tensors are real. + for sub in model.modules(): + if isinstance(sub, (PositionEmbedding, PositionEmbedding3D)): + sub._init_weights() - return msg + return materialized def clean_memory(*objects): @@ -265,7 +483,9 @@ def validate_on_fixed_batch( save_path_gt: str = "", ): val_data = val_data_cpu.cuda(device).to_dict() - fsdp_model = fsdp_model.to(device=device, dtype=torch.bfloat16) + # Do NOT call fsdp_model.to(device) here: the model is sharded across multiple GPUs + # via accelerate.dispatch_model, and .to() would collapse all shards onto one card. + # Weights are already bf16 from the streaming load. with torch.no_grad(), torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16): # Compute padded_latent. @@ -450,11 +670,19 @@ def log_stage(stage_name: str, start_time: float, extra: str = ""): llm_config.freeze_und = training_args.freeze_und llm_config.apply_qwen_2_5_vl_pos_emb = training_args.apply_qwen_2_5_vl_pos_emb + # ===== Meta-init: build the module skeleton with zero CPU RAM. ===== + # The bare Qwen2ForCausalLM(llm_config) call used to materialize a full fp32 3B + # model on CPU (~12 GB), which is the load step that OOM-killed an 8 GB box. + # Under init_empty_weights() every nn.Parameter is created on the "meta" device + # (shape only, no storage), so this whole block stays at near-zero RAM. stage_start = time.perf_counter() - log_rank0(f"[startup] Initializing LLM weights: {model_args.model_path}") - language_model: Qwen2ForCausalLM = Qwen2ForCausalLM(llm_config) - log_stage("LLM weight init", stage_start) + log_rank0(f"[startup] Meta-initializing LLM: {model_args.model_path}") + with init_empty_weights(): + language_model: Qwen2ForCausalLM = Qwen2ForCausalLM(llm_config) + log_stage("LLM meta-init", stage_start) + vit_model = None + vit_config = None if training_args.visual_und: if model_args.vit_type in ("qwen2_5_vl", "qwen_2_5_vl_original"): stage_start = time.perf_counter() @@ -463,27 +691,37 @@ def log_stage(stage_name: str, start_time: float, extra: str = ""): log_stage("VIT config load", stage_start) stage_start = time.perf_counter() - log_rank0(f"[startup] Loading VIT weights: {osp.join(model_args.vit_path, 'vit.safetensors')}") - vit_model = Qwen2_5_VisionTransformerPretrainedModel(vit_config) - vit_weights = load_file(osp.join(model_args.vit_path, "vit.safetensors")) - vit_model.load_state_dict(vit_weights, strict=True) - log_stage("VIT weight load", stage_start) + log_rank0("[startup] Meta-initializing VIT (weights loaded later from vit.safetensors)") + with init_empty_weights(): + vit_model = Qwen2_5_VisionTransformerPretrainedModel(vit_config) + log_stage("VIT meta-init", stage_start) else: raise ValueError(f"Unsupported vit_type: {model_args.vit_type}") - clean_memory(vit_weights) - if training_args.visual_gen: + # WanVideoVAE itself uses torch.device("meta") + assign-load internally, so it + # doesn't contribute to the CPU RAM spike. Built eagerly so vae_config is real. + # Place it on the lightest shard (the last GPU) when sharding across >1 card: + # cuda:0 is the most crowded device and the video VAE decode's conv activations + # OOM it. On a single GPU this resolves to cuda:0 (unchanged behavior). + num_visible_gpus = torch.cuda.device_count() + shard_n = inference_args.shard_num_gpus or num_visible_gpus + shard_n = max(1, min(shard_n, num_visible_gpus)) + vae_device = torch.device("cuda", shard_n - 1) stage_start = time.perf_counter() - log_rank0("[startup] Initializing VAE") - vae_model = WanVideoVAE() + log_rank0(f"[startup] Initializing VAE on {vae_device} " + f"(tile_size={inference_args.vae_tile_size}, tile_overlap={inference_args.vae_tile_overlap})") + vae_model = WanVideoVAE( + device=vae_device, + tile_size=inference_args.vae_tile_size, + tile_overlap=inference_args.vae_tile_overlap, + ) vae_config: AutoEncoderParams = deepcopy(vae_model.vae_config) log_stage("VAE init", stage_start) else: vae_model = None vae_config = None - # Lance configuration config = LanceConfig( visual_gen=training_args.visual_gen, visual_und=training_args.visual_und, @@ -498,34 +736,85 @@ def log_stage(stage_name: str, start_time: float, extra: str = ""): interpolate_pos=model_args.interpolate_pos, timestep_shift=training_args.timestep_shift, ) - model: Lance = Lance( - language_model=language_model, - vit_model=vit_model if training_args.visual_und else None, - vit_type=model_args.vit_type, - config=config, - training_args=training_args, - ) + + stage_start = time.perf_counter() + log_rank0("[startup] Meta-initializing Lance wrapper") + with init_empty_weights(): + model: Lance = Lance( + language_model=language_model, + vit_model=vit_model if training_args.visual_und else None, + vit_type=model_args.vit_type, + config=config, + training_args=training_args, + ) + log_stage("Lance meta-init", stage_start) + + # ===== Decide how to shard across GPUs. ===== + num_visible_gpus = torch.cuda.device_count() + shard_n = inference_args.shard_num_gpus or num_visible_gpus + shard_n = max(1, min(shard_n, num_visible_gpus)) + # Generation tasks decode through the VAE, whose video decode needs a near-full + # card to itself; reserve the last GPU for it (the VAE was built there above). + reserve_vae = bool(training_args.visual_gen) and inference_args.task in GENERATION_TASKS and shard_n >= 2 + log_rank0(f"[startup] Sharding Lance across {shard_n} GPU(s) (visible: {num_visible_gpus}; " + f"reserve last GPU for VAE: {reserve_vae})") + device_map = _build_lance_device_map(model, shard_n, reserve_last_for_vae=reserve_vae) + + # ===== Stream-load weights directly onto each shard's GPU at bf16. ===== + # ViT weights live in a separate file; in the Lance wrapper they sit under vit_model.*. + if training_args.visual_und: + vit_safetensors = osp.join(model_args.vit_path, "vit.safetensors") + stage_start = time.perf_counter() + log_rank0(f"[startup] Streaming VIT weights from {vit_safetensors}") + vit_loaded, vit_unknown = _stream_load_into( + model, vit_safetensors, device_map, key_prefix="vit_model.", dtype=torch.bfloat16, + ) + log_stage("VIT streaming load", stage_start, + extra=f"loaded={len(vit_loaded)} unknown={len(vit_unknown)}") + if vit_unknown: + log_rank0(f"[startup] WARNING: {len(vit_unknown)} ViT key(s) had no matching param " + f"(first few: {vit_unknown[:5]})") + + # The main Lance checkpoint: covers language_model.*, the connector / vae<->llm / + # time_embedder / latent_pos_embed (popped) / etc. Skip the popped sin-cos buffer. + lance_ckpt = _resolve_lance_checkpoint(model_args.model_path) stage_start = time.perf_counter() - log_rank0(f"[startup] Moving Lance model to GPU {DEVICE}") - model = model.to(DEVICE) - log_stage("Lance model move to GPU", stage_start) + log_rank0(f"[startup] Streaming Lance checkpoint from {lance_ckpt}") + main_loaded, main_unknown = _stream_load_into( + model, lance_ckpt, device_map, skip_keys=_POPPED_FROM_CHECKPOINT, dtype=torch.bfloat16, + ) + log_stage("Lance streaming load", stage_start, + extra=f"loaded={len(main_loaded)} unknown={len(main_unknown)}") + if main_unknown: + # Many Lance training-time keys (optimizer state, etc.) may not exist on the + # inference model; informational, not fatal. + log_rank0(f"[startup] NOTE: {len(main_unknown)} checkpoint key(s) had no matching param " + f"(first few: {main_unknown[:5]})") + + # Anything still meta (the popped sin-cos pos_embed, any non-checkpointed buffer) + # gets allocated on its target device and re-initialized to the right values. + materialized = _materialize_remaining_meta(model, device_map, dtype=torch.bfloat16) + if materialized: + log_rank0(f"[startup] Materialized {len(materialized)} meta param/buffer(s) post-load " + f"(first few: {materialized[:5]})") + + # init_moe() copies UND weights into the moe_gen slots. For inference from a fully- + # trained Lance checkpoint, the moe_gen weights are already loaded above — running + # init_moe now would either no-op (good) or clobber them with sharded cross-device + # state_dict() copies (bad). Skip unconditionally on the meta-init path. + if training_args.copy_init_moe: + log_rank0("[startup] Skipping init_moe(): full checkpoint already contains moe_gen weights.") - # Setup tokenizer for model: + # ===== Tokenizer + post-load patch-ups. ===== stage_start = time.perf_counter() log_rank0(f"[startup] Loading tokenizer: {model_args.model_path}") tokenizer: Qwen2Tokenizer = Qwen2Tokenizer.from_pretrained(model_args.model_path) - tokenizer, new_token_ids, num_new_tokens = add_special_tokens(tokenizer) log_stage("tokenizer load and special token init", stage_start, extra=f"num_new_tokens={num_new_tokens}") - # Initialize MoE before loading the checkpoint. - if training_args.copy_init_moe: - language_model.init_moe() - - init_from_model_path_if_needed(model, model_args) - - # Resize afterward to avoid checkpoint shape mismatches or overwritten weights. if num_new_tokens > 0: + # Embedding and lm_head are both pinned to cuda:0 in the device_map, so + # resize_token_embeddings can do its in-place resize without crossing devices. model.language_model.resize_token_embeddings(len(tokenizer)) model.config.llm_config.vocab_size = len(tokenizer) model.language_model.config.vocab_size = len(tokenizer) @@ -534,7 +823,7 @@ def log_stage(stage_name: str, start_time: float, extra: str = ""): from common.model.hacks import hack_qwen2_5_vl_config language_model = hack_qwen2_5_vl_config(language_model) - image_token_id = language_model.config.video_token_id # image_token_id # <|image_pad|> + image_token_id = language_model.config.video_token_id # <|image_pad|> new_token_ids.update({"image_token_id": image_token_id}) model.update_tokenizer(tokenizer=tokenizer) @@ -549,7 +838,12 @@ def log_stage(stage_name: str, start_time: float, extra: str = ""): else: assert model.language_model.get_input_embeddings().weight.data.data_ptr() != model.language_model.get_output_embeddings().weight.data.data_ptr(), 'tie_word_embeddings conflict' - model = model.to(device=DEVICE, dtype=torch.bfloat16) + # ===== Attach cross-device hooks so activations flow between shards. ===== + # dispatch_model walks `device_map` and installs pre/post forward hooks that move + # activations to the right card before each submodule runs. After this point, the + # model must NOT be .to()'d as that would collapse the shards. + if shard_n > 1: + model = dispatch_model(model, device_map=device_map) model.eval() if vae_model is not None and hasattr(vae_model, "eval"): vae_model.eval() diff --git a/inference_lance.sh b/inference_lance.sh index 3a5b959..b8a83ad 100755 --- a/inference_lance.sh +++ b/inference_lance.sh @@ -5,7 +5,11 @@ cd "$SCRIPT_DIR" source "$SCRIPT_DIR/benchmarks/sample_env.sh" # ========================= Inference Parameters ========================= -NUM_GPUS=${NUM_GPUS:-1} +# NUM_GPUS is the number of GPUs to *shard* Lance across (model-parallel), not the +# number of replicas. Launch always runs a single process; the Python side uses +# accelerate.dispatch_model to split the LLM's transformer layers across NUM_GPUS +# cards. Default matches the 5×3060 host; override if running on fewer cards. +NUM_GPUS=${NUM_GPUS:-5} TASK_NAME=${TASK_NAME:-x2t_image} # t2i | image_edit | t2v | i2v | video_edit | x2t_image | x2t_video @@ -46,6 +50,8 @@ while [[ $# -gt 0 ]]; do --RESOLUTION) RESOLUTION="$2"; shift 2 ;; --TEXT_TEMPLATE) TEXT_TEMPLATE="$2"; shift 2 ;; --SAVE_PATH_GEN) SAVE_PATH_GEN="$2"; shift 2 ;; + --VAE_TILE) VAE_TILE="$2"; shift 2 ;; + --VAE_TILE_OVERLAP) VAE_TILE_OVERLAP="$2"; shift 2 ;; -h|--help) echo "Usage: bash inference_lance_my.sh [OPTIONS]" @@ -119,15 +125,23 @@ CONFIG_ARGS=() if [ -n "$CONFIG_PATH" ]; then CONFIG_ARGS=(--val_dataset_config_file "$CONFIG_PATH") fi +# Optional: spatial-tiled VAE decode for high-res video (see TILED_VAE.md). +if [ -n "${VAE_TILE:-}" ]; then + CONFIG_ARGS+=(--vae_tile_size "$VAE_TILE") +fi +if [ -n "${VAE_TILE_OVERLAP:-}" ]; then + CONFIG_ARGS+=(--vae_tile_overlap "$VAE_TILE_OVERLAP") +fi accelerate launch \ --num_machines $NUM_MACHINES \ - --num_processes $TOTAL_RANK \ + --num_processes 1 \ --machine_rank $MACHINE_RANK \ --main_process_ip $MAIN_PROCESS_IP \ --main_process_port $MAIN_PROCESS_PORT \ --mixed_precision bf16 \ inference_lance.py \ + --shard_num_gpus $NUM_GPUS \ --model_path "$MODEL_PATH" \ --vit_type qwen_2_5_vl_original \ --llm_qk_norm true \ diff --git a/modeling/lance/lance.py b/modeling/lance/lance.py index c425fe6..18d1746 100644 --- a/modeling/lance/lance.py +++ b/modeling/lance/lance.py @@ -40,6 +40,36 @@ from data.common import shift_position_ids from copy import deepcopy +def _flex_mask_to_dense_list( + mask_fn, + seqlen: int, + device, + dtype: torch.dtype = torch.bfloat16, +): + """Convert flex_attention's mask function (a closure over device-specific tensors) + into a List[Tensor] of dense additive masks usable by scaled_dot_product_attention. + + flex_attention is incompatible with accelerate's model-parallel dispatch: the + BlockMask captures tensors on the device where it was built, and dynamo's tracer + refuses to combine them with Q/K/V tensors on a different shard. The attention + forward already has a List-of-masks branch that runs eager SDPA per sample (see + qwen2_navit.py `if isinstance(attention_mask, List)`), and SDPA crosses devices + cleanly via the standard accelerate hooks. Calling this helper at every + create_block_mask site funnels the layer attention into that SDPA branch. + """ + q_idx = torch.arange(seqlen, device=device) + kv_idx = torch.arange(seqlen, device=device) + qq, kk = torch.meshgrid(q_idx, kv_idx, indexing="ij") + # `and_masks`/`or_masks` from flex_attention call `b.new_ones(...)` on the batch + # arg, so b/h must be tensors (not ints). Sub-masks ignore b/h anyway. + b = torch.zeros((), dtype=torch.long, device=device) + h = torch.zeros((), dtype=torch.long, device=device) + bool_mask = mask_fn(b, h, qq, kk) + dense = torch.zeros((seqlen, seqlen), dtype=dtype, device=device) + dense.masked_fill_(~bool_mask, float("-inf")) + return [dense] + + class LanceConfig(PretrainedConfig): def __init__( self, @@ -140,10 +170,8 @@ def process_attention_mask(self, current_attn_modes, current_split_lens, current current_attn_modes_ = ["full" if mode_ in ["full_noise", "full_noise_target"] else mode_ for mode_ in current_attn_modes] sparse_mask = create_sparse_mask(current_seq_len, current_split_lens, current_attn_modes_, device) current_seq_len_sum = sum(current_seq_len) - attention_mask = create_block_mask( - sparse_mask, B=1, H=self.num_heads, Q_LEN=current_seq_len_sum, KV_LEN=current_seq_len_sum, device=device, BLOCK_SIZE=BLOCK_SIZE, _compile=False - ) - return attention_mask + # Dense mask List → SDPA branch in qwen2_navit.py (model-parallel safe). + return _flex_mask_to_dense_list(sparse_mask, current_seq_len_sum, device) def forward( self, @@ -239,8 +267,9 @@ def forward( if nested_attention_masks is None: attn_modes_ = ["full" if mode=="full_noise" else mode for mode in attn_modes] sparse_mask = create_sparse_mask(sample_lens, split_lens, attn_modes_, packed_text_embedding.device) - seqlen = sum(sample_lens) - attention_mask = create_block_mask(sparse_mask, B=1, H=self.num_heads, Q_LEN=seqlen, KV_LEN=seqlen, device=packed_text_embedding.device, BLOCK_SIZE=BLOCK_SIZE, _compile=True) + seqlen = sum(sample_lens) # 始终是max_num_tokens + # Dense mask List → SDPA branch (model-parallel safe). + attention_mask = _flex_mask_to_dense_list(sparse_mask, seqlen, packed_text_embedding.device) else: attention_mask = nested_attention_masks @@ -907,7 +936,8 @@ def validation_video_to_text( current_text_len = (step + 1) - (num_text_ids - 1) current_split_lens_ = current_split_lens + [current_text_len, num_pad + 1 - current_text_len] sparse_mask = create_sparse_mask(current_sample_lens, current_split_lens_, current_attn_modes_, device) - attention_mask = create_block_mask(sparse_mask, B=1, H=self.num_heads, Q_LEN=seqlen, KV_LEN=seqlen, device=device, BLOCK_SIZE=BLOCK_SIZE, _compile=False) + # Dense mask List → SDPA branch (model-parallel safe). + attention_mask = _flex_mask_to_dense_list(sparse_mask, seqlen, device) extra_inputs = {"mode": "und"} if self.use_moe: diff --git a/modeling/lance/modeling_utils.py b/modeling/lance/modeling_utils.py index 4b24559..0c8e2a5 100644 --- a/modeling/lance/modeling_utils.py +++ b/modeling/lance/modeling_utils.py @@ -160,6 +160,51 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states +def _torch_1d_sincos(dim: int, pos: torch.Tensor) -> torch.Tensor: + """Torch port of get_1d_sincos_pos_embed_from_grid; runs on pos.device in fp32.""" + assert dim % 2 == 0 + device = pos.device + omega = torch.arange(dim // 2, dtype=torch.float32, device=device) + omega = 1.0 / (10000.0 ** (omega / (dim / 2.0))) # (D/2,) + out = pos.reshape(-1)[:, None] * omega[None, :] # (M, D/2) + return torch.cat([torch.sin(out), torch.cos(out)], dim=1) # (M, D) + + +def _torch_2d_sincos(embed_dim: int, grid_size: int, device, dtype) -> torch.Tensor: + grid_h = torch.arange(grid_size, dtype=torch.float32, device=device) + grid_w = torch.arange(grid_size, dtype=torch.float32, device=device) + # `np.meshgrid(grid_w, grid_h)` puts width first; torch's `indexing="xy"` matches that. + gw, gh = torch.meshgrid(grid_w, grid_h, indexing="xy") + emb_h = _torch_1d_sincos(embed_dim // 2, gh.flatten()) + emb_w = _torch_1d_sincos(embed_dim // 2, gw.flatten()) + return torch.cat([emb_h, emb_w], dim=1).to(dtype) + + +def _torch_3d_sincos(embed_dim: int, t: int, h: int, w: int, device, dtype) -> torch.Tensor: + """Torch port of get_3d_sincos_pos_embed; computes on `device` in fp32. + + The numpy original allocates three intermediate fp64 arrays of shape (t*h*w, ~D/3) + each plus a concatenated copy, peaking around 4 GB of CPU RAM for Lance's defaults + (t=31, h=w=64, D=2048). Doing the same work in fp32 on a GPU is ~free and avoids + the spike that OOMs the 8 GB host post-load. + """ + assert embed_dim % 2 == 0 + d = embed_dim // 3 + d = d if d % 2 == 0 else d - 1 + dim_t, dim_h = d, d + dim_w = embed_dim - 2 * d + assert dim_w % 2 == 0 + + grid_t = torch.arange(t, dtype=torch.float32, device=device) + grid_h = torch.arange(h, dtype=torch.float32, device=device) + grid_w = torch.arange(w, dtype=torch.float32, device=device) + tt, hh, ww = torch.meshgrid(grid_t, grid_h, grid_w, indexing="ij") + emb_t = _torch_1d_sincos(dim_t, tt.flatten()) + emb_h = _torch_1d_sincos(dim_h, hh.flatten()) + emb_w = _torch_1d_sincos(dim_w, ww.flatten()) + return torch.cat([emb_t, emb_h, emb_w], dim=1).to(dtype) + + class PositionEmbedding(nn.Module): def __init__(self, max_num_patch_per_side, hidden_size): super().__init__() @@ -172,9 +217,18 @@ def __init__(self, max_num_patch_per_side, hidden_size): self._init_weights() def _init_weights(self): - # Initialize (and freeze) pos_embed by sin-cos embedding: - pos_embed = get_2d_sincos_pos_embed(self.hidden_size, self.max_num_patch_per_side) - self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float()) + # Skip when constructed under accelerate.init_empty_weights() — the param is on + # meta and cannot be copied into. The caller must materialize the param on a real + # device and re-invoke _init_weights() after dispatch. + if self.pos_embed.is_meta: + return + with torch.no_grad(): + self.pos_embed.data.copy_( + _torch_2d_sincos( + self.hidden_size, self.max_num_patch_per_side, + device=self.pos_embed.device, dtype=self.pos_embed.dtype, + ) + ) def forward(self, position_ids): return self.pos_embed[position_ids] @@ -190,9 +244,16 @@ def __init__(self, max_latent_num_frames, max_latent_size, hidden_size): self._init_weights() def _init_weights(self): - # Initialize (and freeze) pos_embed by sin-cos embedding: - pos_embed = get_3d_sincos_pos_embed(self.hidden_size, self.max_num_latent_frames, self.max_latent_size, self.max_latent_size) - self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float()) + # See PositionEmbedding._init_weights for the meta-tensor rationale. + if self.pos_embed.is_meta: + return + with torch.no_grad(): + self.pos_embed.data.copy_( + _torch_3d_sincos( + self.hidden_size, self.max_num_latent_frames, self.max_latent_size, self.max_latent_size, + device=self.pos_embed.device, dtype=self.pos_embed.dtype, + ) + ) def forward(self, position_ids): return self.pos_embed[position_ids] diff --git a/modeling/lance/qwen2_navit.py b/modeling/lance/qwen2_navit.py index 8f3c8d8..e7d0f3b 100644 --- a/modeling/lance/qwen2_navit.py +++ b/modeling/lance/qwen2_navit.py @@ -616,7 +616,13 @@ def forward_train( # Self Attention if attention_mask is not None: - attention_mask = attention_mask.to(device=packed_sequence_.device) + # Mask may be a BlockMask (single tensor) or a List of per-sample dense + # masks (model-parallel path that routes attention through SDPA). Move + # each element onto this layer's shard so SDPA's device check passes. + if isinstance(attention_mask, list): + attention_mask = [m.to(device=packed_sequence_.device) for m in attention_mask] + else: + attention_mask = attention_mask.to(device=packed_sequence_.device) packed_sequence_ = self.self_attn( packed_sequence=packed_sequence_, @@ -892,6 +898,13 @@ def forward_train( **kwargs, ) + # Model-parallel: after the layer loop, packed_sequence lives on whichever + # shard ran the last layer (e.g. cuda:4). The index tensors and norm modules + # are pinned to cuda:0. Move the sequence back so the parent-level indexing + # below combines tensors on a single device. + if self.use_moe and packed_sequence.device != packed_und_token_indexes.device: + packed_sequence = packed_sequence.to(packed_und_token_indexes.device) + if self.use_moe: packed_sequence_ = torch.zeros_like(packed_sequence) packed_sequence_[packed_und_token_indexes] = self.norm(packed_sequence[packed_und_token_indexes]).to(dtype=packed_sequence.dtype) @@ -956,6 +969,15 @@ def forward_inference( **kwargs, ) + # Model-parallel (inference twin of the forward_train fix): after the layer + # loop, packed_query_sequence lives on the last layer's shard (e.g. cuda:4), + # but the index tensors and norm modules are pinned to cuda:0. The gen-mode + # index-put below combines across devices, which accelerate's hooks can't + # reach (it's parent-level Python, not a submodule boundary). Move the + # sequence back to the index device first. + if self.use_moe and mode == "gen" and packed_query_sequence.device != packed_text_indexes.device: + packed_query_sequence = packed_query_sequence.to(packed_text_indexes.device) + if self.use_moe: if mode == "und": packed_query_sequence = self.norm(packed_query_sequence) diff --git a/modeling/vae/wan/model.py b/modeling/vae/wan/model.py index 77f65de..54a71b3 100644 --- a/modeling/vae/wan/model.py +++ b/modeling/vae/wan/model.py @@ -32,6 +32,53 @@ def reparameterize(mu, log_var): return eps * std + mu +# --------------------------------------------------------------------------- +# Spatial-tiled VAE decode (see TILED_VAE.md). +# +# The video VAE decode's conv activations for a single frame at full resolution +# OOM a 12 GB card above ~480-512^2. The decode is already streamed temporally +# (one latent frame at a time), so the remaining peak is purely spatial. Tiling +# the latent spatially, decoding each tile through the existing (temporally +# streamed) decode, and feather-blending the outputs bounds the per-tile memory +# to a small frame, lifting the resolution ceiling. +# --------------------------------------------------------------------------- + +# Latent spatial size (cells) above which auto-tiling kicks in (vae_tile_size==0). +# 480^2 -> h=30 fits plainly; 512^2 -> 32 fits; 768^2 -> 48 OOMs. Threshold sits between. +_VAE_AUTO_TILE_THRESHOLD = 36 +_VAE_DEFAULT_TILE = 32 # latent cells per tile (512 px output at 16x upsample) +_VAE_DEFAULT_OVERLAP = 8 # latent cells of overlap between adjacent tiles + + +def _tile_starts(n: int, tile: int, stride: int) -> List[int]: + """Start indices of tiles covering [0, n); the last tile is snapped to the + edge so the whole extent is covered even when n is not a multiple of stride.""" + if n <= tile: + return [0] + starts = list(range(0, n - tile + 1, stride)) + if starts[-1] != n - tile: + starts.append(n - tile) + return starts + + +def _blend_ramp_1d(length: int, ramp: int, ramp_lo: bool, ramp_hi: bool, + device, dtype) -> Tensor: + """1-D blend weight: 1.0 everywhere, linearly ramped toward (but not to) 0 on + edges that overlap a neighbor. Two adjacent tiles' opposing ramps span the same + overlap band and sum to ~1; the caller's weight-sum normalization makes the + blend exact regardless, while single-coverage regions stay at weight 1.""" + w = torch.ones(length, device=device, dtype=dtype) + r = min(ramp, length // 2) + if r > 0: + # values in (0, 1): 1/(r+1) .. r/(r+1) — never exactly 0, so wsum > 0. + vals = torch.linspace(1.0 / (r + 1), r / (r + 1), r, device=device, dtype=dtype) + if ramp_lo: + w[:r] = vals + if ramp_hi: + w[length - r:] = vals.flip(0) + return w + + class WanVideoVAE(object): __version__ = "v2.2" __name__ = "WanVideoVAE" @@ -43,10 +90,22 @@ def __init__(self, config_path: str = "", **kwargs) -> None: self.logger = self.__class__.__logger__ self.dtype = kwargs.get("dtype", torch.bfloat16) + # Allow the VAE to live on a card other than cuda:LOCAL_RANK. Under + # model-parallel sharding, cuda:0 is the most crowded device (embed, lm_head, + # ViT, first LLM layers), and the video VAE decode's conv activations OOM it. + # Placing the VAE on the lightest shard gives the decode room to breathe. + # Defaults to get_device() so single-GPU behavior is unchanged. self.device = torch.device(kwargs.get("device", get_device())) self.configure_vae_model() self.use_sample = kwargs.get("use_sample", True) + # Spatial-tiled decode config (latent cells). See TILED_VAE.md. + # tile_size > 0 : tile whenever max(h, w) > tile_size + # tile_size == 0: auto — tile when max(h, w) > _VAE_AUTO_TILE_THRESHOLD + # tile_size < 0: never tile (force plain decode) + self.tile_size = int(kwargs.get("tile_size", 0) or 0) + self.tile_overlap = int(kwargs.get("tile_overlap", _VAE_DEFAULT_OVERLAP)) + # wan vae2.2 config is equal to seedance vae self.vae_config = AutoEncoderParams( downsample_spatial=16, @@ -97,6 +156,57 @@ def vae_encode(self, samples: List[Tensor], **kwargs) -> List[Tensor]: return latents + def _should_tile(self, u: Tensor) -> bool: + """Decide whether to spatially tile the decode of latent u [1,48,t,h,w].""" + if self.tile_size < 0: + return False + h, w = u.shape[-2], u.shape[-1] + threshold = self.tile_size if self.tile_size > 0 else _VAE_AUTO_TILE_THRESHOLD + return max(h, w) > threshold + + def _tiled_decode(self, u: Tensor) -> Tensor: + """Decode latent u [1,48,t,h,w] in overlapping spatial tiles and + feather-blend into the full output. Each tile reuses self.vae.decode, + which resets its own temporal feat_cache, so every tile is a correct + independent temporal stream. Returns [1,3,T,H,W].""" + _, _, _, h, w = u.shape + tile = self.tile_size if self.tile_size > 0 else _VAE_DEFAULT_TILE + # overlap must leave a positive stride and fit within a tile + overlap = max(0, min(self.tile_overlap, tile // 2 - 1)) + stride = max(1, tile - overlap) + + row_starts = _tile_starts(h, tile, stride) + col_starts = _tile_starts(w, tile, stride) + + canvas = None + wsum = None + f = None # spatial upsample factor (pixels per latent cell), inferred from first tile + for r0 in row_starts: + r1 = min(r0 + tile, h) + for c0 in col_starts: + c1 = min(c0 + tile, w) + out = self.vae.decode(u[:, :, :, r0:r1, c0:c1]) # [1,3,T,(r1-r0)*f,(c1-c0)*f] + + if canvas is None: + f = out.shape[-2] // (r1 - r0) + T_out, C_out = out.shape[2], out.shape[1] + H, W = h * f, w * f + canvas = torch.zeros((1, C_out, T_out, H, W), dtype=out.dtype, device=out.device) + wsum = torch.zeros((1, 1, 1, H, W), dtype=out.dtype, device=out.device) + + py0, py1, px0, px1 = r0 * f, r1 * f, c0 * f, c1 * f + wy = _blend_ramp_1d(py1 - py0, overlap * f, ramp_lo=(r0 != 0), ramp_hi=(r1 != h), + device=out.device, dtype=out.dtype) + wx = _blend_ramp_1d(px1 - px0, overlap * f, ramp_lo=(c0 != 0), ramp_hi=(c1 != w), + device=out.device, dtype=out.dtype) + w2d = (wy[:, None] * wx[None, :])[None, None, None, :, :] # [1,1,1,ph,pw] + + canvas[:, :, :, py0:py1, px0:px1] += out * w2d + wsum[:, :, :, py0:py1, px0:px1] += w2d + del out + + return canvas / wsum.clamp(min=1e-6) + @torch.no_grad() def vae_decode(self, latents: List[Tensor], **kwargs) -> List[Tensor]: device = self.device @@ -107,7 +217,10 @@ def vae_decode(self, latents: List[Tensor], **kwargs) -> List[Tensor]: u = u.unsqueeze(0).to(device=device) # -> [1,t,h,w,48] u = rearrange(u, "b ... c -> b c ...") # -> [1,48,t,h,w] - x_hat = self.vae.decode(u) # -> [1,3,T,H,W] + if self._should_tile(u): + x_hat = self._tiled_decode(u) # -> [1,3,T,H,W] + else: + x_hat = self.vae.decode(u) # -> [1,3,T,H,W] samples.append(x_hat.squeeze(0)) # -> List[[3,T,H,W]]