Skip to content

Run Lance inference on low-RAM / multi-GPU consumer hardware (e.g. 8 GB RAM + 5×3060)#43

Open
johbau wants to merge 5 commits into
bytedance:mainfrom
johbau:low-ram-sharded-load
Open

Run Lance inference on low-RAM / multi-GPU consumer hardware (e.g. 8 GB RAM + 5×3060)#43
johbau wants to merge 5 commits into
bytedance:mainfrom
johbau:low-ram-sharded-load

Conversation

@johbau

@johbau johbau commented Jun 19, 2026

Copy link
Copy Markdown

Enables Lance inference on hosts that can't fit the model the stock way — too little system RAM to materialize the 3B model on CPU, and no single GPU large enough for it. Validated
end-to-end on an 8 GB-RAM / 5×RTX 3060 (12 GB each) box across all four task families: x2t_image, x2t_video, t2i, and t2v (incl. 768² via tiling).

Five focused commits:

  • Low-RAM load — build the model under init_empty_weights() and stream the safetensors checkpoint directly onto the GPU in bf16 (no full fp32 copy on CPU, no mmap), so weight load
    no longer OOMs on low system RAM.
  • Model-parallel sharding — a device_map + accelerate.dispatch_model spread the LLM layers across N GPUs (--shard_num_gpus); flex-attention is swapped for an eager-SDPA dense-mask
    path that's safe across shards.
  • Generation-path sharding fixes — device-alignment for the forward_inference (KV-cache) gen path and a dedicated VAE card for video decode.
  • Smoke-test scaffolding — single-prompt example configs.
  • Tiled VAE decode — overlapping spatial tiles with feather-blending lift the VAE-decode resolution ceiling (768²+ on a 12 GB card); auto-enables above ~512², off by default below.

No behavior change for hosts with enough RAM/VRAM: --shard_num_gpus defaults to all visible GPUs (collapses to single-GPU when there's one), and tiling stays off at low
resolutions. Design notes in LOW_RAM_LOAD.md, SHARDED_LOAD.md, TILED_VAE.md.

johbau and others added 5 commits June 19, 2026 17:58
Lets inference_lance.py run on hosts with less system RAM than the model
needs to materialize in fp32 on CPU (~12 GB for Lance_3B). On main, the
first call `Qwen2ForCausalLM(llm_config)` allocates a freshly-init'd fp32
3B model on CPU and OOM-kills an 8 GB host before any GPU code runs.

Changes:

- Build LLM / ViT / Lance wrapper under `accelerate.init_empty_weights()`
  so every nn.Parameter is shape-only on the meta device — near-zero CPU
  RAM during construction.

- Replace `safetensors.safe_open()` with a hand-rolled reader that does
  plain seek+read of one tensor at a time. safe_open mmaps the whole 12 GB
  file, which Linux refuses on a host with strict overcommit / no swap
  (ENOMEM). Peak CPU RAM during load is one tensor at a time.

- Pass `dtype=torch.bfloat16` to `set_module_tensor_to_device` so loaded
  values aren't silently upcast back to the meta tensor's fp32 default.
  Without this the model lives at fp32 on the GPU, doubling VRAM and
  breaking the bf16 autocast path (fp32 weights * bf16 activations
  → fp32 output, then index-put into bf16 destination crashes).

- Replace numpy fp64 sin-cos position embeddings with a torch fp32 port
  that computes on the param's device. PositionEmbedding3D._init_weights
  used to peak around 4 GB of CPU RAM building 3 intermediate arrays of
  shape (t*h*w, ~D/3); the GPU version contributes ~zero CPU.

- Materialize any params left on `meta` after the load (the popped
  latent_pos_embed sin-cos buffer) on the target device and re-init.

- Set `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` so the per-tensor
  streaming pattern doesn't fragment the CUDA caching allocator into a
  state where large allocations fail despite plenty of free VRAM.

Peak CPU RSS during load stays under ~2 GB. The model loads in bf16
across whatever device the runner picks (cuda:LOCAL_RANK). Multi-GPU
sharding for hosts that can't fit the model on one card is a separate
follow-up — see SHARDED_LOAD.md.

See LOW_RAM_LOAD.md for the full memory profile and per-file rationale.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Builds on the low-RAM streaming load (previous commit) to enable single-
process, model-parallel inference across N GPUs. Lets Lance run on hosts
where no single card has enough VRAM but the aggregate does — e.g.
5 × RTX 3060 (60 GB) for Lance_3B + ViT + VAE.

Changes:

- `_build_lance_device_map(model, num_gpus)` spreads LLM transformer
  layers across cuda:0..N-1 with cuda:0 getting a reduced share (it also
  hosts embed/lm_head/ViT/VAE/connectors). A safety net pins any
  uncovered parameter to cuda:0 so future top-level MoT siblings don't
  break dispatch.

- `accelerate.dispatch_model` installs pre/post forward hooks that move
  activations between cards as needed. The streaming loader from the
  previous commit already routes each tensor onto its target shard at
  load time; this commit just attaches the runtime hooks.

- Replace flex_attention with eager-SDPA in all three call sites in
  lance.py. flex_attention's BlockMask captures device-specific tensors
  that dynamo refuses to combine with Q/K/V from a different shard.
  qwen2_navit.py already has an isinstance(attention_mask, List) → SDPA
  branch that crosses devices cleanly via accelerate hooks. A new helper
  `_flex_mask_to_dense_list` evaluates the flex mask function on a
  (q_idx, kv_idx) meshgrid to produce that List.

- Align parent-class Python combine sites that dispatch_model's hooks
  can't reach:
  - Pin ViT to cuda:0 so its output matches embed_tokens for the inline
    `masked_scatter` in validation_video_to_text.
  - Move `packed_sequence` back to the index tensor's device after the
    Qwen2Model layer loop, so the final norm/lm_head indexing combine
    works.
  - Handle both Tensor and List in the per-layer
    `attention_mask.to(device=…)` call.

- Launcher and config:
  - inference_lance.sh: NUM_GPUS=5 default (now means shard count, not
    data-parallel rank count), `--num_processes 1`, forwards
    `--shard_num_gpus $NUM_GPUS`.
  - InferenceArguments: new `shard_num_gpus: int = 0`. 0 = use all
    visible GPUs; >0 caps to that many.

Behavior on a 1-GPU host is unchanged — the device map collapses to
"everything on cuda:0" and dispatch_model is skipped. The dense-mask
SDPA replacement runs unconditionally; flex_attention can be gated
behind a flag if a single-card user wants the compiled kernel back.

Memory profile on a 5 × 3060 / 8 GB RAM host:
  cuda:0  ~6 GB  (3 LLM layers + ViT + VAE + embed + lm_head + extras)
  cuda:1  ~3 GB  (8 LLM layers)
  cuda:2  ~3 GB  (8 LLM layers)
  cuda:3  ~3 GB  (8 LLM layers)
  cuda:4  ~3 GB  (8 LLM layers)

Smoke test (x2t_image, 768 res, 6 cases) completes successfully. About
67 s per understanding batch — slow because activations shuttle across
PCIe between cards and SDPA is eager. The point is fitting the model on
this hardware, not throughput.

See SHARDED_LOAD.md for the full rationale and per-file summary.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The diffusion + VAE-decode generation path uses forward_inference (KVcache)
and the WanVideoVAE, neither exercised by the understanding path. Three
model-parallel fixes were needed to run t2i/t2v across multiple GPUs:

- forward_inference gen-mode norm (qwen2_navit.py): inference-mode twin of
  the forward_train post-loop fix. After the decoder-layer loop the sequence
  is on the last shard, but the index tensors + norm/norm_moe_gen are on
  cuda:0. Added a .to(packed_text_indexes.device) guard before the gen-mode
  index-put. Runs every diffusion timestep. (The KVcache path attends via
  flash_attn_varlen_func, not flex_attention, so the dense-mask change is not
  exercised here.)

- VAE device override (modeling/vae/wan/model.py): WanVideoVAE hard-coded
  get_device()=cuda:0. It now accepts a device= (default get_device(), so
  single-GPU is unchanged) used by configure_vae_model/vae_encode/vae_decode,
  so the VAE can live on a less-crowded card.

- Dedicated VAE card (inference_lance.py): the video VAE decode's conv
  activations (~9 GB at 480^2 / 17 frames) won't fit on a card that also
  holds LLM layers. For generation tasks, _build_lance_device_map now takes
  reserve_last_for_vae and shards the LLM across the first N-1 cards, leaving
  the last card empty of LLM weights; the VAE is built there.

Confirmed on the 8 GB-RAM / 5x3060 host: t2i produces a clean PNG and t2v
produces a valid 480x480 17-frame h264 mp4 (1.42 s). Note: 768^2 video decode
exceeds a single 12 GB card even when dedicated — that needs VAE decode tiling
(not implemented). The launcher's VIDEO_HEIGHT/WIDTH default to 768, so pass
--VIDEO_HEIGHT 480 --VIDEO_WIDTH 480 for t2v on 12 GB cards. See SHARDED_LOAD.md.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Convenience for bounded generation smoke tests (one output instead of all
prompts in the example JSON, which apply_inference_defaults expands to
validation_max_samples=100000):

- inference_lance.sh: new --DATASET_CONFIG passthrough that forwards
  --val_dataset_config_file to the Python side.
- config/examples/t2i_single.json, t2v_single.json: first prompt of the
  corresponding example file, so a smoke test generates a single image/video.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Lifts the VAE-decode resolution ceiling (previously ~480-512^2 on a 12 GB
card; 768^2 OOMs even on a card dedicated to the VAE). The decode is already
streamed temporally, so the memory peak is a single frame's full-resolution
conv activations — a spatial problem, not a weight one (so LLM-style layer
sharding wouldn't help). Tiling the latent spatially and feather-blending the
per-tile decodes bounds per-tile memory and lifts the ceiling.

- WanVideoVAE._tiled_decode / _should_tile (modeling/vae/wan/model.py):
  slice the latent [1,48,t,h,w] into overlapping spatial tiles, decode each
  via the existing self.vae.decode (which resets its own temporal feat_cache,
  so each tile is a correct independent temporal stream), and blend into the
  output with a linear edge ramp + weight-sum normalization. Reuses the
  validated decode per tile — no Decoder3d rewrite.
- Config: vae_tile_size (0=auto above ~512^2 latent, >0=force, <0=disable) and
  vae_tile_overlap in InferenceArguments; plumbed through inference_lance.py
  and inference_lance.sh (--VAE_TILE / --VAE_TILE_OVERLAP).
- Blend arithmetic unit-tested off-GPU: reconstructing exact-tile decodes
  matches the source to ~1e-16 across divisible/non-divisible/768^2-latent
  cases with full coverage (wsum >= 1).

TILED_VAE.md documents the design (Approach A: single-GPU tiling, implemented;
Approach B: multi-GPU tile distribution, proposal) and the pending in-container
validation (480^2 parity, 768^2 memory, seams). SHARDED_LOAD.md cross-links it.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant