diff --git a/cosmos_framework/callbacks/dit_image_sample.py b/cosmos_framework/callbacks/dit_image_sample.py new file mode 100644 index 0000000..4615e37 --- /dev/null +++ b/cosmos_framework/callbacks/dit_image_sample.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Fixed class-conditioned image sampling callback for DiT training.""" + +from __future__ import annotations + +from contextlib import nullcontext +from typing import Any + +import torch +import torchvision +import wandb + +from cosmos_framework.callbacks.every_n import EveryN +from cosmos_framework.model._base import ImaginaireModel +from cosmos_framework.trainer import ImaginaireTrainer +from cosmos_framework.utils import distributed + + +class DiTImageSampleCallback(EveryN): + """Generate fixed ImageNet class samples through ``model.generate_image``.""" + + def __init__( + self, + every_n: int = 5000, + class_ids: list[int] | None = None, + cfg_scales: list[float] | None = None, + num_steps: int = 50, + seed: int = 0, + is_ema: bool = True, + run_at_start: bool = False, + ) -> None: + super().__init__(every_n=every_n, run_at_start=run_at_start) + self.class_ids = class_ids or [0, 1, 2, 3] + self.cfg_scales = cfg_scales or [1.0, 1.25, 1.5, 2.0] + self.num_steps = num_steps + self.seed = seed + self.is_ema = is_ema + self.rank = distributed.get_rank() + + @torch.no_grad() + def every_n_impl( + self, + trainer: ImaginaireTrainer, + model: ImaginaireModel, + data_batch: dict[str, torch.Tensor], + output_batch: dict[str, torch.Tensor], + loss: torch.Tensor, + iteration: int, + ) -> None: + del trainer, data_batch, output_batch, loss + + if not hasattr(model, "generate_image"): + raise AttributeError("DiTImageSampleCallback requires model.generate_image().") + if self.is_ema and not model.config.ema.enabled: + return + + was_training = model.training + context: Any = model.ema_scope("dit_image_sample") if self.is_ema else nullcontext() + generated_rows: list[torch.Tensor] = [] + seed_list = [self.seed + sample_idx for sample_idx in range(len(self.class_ids))] + try: + with context: + for cfg_scale in self.cfg_scales: + images = model.generate_image( + class_ids=self.class_ids, + num_steps=self.num_steps, + cfg_scale=cfg_scale, + seed=seed_list, + ) # [B,3,H,W] + if self.rank == 0: + generated_rows.append(images.detach().float().cpu()) # [B,3,H,W] + finally: + if was_training: + model.train() + + if self.rank != 0 or wandb.run is None or not generated_rows: + return + + grid_images = torch.cat(generated_rows, dim=0) # [R*B,3,H,W] + grid = torchvision.utils.make_grid(grid_images, nrow=len(self.class_ids), padding=2, normalize=False) # [3,H,W] + grid_np = grid.clamp(0.0, 1.0).permute(1, 2, 0).numpy() # [H,W,3] + tag = "ema" if self.is_ema else "reg" + caption = f"classes={self.class_ids}, cfg={self.cfg_scales}, steps={self.num_steps}, seed={self.seed}, {tag}" + wandb.log( + {f"dit_image_sample/{tag}": wandb.Image(grid_np, caption=caption)}, + step=iteration, + ) diff --git a/cosmos_framework/callbacks/iter_speed.py b/cosmos_framework/callbacks/iter_speed.py index 738721e..7870003 100644 --- a/cosmos_framework/callbacks/iter_speed.py +++ b/cosmos_framework/callbacks/iter_speed.py @@ -78,7 +78,8 @@ def every_n_impl( ) per_sample_batch_counter = dict() - if hasattr(model, "is_image_batch"): + # for VFM + if hasattr(model, "is_image_batch") and hasattr(model, "input_image_key") and hasattr(model, "input_video_key"): is_image_batch = model.is_image_batch(data_batch) if is_image_batch: image_batch_size = len(data_batch[model.input_image_key]) @@ -86,6 +87,18 @@ def every_n_impl( else: video_batch_size = len(data_batch[model.input_video_key]) per_sample_batch_counter["video_batch_size"] = video_batch_size + # for LLM training only + elif "input_ids" in data_batch: + mbs = data_batch["input_ids"].shape[0] + dp_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + grad_accum_iter = int(trainer.config.trainer.grad_accum_iter) + per_sample_batch_counter["token_batch_size"] = mbs + per_sample_batch_counter["token_global_batch_size"] = mbs * dp_size * grad_accum_iter + # Cumulative token count (LLM analog of sample_counter). Set by + # ``LLMPretrainModel.training_step`` into a persistent buffer on + # ``model.net``, so this value survives checkpoint resume. + if hasattr(model, "token_counter"): + per_sample_batch_counter["token_counter"] = model.token_counter if wandb.run: sample_counter = getattr(trainer, "sample_counter", iteration) diff --git a/cosmos_framework/callbacks/norm_monitor.py b/cosmos_framework/callbacks/norm_monitor.py index 46f51cb..3c68303 100644 --- a/cosmos_framework/callbacks/norm_monitor.py +++ b/cosmos_framework/callbacks/norm_monitor.py @@ -14,7 +14,7 @@ from cosmos_framework.utils import distributed, log, misc from cosmos_framework.utils.callback import Callback from cosmos_framework.utils.easy_io import easy_io -from cosmos_framework.data.vfm.sequence_packing import get_gen_seq +from cosmos_framework.data.vfm.sequence_packing.runtime import get_gen_seq try: from apex.contrib.layer_norm import FastLayerNorm diff --git a/cosmos_framework/callbacks/sequence_packing_padding.py b/cosmos_framework/callbacks/sequence_packing_padding.py index c4ab4b8..91f3b7e 100644 --- a/cosmos_framework/callbacks/sequence_packing_padding.py +++ b/cosmos_framework/callbacks/sequence_packing_padding.py @@ -4,10 +4,10 @@ import torch import wandb -import cosmos_framework.data.vfm.sequence_packing as sequence_packing from cosmos_framework.callbacks.every_n import EveryN from cosmos_framework.model._base import ImaginaireModel from cosmos_framework.trainer import ImaginaireTrainer +from cosmos_framework.data.vfm.sequence_packing.runtime import get_padding_stats class SequencePackingPadding(EveryN): @@ -32,11 +32,12 @@ def every_n_impl( iteration: int, ) -> None: if wandb.run: + padding_stats = get_padding_stats() log_dict = { - "SequencePackingPadding/max_causal_len_image_batch": sequence_packing.MAX_CAUSAL_LEN_IMAGE_BATCH, - "SequencePackingPadding/max_full_len_image_batch": sequence_packing.MAX_FULL_LEN_IMAGE_BATCH, - "SequencePackingPadding/max_causal_len_video_batch": sequence_packing.MAX_CAUSAL_LEN_VIDEO_BATCH, - "SequencePackingPadding/max_full_len_video_batch": sequence_packing.MAX_FULL_LEN_VIDEO_BATCH, + "SequencePackingPadding/max_causal_len_image_batch": padding_stats["MAX_CAUSAL_LEN_IMAGE_BATCH"], + "SequencePackingPadding/max_full_len_image_batch": padding_stats["MAX_FULL_LEN_IMAGE_BATCH"], + "SequencePackingPadding/max_causal_len_video_batch": padding_stats["MAX_CAUSAL_LEN_VIDEO_BATCH"], + "SequencePackingPadding/max_full_len_video_batch": padding_stats["MAX_FULL_LEN_VIDEO_BATCH"], } modality = "video" if "is_image_batch" in output_batch: diff --git a/cosmos_framework/configs/base/defaults/callbacks.py b/cosmos_framework/configs/base/defaults/callbacks.py index 805ffb5..dfe0331 100644 --- a/cosmos_framework/configs/base/defaults/callbacks.py +++ b/cosmos_framework/configs/base/defaults/callbacks.py @@ -10,7 +10,9 @@ from cosmos_framework.utils.lazy_config import LazyCall as L from cosmos_framework.utils.callback import LowPrecisionCallback, WandBCallback from cosmos_framework.callbacks.compile_tokenizer import CompileTokenizer + from cosmos_framework.callbacks.device_monitor import DeviceMonitor +from cosmos_framework.callbacks.dit_image_sample import DiTImageSampleCallback from cosmos_framework.callbacks.every_n_draw_sample import EveryNDrawSample from cosmos_framework.callbacks.expert_heatmap import ExpertHeatmap from cosmos_framework.callbacks.grad_clip import GradClip @@ -72,6 +74,41 @@ ofu=L(OFUCallback)(every_n="${trainer.logging_iter}"), ) +# LLM-only subset of BASIC_CALLBACKS. +# Drops VFM-specific callbacks: +# - CompileTokenizer: requires model.tokenizer_vision_gen (VAE) +# - ExpertHeatmap: requires MoE language_model with mlp_moe_gen +# - SigmaLossAnalysis: rectified-flow specific +# - SequencePackingPadding: VFM multi-modal packing specific +# - NormMonitor: param filter assumes "moe_gen" params → logs nothing for dense LLM +# Drops Necessary but not supported Callbacks: +# - MFU: @TODO +BASIC_LLM_CALLBACKS = dict( + iter_speed=L(IterSpeed)( + every_n="${trainer.logging_iter}", + save_s3="${upload_reproducible_setup}", + save_s3_every_log_n=500, + hit_thres=50, + ), + manual_gc=L(ManualGarbageCollection)(every_n=5), + wandb=L(WandBCallback)(), + wandb_2x=L(WandBCallbackMultiplier)( + logging_iter_multipler=2, + save_logging_iter_multipler=1, + save_s3="${upload_reproducible_setup}", + ), + param_count=L(ParamCount)( + save_s3="${upload_reproducible_setup}", + ), + wandb_val=L(WandBCallbackEval)( + save_s3="${upload_reproducible_setup}", + ), + ofu=L(OFUCallback)(every_n="${trainer.logging_iter}"), +) + +# DiT-safe subset for LLM-backed rectified-flow image training. +BASIC_DIT_CALLBACKS = dict(BASIC_LLM_CALLBACKS) + JOB_MONITOR_CALLBACKS = dict( heart_beat=L(HeartBeat)( every_n=200, @@ -94,6 +131,19 @@ low_precision=L(LowPrecisionCallback)(update_iter=1, config=PLACEHOLDER, trainer=PLACEHOLDER), # use model ) +OPTIMIZATION_LLM_CALLBACKS = dict( + skip_nan_step=L(SkipNaNStep)(max_consecutive_nan=100), + grad_clip=L(GradClip)(clip_norm=1.0, track_per_modality=False), + low_precision=L(LowPrecisionCallback)(update_iter=1, config=PLACEHOLDER, trainer=PLACEHOLDER), +) + +# DiT reuses the same GradClip callback as LLM, without VFM image/video grad-norm split. +OPTIMIZATION_DIT_CALLBACKS = dict( + skip_nan_step=L(SkipNaNStep)(max_consecutive_nan=100), + grad_clip=L(GradClip)(clip_norm=1.0, track_per_modality=False), + low_precision=L(LowPrecisionCallback)(update_iter=1, config=PLACEHOLDER, trainer=PLACEHOLDER), +) + VIZ_ONLINE_SAMPLING_CALLBACKS = dict( every_n_sample_reg=L(EveryNDrawSample)( every_n=5000, @@ -108,18 +158,34 @@ ), ) +DIT_IMAGE_SAMPLING_CALLBACKS = dict( + dit_image_sample_ema=L(DiTImageSampleCallback)( + every_n=5000, + class_ids=[0, 1, 2, 3], + cfg_scales=[1.0, 1.25, 1.5, 2.0], + num_steps=50, + seed=0, + is_ema=True, + ), +) + def register_callbacks(): cs = ConfigStore.instance() cs.store(group="callbacks", package="trainer.callbacks", name="basic", node=BASIC_CALLBACKS) cs.store(group="callbacks", package="trainer.callbacks", name="job_monitor", node=JOB_MONITOR_CALLBACKS) cs.store(group="callbacks", package="trainer.callbacks", name="optimization", node=OPTIMIZATION_CALLBACKS) + cs.store(group="callbacks", package="trainer.callbacks", name="optimization_llm", node=OPTIMIZATION_LLM_CALLBACKS) + cs.store(group="callbacks", package="trainer.callbacks", name="optimization_dit", node=OPTIMIZATION_DIT_CALLBACKS) # Online sampling generation callback cs.store( group="callbacks", package="trainer.callbacks", name="viz_online_sampling", node=VIZ_ONLINE_SAMPLING_CALLBACKS ) # Register "generation" as alias for "viz_online_sampling" (expected by base config.py defaults) cs.store(group="callbacks", package="trainer.callbacks", name="generation", node=VIZ_ONLINE_SAMPLING_CALLBACKS) + cs.store( + group="callbacks", package="trainer.callbacks", name="dit_image_sampling", node=DIT_IMAGE_SAMPLING_CALLBACKS + ) TRAINING_STATS_CALLBACKS = dict( training_stats=L(TrainingStatsCallback)( @@ -127,3 +193,7 @@ def register_callbacks(): ) ) cs.store(group="callbacks", package="trainer.callbacks", name="training_stats", node=TRAINING_STATS_CALLBACKS) + + # Only for LLM training, removed callbacks that is not working for llm training + cs.store(group="callbacks", package="trainer.callbacks", name="basic_llm", node=BASIC_LLM_CALLBACKS) + cs.store(group="callbacks", package="trainer.callbacks", name="basic_dit", node=BASIC_DIT_CALLBACKS) diff --git a/cosmos_framework/configs/base/defaults/llm_model.py b/cosmos_framework/configs/base/defaults/llm_model.py new file mode 100644 index 0000000..a4e55a9 --- /dev/null +++ b/cosmos_framework/configs/base/defaults/llm_model.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Register LLM and DiT model configs alongside the existing ``mot_fsdp``.""" + +from hydra.core.config_store import ConfigStore + +from cosmos_framework.utils.lazy_config import LazyCall as L +from cosmos_framework.model.vfm.llm.dit.image_dit_model import DiTPretrainModel, DiTPretrainModelConfig +from cosmos_framework.model.vfm.llm.llm_pretrain_model import LLMPretrainModel, LLMPretrainModelConfig + +# ── FSDP config (production, multi-GPU) ────────────────────────────────────── + +LLM_FSDP_CONFIG = dict( + trainer=dict( + distributed_parallelism="fsdp", + ), + model=L(LLMPretrainModel)( + config=LLMPretrainModelConfig(), + _recursive_=False, + ), +) + +# ── DDP config (debug, single-node) ───────────────────────────────────────── + +LLM_DDP_CONFIG = dict( + trainer=dict( + distributed_parallelism="ddp", + ), + model=L(LLMPretrainModel)( + config=LLMPretrainModelConfig(), + _recursive_=False, + ), +) + +# ── Image DiT configs ─────────────────────────────────────────────────────── + +DIT_FSDP_CONFIG = dict( + trainer=dict( + distributed_parallelism="fsdp", + ), + model=L(DiTPretrainModel)( + config=DiTPretrainModelConfig(), + _recursive_=False, + ), +) + +DIT_DDP_CONFIG = dict( + trainer=dict( + distributed_parallelism="ddp", + ), + model=L(DiTPretrainModel)( + config=DiTPretrainModelConfig(), + _recursive_=False, + ), +) + + +def register_llm_model(): + cs = ConfigStore.instance() + cs.store(group="model", package="_global_", name="llm_fsdp", node=LLM_FSDP_CONFIG) + cs.store(group="model", package="_global_", name="llm_ddp", node=LLM_DDP_CONFIG) + cs.store(group="model", package="_global_", name="dit_fsdp", node=DIT_FSDP_CONFIG) + cs.store(group="model", package="_global_", name="dit_ddp", node=DIT_DDP_CONFIG) diff --git a/cosmos_framework/configs/base/defaults/model_config.py b/cosmos_framework/configs/base/defaults/model_config.py index f7e7c8d..0695aab 100644 --- a/cosmos_framework/configs/base/defaults/model_config.py +++ b/cosmos_framework/configs/base/defaults/model_config.py @@ -81,10 +81,6 @@ class RectifiedFlowTrainingConfig: loss_scale: float = 1.0 # Loss scale image_loss_scale: float | None = None # If set, overrides loss_scale for images sound_loss_scale: float | None = None # If set, overrides loss_scale for sound - use_high_sigma_strategy: bool = False # Whether to use high sigma strategy - high_sigma_ratio: float = 0.05 # Ratio of using high sigmas - high_sigma_timesteps_min: int = 995 # Minimum timestep for high sigma - high_sigma_timesteps_max: int = 1000 # Maximum timestep for high sigma use_discrete_rf: bool = False # Whether to use discrete formulation of rectified flow # user: please adjust this value according to loss_scale to balance the action loss with the video loss. @@ -93,21 +89,16 @@ class RectifiedFlowTrainingConfig: # Independent noise schedule for action. When False (default), action shares the sigma # sampled from the vision RF on every step — legacy behavior. When True, action samples - # its own sigma from `rectified_flow_action` using `shift_action` and - # `use_high_sigma_strategy_action`. Action always uses a shared scalar sigma per sample - # ([B,1]), independent of vision's DF mode. If action opts in to the high-sigma strategy, - # it reuses the global ratio / min / max. + # its own sigma from `rectified_flow_action` using `shift_action`. Action always uses a + # shared scalar sigma per sample ([B,1]), independent of vision's DF mode. independent_action_schedule: bool = False shift_action: int | None = None # must be int; None → inherit `shift` (which must also be int) - use_high_sigma_strategy_action: bool = False # Independent noise schedule for sound. When False (default), sound shares the vision # sigma schedule, reindexed to the dense audio-bearing subset. When True, sound samples - # its own scalar sigma per sample ([B,1]) from `rectified_flow_sound` using `shift_sound` - # and `use_high_sigma_strategy_sound`. + # its own scalar sigma per sample ([B,1]) from `rectified_flow_sound` using `shift_sound`. independent_sound_schedule: bool = False shift_sound: int | None = None # must be int; None → inherit `shift` (which must also be int) - use_high_sigma_strategy_sound: bool = False # When True, per-instance flow-matching loss is normalized by the count of # active (noisy) elements rather than all elements — preserves sum/active_count @@ -204,9 +195,7 @@ class OmniMoTModelConfig: # Attention implementation for joint understanding + generation # Note "two_way" and "three_way" disallow and remove "End-of-Vision" or other text token in the generation tower. # "three_way" must only be used when introducing sparsity - joint_attn_implementation: str = ( - "two_way" # "two_way", "three_way" or "flex" (NOTICE: We are planning to remove "flex" soon) - ) + joint_attn_implementation: str = "two_way" # "two_way" or "three_way" # Per-layer NATTEN parameters # Must use "three_way" attention if used. diff --git a/cosmos_framework/configs/base/defaults/optimizer.py b/cosmos_framework/configs/base/defaults/optimizer.py index 245a5df..4e254e6 100644 --- a/cosmos_framework/configs/base/defaults/optimizer.py +++ b/cosmos_framework/configs/base/defaults/optimizer.py @@ -93,6 +93,23 @@ def register_schedulers(lambdacosine_kwargs: dict[str, Any]) -> None: **lambdacosine_kwargs, ), ) + # WSD (Warmup-Stable-Decay) scheduler for LLM pretraining + cs.store( + group="scheduler", + package="scheduler", + name="wsd", + node=L(build_lr_scheduler)( + optimizer=PLACEHOLDER, + lr_scheduler_type="wsd", + warm_up_steps=2000, + total_steps=50000, + decay_steps=5000, + decay_type="cosine", + f_start=0.01, + f_max=1.0, + f_min=0.1, + ), + ) def register_optimizer() -> None: diff --git a/cosmos_framework/configs/base/defaults/parallelism.py b/cosmos_framework/configs/base/defaults/parallelism.py index ffce654..59b3eff 100644 --- a/cosmos_framework/configs/base/defaults/parallelism.py +++ b/cosmos_framework/configs/base/defaults/parallelism.py @@ -3,13 +3,19 @@ """User-facing parallelism degrees shared by VFM and VLM trainers.""" +from typing import Literal + import attrs import torch +AttentionIOLayout = Literal["sequence_sharded", "replicated"] + # Canonical mapping from precision string (used in user-facing configs and # threaded through OmegaConf) to ``torch.dtype``. Consumed by sites that # need to translate ``precision`` / ``fsdp_master_dtype`` into concrete # torch dtypes (e.g. ``MixedPrecisionPolicy``, ``HFModel`` meta-init). + + PRECISION_TO_TORCH_DTYPE: dict[str, torch.dtype] = { "bfloat16": torch.bfloat16, "float16": torch.float16, @@ -31,6 +37,15 @@ class ParallelismConfig: # Number of ranks for context parallelism. context_parallel_shard_degree: int = 1 + # Tensor layout at the attention boundary when CP is enabled. Both + # layouts may run the attention kernel with head-sharded Q/K/V: + # ``sequence_sharded`` keeps surrounding projections/MLP sequence-sharded + # with Ulysses-style all-to-all into/out of attention, while + # ``replicated`` keeps current-frame hidden states replicated, slices + # local heads before attention, then reduces/gathers attention output back + # to replicated hidden states. + attention_io_layout: AttentionIOLayout = "sequence_sharded" + # Number of ranks for CFG parallelism. cfg_parallel_shard_degree: int = 1 diff --git a/cosmos_framework/configs/base/defaults/tokenizer.py b/cosmos_framework/configs/base/defaults/tokenizer.py index 526d579..530907d 100644 --- a/cosmos_framework/configs/base/defaults/tokenizer.py +++ b/cosmos_framework/configs/base/defaults/tokenizer.py @@ -8,10 +8,12 @@ from cosmos_framework.model.vfm.tokenizers.audio.avae import AVAEInterface from cosmos_framework.model.vfm.tokenizers.dc_ae.dc_ae_4x32x32 import DCAE4x32x32Interface from cosmos_framework.model.vfm.tokenizers.flux_vae_8x8 import FluxVAEInterface +from cosmos_framework.model.vfm.tokenizers.stable_diffusion_vae_8x8 import StableDiffusionVAEInterface from cosmos_framework.model.vfm.tokenizers.uniae.noncausal_4x16x16 import UniAEVAEInterface from cosmos_framework.model.vfm.tokenizers.wan2pt1_vae_4x8x8 import Wan2pt1VAEInterface from cosmos_framework.model.vfm.tokenizers.wan2pt2_vae_4x16x16 import Wan2pt2VAEInterface +PRETRAINED_TOKENIZER_SD_VAE_REPO = "stabilityai/sd-vae-ft-ema" PRETRAINED_TOKENIZER_WAN2PT1_VAE_PTH = "pretrained/tokenizers/video/wan2pt1/Wan2.1_VAE.pth" PRETRAINED_TOKENIZER_WAN2PT2_VAE_PTH = "pretrained/tokenizers/video/wan2pt2/Wan2.2_VAE.pth" PRETRAINED_TOKENIZER_FLUX_VAE_PTH = "pretrained/tokenizers/image/flux/ae.safetensors" @@ -24,8 +26,8 @@ # DCAE checkpoint paths PRETRAINED_TOKENIZER_DCAE_4X32X32_C64_T120_256P_FPS_ALL_ENCODER_CAUSAL_DECODER_CHUNKCAUSAL4_NOGAN_COSMOS_PAD_7_V0PT2_PTH = "pretrained/tokenizers/video/cosmos/dcae4x32x32_c64_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2.pt" -PRETRAINED_TOKENIZER_DCAE_4X32X32_C96_T120_256P_FPS_ALL_ENCODER_CAUSAL_DECODER_CHUNKCAUSAL4_NOGAN_COSMOS_PAD_7_V0PT2_PTH = "pretrained/tokenizers/video/cosmos/dcae4x32x32_c96_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2.pt" -PRETRAINED_TOKENIZER_DCAE_4X32X32_C128_T120_256P_FPS_ALL_ENCODER_CAUSAL_DECODER_CHUNKCAUSAL4_NOGAN_COSMOS_PAD_7_V0PT2_PTH = "pretrained/tokenizers/video/cosmos/dcae4x32x32_c128_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2.pt" +PRETRAINED_TOKENIZER_DCAE_4X32X32_C96_T120_256P_FPS_ALL_ENCODER_CAUSAL_DECODER_CHUNKCAUSAL4_NOGAN_COSMOS_PAD_7_V0PT2_LCR_PTH = "pretrained/tokenizers/video/cosmos/dcae4x32x32_c96_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_lcr.pt" +PRETRAINED_TOKENIZER_DCAE_4X32X32_C128_T120_256P_FPS_ALL_ENCODER_CAUSAL_DECODER_CHUNKCAUSAL4_NOGAN_COSMOS_PAD_7_V0PT2_LCR_PTH = "pretrained/tokenizers/video/cosmos/dcae4x32x32_c128_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_lcr.pt" # AVAE (Audio VAE) checkpoint paths PRETRAINED_TOKENIZER_AVAE_PTH = "pretrained/tokenizers/audio/avae/model_unwrap.ckpt" @@ -49,6 +51,19 @@ causal=True, ) +StableDiffusionVAEConfig: LazyDict = L(StableDiffusionVAEInterface)( + # Stable Diffusion VAE used by the original DiT ImageNet setup. + bucket_name="", + object_store_credential_path_pretrained=None, + vae_path=PRETRAINED_TOKENIZER_SD_VAE_REPO, + scaling_factor=0.18215, + sample_posterior=True, + dtype="float32", + chunk_duration=1, + spatial_compression_factor=8, + temporal_compression_factor=1, +) + Wan2pt1VAEConfig: LazyDict = L(Wan2pt1VAEInterface)( # 4x8x8 tokenizer bucket_name=PLACEHOLDER, @@ -80,26 +95,28 @@ causal=True, ) -DCAE4x32x32C96T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2Config: LazyDict = L( +DCAE4x32x32C96T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2LCRConfig: LazyDict = L( DCAE4x32x32Interface )( bucket_name=PLACEHOLDER, object_store_credential_path_pretrained=PLACEHOLDER, - vae_path=PRETRAINED_TOKENIZER_DCAE_4X32X32_C96_T120_256P_FPS_ALL_ENCODER_CAUSAL_DECODER_CHUNKCAUSAL4_NOGAN_COSMOS_PAD_7_V0PT2_PTH, - model_name="dcae4x32x32_c96_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2", + vae_path=PRETRAINED_TOKENIZER_DCAE_4X32X32_C96_T120_256P_FPS_ALL_ENCODER_CAUSAL_DECODER_CHUNKCAUSAL4_NOGAN_COSMOS_PAD_7_V0PT2_LCR_PTH, + model_name="dcae4x32x32_c96_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_lcr", spatial_compression_factor=32, temporal_compression_factor=4, + causal=True, ) -DCAE4x32x32C128T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2Config: LazyDict = L( +DCAE4x32x32C128T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2LCRConfig: LazyDict = L( DCAE4x32x32Interface )( bucket_name=PLACEHOLDER, object_store_credential_path_pretrained=PLACEHOLDER, - vae_path=PRETRAINED_TOKENIZER_DCAE_4X32X32_C128_T120_256P_FPS_ALL_ENCODER_CAUSAL_DECODER_CHUNKCAUSAL4_NOGAN_COSMOS_PAD_7_V0PT2_PTH, - model_name="dcae4x32x32_c128_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2", + vae_path=PRETRAINED_TOKENIZER_DCAE_4X32X32_C128_T120_256P_FPS_ALL_ENCODER_CAUSAL_DECODER_CHUNKCAUSAL4_NOGAN_COSMOS_PAD_7_V0PT2_LCR_PTH, + model_name="dcae4x32x32_c128_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_lcr", spatial_compression_factor=32, temporal_compression_factor=4, + causal=True, ) @@ -166,9 +183,16 @@ ) -def register_tokenizer(): +def register_tokenizer() -> None: cs = ConfigStore.instance() + # Stable Diffusion image tokenizer + cs.store( + group="tokenizer", + package="model.config.tokenizer", + name="sd_vae_tokenizer", + node=StableDiffusionVAEConfig, + ) # Wan2pt1 and Wan2pt2 tokenizers cs.store(group="tokenizer", package="model.config.tokenizer", name="wan2pt1_tokenizer", node=Wan2pt1VAEConfig) cs.store(group="tokenizer", package="model.config.tokenizer", name="wan2pt2_tokenizer", node=Wan2pt2VAEConfig) @@ -191,18 +215,18 @@ def register_tokenizer(): cs.store( group="tokenizer", package="model.config.tokenizer", - name="dcae4x32x32_c96_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_tokenizer", - node=DCAE4x32x32C96T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2Config, + name="dcae4x32x32_c96_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_lcr_tokenizer", + node=DCAE4x32x32C96T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2LCRConfig, ) cs.store( group="tokenizer", package="model.config.tokenizer", - name="dcae4x32x32_c128_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_tokenizer", - node=DCAE4x32x32C128T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2Config, + name="dcae4x32x32_c128_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_lcr_tokenizer", + node=DCAE4x32x32C128T120_256pFpsAllEncoderCausalDecoderChunkCausal4NoganCosmosPad7V0pt2LCRConfig, ) -def register_sound_tokenizer(): +def register_sound_tokenizer() -> None: """Register sound tokenizers in Hydra ConfigStore under model.config.sound_tokenizer.""" cs = ConfigStore.instance() cs.store( diff --git a/cosmos_framework/configs/base/experiment/sft/models/nano_model_config.py b/cosmos_framework/configs/base/experiment/sft/models/nano_model_config.py index 37f2689..7340737 100644 --- a/cosmos_framework/configs/base/experiment/sft/models/nano_model_config.py +++ b/cosmos_framework/configs/base/experiment/sft/models/nano_model_config.py @@ -85,9 +85,6 @@ ), rectified_flow_training_config=dict( action_loss_weight=10.0, - high_sigma_ratio=0.05, - high_sigma_timesteps_max=1000, - high_sigma_timesteps_min=995, image_loss_scale=1.0, independent_action_schedule=False, loss_scale=1.0, @@ -100,8 +97,6 @@ train_time_weight="uniform", use_discrete_rf=False, use_dynamic_shift=False, - use_high_sigma_strategy=False, - use_high_sigma_strategy_action=False, ), tokenizer=dict( bucket_name="", diff --git a/cosmos_framework/configs/base/experiment/sft/models/super_model_config.py b/cosmos_framework/configs/base/experiment/sft/models/super_model_config.py index 0400f8d..8c7278a 100644 --- a/cosmos_framework/configs/base/experiment/sft/models/super_model_config.py +++ b/cosmos_framework/configs/base/experiment/sft/models/super_model_config.py @@ -107,9 +107,6 @@ ), rectified_flow_training_config=dict( action_loss_weight=10.0, - high_sigma_ratio=0.05, - high_sigma_timesteps_max=1000, - high_sigma_timesteps_min=995, image_loss_scale=1.0, independent_action_schedule=False, loss_scale=1.0, @@ -122,8 +119,6 @@ train_time_weight="uniform", use_discrete_rf=False, use_dynamic_shift=False, - use_high_sigma_strategy=False, - use_high_sigma_strategy_action=False, ), tokenizer=dict( bucket_name="", diff --git a/cosmos_framework/configs/base/vlm/defaults/callbacks.py b/cosmos_framework/configs/base/vlm/defaults/callbacks.py index 3392205..bb1cd4b 100644 --- a/cosmos_framework/configs/base/vlm/defaults/callbacks.py +++ b/cosmos_framework/configs/base/vlm/defaults/callbacks.py @@ -12,6 +12,7 @@ from cosmos_framework.utils.lazy_config import LazyCall as L from cosmos_framework.utils.callback import LowPrecisionCallback, WandBCallback from cosmos_framework.callbacks.dataloader_state import DataLoaderStateCallback + from cosmos_framework.callbacks.grad_clip import GradClip from cosmos_framework.callbacks.hf_export import HFExportCallback from cosmos_framework.callbacks.iter_speed import IterSpeed diff --git a/cosmos_framework/data/vfm/action/action_processing.py b/cosmos_framework/data/vfm/action/action_processing.py index 1a09170..2d445da 100644 --- a/cosmos_framework/data/vfm/action/action_processing.py +++ b/cosmos_framework/data/vfm/action/action_processing.py @@ -91,6 +91,7 @@ def load_action_stats(stats_path: str, stats_key: str = "global") -> dict[str, n def resolve_action_normalization( method: ActionNormalizationMethod, stats: dict[str, torch.Tensor], + apply_forward_clamp: bool = False, ) -> ActionAffineNormalization: """Resolve configured action stats into affine forward/inverse parameters.""" if method == "meanstd": @@ -109,10 +110,17 @@ def resolve_action_normalization( offset = (hi + lo) / 2.0 # [D] scale = (hi - lo).clamp(min=1e-8) / 2.0 # [D] + + if apply_forward_clamp: + # Ideally this hardcode should be removed, but for now we keep it so we can be aligned with mid-training checkpoints. + forward_clamp = (-1.0, 1.0) + else: + forward_clamp = None + return ActionAffineNormalization( offset=offset, scale=scale, - forward_clamp=(-1.0, 1.0), + forward_clamp=forward_clamp, ) diff --git a/cosmos_framework/data/vfm/action/transforms.py b/cosmos_framework/data/vfm/action/transforms.py index 3462f1e..96f4f0b 100644 --- a/cosmos_framework/data/vfm/action/transforms.py +++ b/cosmos_framework/data/vfm/action/transforms.py @@ -30,8 +30,8 @@ from cosmos_framework.data.vfm.augmentors.idle_frames_text_info import IdleFramesTextInfo from cosmos_framework.data.vfm.augmentors.resolution_text_info import ResolutionTextInfo from cosmos_framework.data.vfm.augmentors.text_tokenizer import TextTokenizerTransform -from cosmos_framework.data.vfm.sequence_packing import SequencePlan from cosmos_framework.data.vfm.utils import VIDEO_RES_SIZE_INFO +from cosmos_framework.data.vfm.sequence_packing import SequencePlan from cosmos_framework.utils.vfm.data_utils import get_vision_data_resolution diff --git a/cosmos_framework/data/vfm/augmentor_provider.py b/cosmos_framework/data/vfm/augmentor_provider.py index 3e3d785..fa859e4 100644 --- a/cosmos_framework/data/vfm/augmentor_provider.py +++ b/cosmos_framework/data/vfm/augmentor_provider.py @@ -10,6 +10,7 @@ import cosmos_framework.data.vfm.augmentors.append_fps_frames_for_image as append_fps_frames_for_image import cosmos_framework.data.vfm.augmentors.audio_caption as audio_caption import cosmos_framework.data.vfm.augmentors.caption_filter as caption_filter +import cosmos_framework.data.vfm.augmentors.cropping as cosmos_cropping import cosmos_framework.data.vfm.augmentors.duration_fps_text_timestamps as duration_fps_text_timestamps import cosmos_framework.data.vfm.augmentors.image_resolution_filter as image_resolution_filter import cosmos_framework.data.vfm.augmentors.merge_datadict as merge_datadict @@ -25,6 +26,9 @@ from cosmos_framework.data.vfm.augmentors import sequence_plan from cosmos_framework.data.vfm.utils import IMAGE_RES_SIZE_INFO, VIDEO_RES_SIZE_INFO +# UniAE requires spatial dimensions divisible by (spatial_compression * patch_spatial) = 16 * 2 = 32. +UNIAE_SPATIAL_MULTIPLE = 32 + AUGMENTOR_OPTIONS = {} CAMERA_MOVEMENT_PHRASES = [ @@ -617,9 +621,20 @@ def get_video_augmentor_v3( input_keys=["video"], args={"size": VIDEO_RES_SIZE_INFO[resolution]}, ), - "reflection_padding": L(padding.ReflectionPadding)( - input_keys=["video"], - args={"size": VIDEO_RES_SIZE_INFO[resolution]}, + **( + { + "reflection_padding": L(padding.ReflectionPadding)( + input_keys=["video"], + args={"size": VIDEO_RES_SIZE_INFO[resolution]}, + ) + } + if causal_vae + else { + "crop_to_multiple": L(cosmos_cropping.CropToMultiple)( + input_keys=["video"], + args={"multiple": UNIAE_SPATIAL_MULTIPLE}, + ) + } ), "text_transform": L(text_transforms_for_video.TextTransformForVideoWithFullFrames)( input_keys=["metas", "ai_caption", "sequence_plan"], @@ -781,6 +796,8 @@ def get_video_augmentor_v3_json_caption( use_dynamic_fps: bool = False, max_stride: int = 3, min_stride: int = 1, + min_fps: float = 10.0, + max_fps: float = 60.0, use_system_prompt: bool = False, resize_on_read: bool = False, dataset_resolution_type: str = "all", @@ -842,7 +859,6 @@ def get_video_augmentor_v3_json_caption( uniae_pad_frames = kwargs.get("uniae_pad_frames", None) uniae_chunk_frames = kwargs.get("uniae_chunk_frames", None) - print("Running video_augmentor_v3_json_caption...") augmentors = { # Caption parsing runs BEFORE video parsing so that VideoParsingChunkedFrames can # decode only the frames belonging to a randomly sampled caption chunk. @@ -863,6 +879,8 @@ def get_video_augmentor_v3_json_caption( "use_dynamic_fps": use_dynamic_fps, "max_stride": max_stride, "min_stride": min_stride, + "min_fps": min_fps, + "max_fps": max_fps, "seek_mode": "exact", "dataset_resolution_type": dataset_resolution_type, "resolution": resolution, @@ -908,9 +926,20 @@ def get_video_augmentor_v3_json_caption( input_keys=["video"], args={"size": VIDEO_RES_SIZE_INFO[resolution]}, ), - "reflection_padding": L(padding.ReflectionPadding)( - input_keys=["video"], - args={"size": VIDEO_RES_SIZE_INFO[resolution]}, + **( + { + "reflection_padding": L(padding.ReflectionPadding)( + input_keys=["video"], + args={"size": VIDEO_RES_SIZE_INFO[resolution]}, + ) + } + if causal_vae + else { + "crop_to_multiple": L(cosmos_cropping.CropToMultiple)( + input_keys=["video"], + args={"multiple": UNIAE_SPATIAL_MULTIPLE}, + ) + } ), # Duration/FPS timestamp augmentor - appends metadata like "The video is 2.5 seconds long and is of 24 FPS." # To customize the template or separator, add them to the args dict below: diff --git a/cosmos_framework/data/vfm/augmentors/cropping.py b/cosmos_framework/data/vfm/augmentors/cropping.py index ce65744..12b826c 100644 --- a/cosmos_framework/data/vfm/augmentors/cropping.py +++ b/cosmos_framework/data/vfm/augmentors/cropping.py @@ -73,7 +73,9 @@ def __call__(self, data_dict: dict) -> dict: # log.info(f"Data cropped from ({h}, {w}) to ({new_h}, {new_w})") data_dict[key] = transforms_F.crop(data, top=top, left=left, height=new_h, width=new_w) - # Store final dimensions for downstream use (e.g., resolution text info) + # Store final dimensions for downstream use (e.g., ResolutionTextInfo) + # Use the same image_size format as ReflectionPadding: [target_h, target_w, orig_h, orig_w] + data_dict["image_size"] = torch.tensor([new_h, new_w, h, w], dtype=torch.float) data_dict["final_height"] = new_h data_dict["final_width"] = new_w diff --git a/cosmos_framework/data/vfm/augmentors/text_transforms_for_image.py b/cosmos_framework/data/vfm/augmentors/text_transforms_for_image.py index d38fae4..8b094ae 100644 --- a/cosmos_framework/data/vfm/augmentors/text_transforms_for_image.py +++ b/cosmos_framework/data/vfm/augmentors/text_transforms_for_image.py @@ -9,6 +9,8 @@ from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor from cosmos_framework.utils import log +# COSMOS-RELEASE-END-IGNORE + # For the qwen captions, we have 3 variants: short, medium, long # In addition, for synthetic data, we create prompt embeddings as well. # There is quite a bit of entropy in the way prompt data is saved. diff --git a/cosmos_framework/data/vfm/augmentors/transfer_control_transform.py b/cosmos_framework/data/vfm/augmentors/transfer_control_transform.py index a18e8fd..c94b8af 100644 --- a/cosmos_framework/data/vfm/augmentors/transfer_control_transform.py +++ b/cosmos_framework/data/vfm/augmentors/transfer_control_transform.py @@ -29,8 +29,8 @@ from cosmos_framework.data.imaginaire.webdataset.augmentors.augmentor import Augmentor from cosmos_framework.utils import log from cosmos_framework.data.vfm.augmentors.transfer_control_input import AddControlInputComb -from cosmos_framework.data.vfm.sequence_packing import SequencePlan from cosmos_framework.data.vfm.utils import VIDEO_RES_SIZE_INFO +from cosmos_framework.data.vfm.sequence_packing import SequencePlan class SampleResolution(Augmentor): diff --git a/cosmos_framework/data/vfm/augmentors/video_parsing.py b/cosmos_framework/data/vfm/augmentors/video_parsing.py index 25a5580..1304f6c 100644 --- a/cosmos_framework/data/vfm/augmentors/video_parsing.py +++ b/cosmos_framework/data/vfm/augmentors/video_parsing.py @@ -442,8 +442,9 @@ def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: O self.resolution_tier = _DATASET_RESOLUTION_TIER.get(self.dataset_resolution_type) # VAE temporal alignment mode. - # causal_vae=True (default): align to 1+4N (causal VAE, e.g. Wan 2.2) - # causal_vae=False: align to 4N (non-causal VAE, e.g. UniAE) + # causal_vae=True (default): align to 1+4N (causal VAE, e.g. Wan 2.2) + # causal_vae=False: align to 1+effective_chunk_frames*N (UniAE with chunk structure) + # or 4N (generic non-causal VAE) self.causal_vae = args.get("causal_vae", True) self.target_resolution_key = None if args.get("resolution") is None else str(args["resolution"]) self.uniae_pad_frames = None if args.get("uniae_pad_frames") is None else int(args["uniae_pad_frames"]) diff --git a/cosmos_framework/data/vfm/joint_dataloader.py b/cosmos_framework/data/vfm/joint_dataloader.py index ba84ceb..6ad83eb 100644 --- a/cosmos_framework/data/vfm/joint_dataloader.py +++ b/cosmos_framework/data/vfm/joint_dataloader.py @@ -389,13 +389,15 @@ def _compute_num_tokens_per_sample(self, data_batch: dict) -> int: # vae temporal downsampling factor # Action tokens have 1 token per time step (no spatial dimension) - text_token_ids = data_batch["text_token_ids"] - if isinstance(text_token_ids, list): - num_text_tokens = text_token_ids[0].shape[0] - else: - num_text_tokens = text_token_ids.shape[1] - - num_tokens = num_text_tokens + 1 + has_text_tokens = "text_token_ids" in data_batch + num_tokens = 0 + if has_text_tokens: + text_token_ids = data_batch["text_token_ids"] + if isinstance(text_token_ids, list): + num_text_tokens = text_token_ids[0].shape[0] + else: + num_text_tokens = text_token_ids.shape[1] + num_tokens += num_text_tokens + 1 # Vision part is_image_batch = "images" in data_batch @@ -415,7 +417,9 @@ def _compute_num_tokens_per_sample(self, data_batch: dict) -> int: patch_w_shape = math.ceil(latent_w_shape / self.patch_spatial) latent_t_shape = self._compute_vision_latent_t_shape(T, H, W) - num_vision_tokens = patch_h_shape * patch_w_shape * latent_t_shape + 2 + num_vision_tokens = patch_h_shape * patch_w_shape * latent_t_shape + if has_text_tokens: + num_vision_tokens += 2 num_tokens += num_vision_tokens # Action part: each action time step is 1 token. diff --git a/cosmos_framework/data/vfm/local_datasets/sft_dataset.py b/cosmos_framework/data/vfm/local_datasets/sft_dataset.py index c8f0858..38b965e 100644 --- a/cosmos_framework/data/vfm/local_datasets/sft_dataset.py +++ b/cosmos_framework/data/vfm/local_datasets/sft_dataset.py @@ -24,7 +24,8 @@ get_video_metadata, parse_s3_url, ) -from cosmos_framework.data.vfm.sequence_packing import SequencePlan, add_special_tokens +from cosmos_framework.data.vfm.sequence_packing import SequencePlan +from cosmos_framework.data.vfm.sequence_packing.modalities import add_special_tokens from cosmos_framework.data.vfm.utils import VIDEO_RES_SIZE_INFO from cosmos_framework.inference.structured_caption import CAPTION_JSON_KEY, caption_json_to_prompt from cosmos_framework.model.vfm.vlm.qwen3_vl.utils import tokenize_caption diff --git a/cosmos_framework/data/vfm/sequence_packing.py b/cosmos_framework/data/vfm/sequence_packing.py deleted file mode 100644 index 1209a2d..0000000 --- a/cosmos_framework/data/vfm/sequence_packing.py +++ /dev/null @@ -1,3065 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: OpenMDW-1.1 - -"""Functions for implementing sequence packing with flexible attention modes. - -This module provides utilities for packing text and image sequences together -with support for different attention patterns (causal, full, noise). - -Key Components: ---------------- -1. Attention Mask Creation: - - create_sparse_mask(): Creates sparse masks for flex attention - - prepare_attention_mask_per_sample(): Creates dense attention masks - -2. Position ID Generation: - - get_flattened_position_ids_extrapolate(): Extrapolation-based position encoding - - get_flattened_position_ids_interpolate(): Interpolation-based position encoding - -3. Tokenizer Setup: - - add_special_tokens(): Adds image boundary tokens to tokenizer - -4. Sequence Packing: - - pack_input_sequence(): Main function for packing text and image sequences - - Helper functions: _pack_text_tokens(), _pack_image_tokens(), _finalize_packed_data() - -Sequence Format: ---------------- -Each sample consists of alternating text and image sections: - [text_tokens] [image_tokens] ... - -Attention Modes: ---------------- -- 'causal': Standard causal/autoregressive attention for text -- 'full': Bidirectional attention for images -- 'noise': Special mode for noise conditioning -""" - -import math -from collections.abc import Mapping, Sequence -from dataclasses import dataclass, field -from typing import Any, Dict, List, Tuple - -import torch -from torch.nn.attention.flex_attention import and_masks, or_masks - -from cosmos_framework.model.attention.checks import check_valid_tuple_or_element -from cosmos_framework.model.attention.varlen import generate_multi_dim_varlen_parameters -from cosmos_framework.utils import log -from cosmos_framework.model.vfm.mot.unified_3dmrope_utils import ( - get_3d_mrope_ids_text_tokens, - get_3d_mrope_ids_vae_tokens, -) -from cosmos_framework.model.vfm.utils.data_and_condition import GenerationDataClean -from cosmos_framework.model.vfm.tokenizers.tokenization_qwen2 import Qwen2Tokenizer - -MAX_CAUSAL_LEN_IMAGE_BATCH = 0 -MAX_FULL_LEN_IMAGE_BATCH = 0 -MAX_CAUSAL_LEN_VIDEO_BATCH = 0 -MAX_FULL_LEN_VIDEO_BATCH = 0 - - -# ============================================================================ -# Attention mask creation -# ============================================================================ - - -def create_sparse_mask(document_lens, split_lens, attn_modes, device): - """Create a sparse attention mask combining multiple attention patterns. - - Args: - document_lens: List of document lengths - split_lens: List of split lengths within documents - attn_modes: List of attention modes ('causal', 'full', 'noise') for each split - device: Device to place tensors on - - Returns: - Combined mask using flex attention API - """ - - # Build sequence ID tensors for tracking full/noise attention regions - full_and_noise_seq_ids = [] - noise_seq_ids = [] - - for seq_idx, (length, attn_mode) in enumerate(zip(split_lens, attn_modes)): - # Assign sequence ID for full/noise regions, -1 for causal regions - seq_id = seq_idx if attn_mode in ["full", "noise"] else -1 - full_and_noise_seq_ids.extend([seq_id] * length) - - # Assign sequence ID only for noise regions - noise_seq_id = seq_idx if attn_mode == "noise" else -1 - noise_seq_ids.extend([noise_seq_id] * length) - - full_and_noise_seq_id = torch.tensor(full_and_noise_seq_ids, device=device) # [seq_len] - noise_seq_id = torch.tensor(noise_seq_ids, device=device) # [seq_len] - document_id = torch.cat([torch.full((l,), i) for i, l in enumerate(document_lens, start=1)]).to(device) # [seq_len] - - # Define component mask functions - def causal_mask(b, h, q_idx, kv_idx): - """Standard causal attention: query can only attend to prior keys.""" - return q_idx >= kv_idx - - def full_and_noise_mask(b, h, q_idx, kv_idx): - """Allow attention within same full/noise sequence.""" - return (full_and_noise_seq_id[q_idx] == full_and_noise_seq_id[kv_idx]) & (full_and_noise_seq_id[q_idx] >= 0) - - def remove_noise_mask(b, h, q_idx, kv_idx): - """Prevent attending to noise tokens from different sequences.""" - return ~((noise_seq_id[kv_idx] >= 0) & (noise_seq_id[q_idx] != noise_seq_id[kv_idx])) - - def sample_mask(b, h, q_idx, kv_idx): - """Ensure attention stays within same document/sample.""" - return document_id[q_idx] == document_id[kv_idx] - - # Combine all masks: (causal OR full_and_noise) AND remove_noise AND sample - return and_masks(or_masks(causal_mask, full_and_noise_mask), remove_noise_mask, sample_mask) - - -def prepare_attention_mask_per_sample(split_lens, attn_modes, device="cpu"): - """Prepare dense attention mask for a single sample with multiple splits. - - Args: - split_lens: List of integers indicating length of each split within the sample - attn_modes: List of attention modes for each split ('causal', 'full', or 'noise') - device: Device to place the attention mask tensor on - - Returns: - Attention mask tensor of shape (sample_len, sample_len) with -inf for masked positions - """ - sample_len = sum(split_lens) - attention_mask = torch.zeros((sample_len, sample_len), dtype=torch.bool, device=device) # [sample_len,sample_len] - - # First pass: Set up basic attention patterns for each split - current_pos = 0 - for split_len, attn_mode in zip(split_lens, attn_modes): - assert attn_mode in ["causal", "full", "noise"], f"Invalid attention mode: {attn_mode}" - - split_start = current_pos - split_end = current_pos + split_len - - if attn_mode == "causal": - # Causal: lower triangular within split + full attention to previous splits - attention_mask[split_start:split_end, split_start:split_end] = torch.ones( - (split_len, split_len), device=device - ).tril() # [split_len,split_len] - attention_mask[split_start:split_end, :split_start] = 1 - else: # "full" or "noise" - # Full attention within split and to previous splits - attention_mask[split_start:split_end, split_start:split_end] = torch.ones( - (split_len, split_len), device=device - ) # [split_len,split_len] - attention_mask[split_start:split_end, :split_start] = 1 - - current_pos += split_len - - # Second pass: Handle noise mode - mask out noise columns except within same split - current_pos = 0 - for split_len, attn_mode in zip(split_lens, attn_modes): - if attn_mode == "noise": - split_start = current_pos - split_end = current_pos + split_len - - # Zero out the entire column for noise tokens - attention_mask[:, split_start:split_end] = 0 - # But allow self-attention within the noise split - attention_mask[split_start:split_end, split_start:split_end] = 1 - - current_pos += split_len - - # Convert boolean mask to float with -inf for masked positions - attention_mask = torch.zeros_like(attention_mask, dtype=torch.float).masked_fill_( - ~attention_mask, float("-inf") - ) # [sample_len,sample_len] - - return attention_mask - - -# ============================================================================ -# Tokenizer utilities -# ============================================================================ - - -def add_special_tokens(tokenizer): - """Add image-related special tokens to tokenizer if not already present. - - Args: - tokenizer: Tokenizer to add special tokens to - - Returns: - Tuple of (modified tokenizer, dict of new token IDs) - """ - # Collect existing special tokens - existing_special_tokens = [] - for key, value in tokenizer.special_tokens_map.items(): - if isinstance(value, str): - existing_special_tokens.append(value) - elif isinstance(value, list): - existing_special_tokens.extend(value) - - # Define image boundary tokens to add if missing - tokens_to_add = [] - if "<|vision_start|>" not in existing_special_tokens: - tokens_to_add.append("<|vision_start|>") - if "<|vision_end|>" not in existing_special_tokens: - tokens_to_add.append("<|vision_end|>") - - # Add new tokens to tokenizer vocabulary - if tokens_to_add: - tokenizer.add_tokens(tokens_to_add) - - # Get token IDs for image boundary tokens - new_token_ids = { - "start_of_generation": tokenizer.convert_tokens_to_ids("<|vision_start|>"), - "end_of_generation": tokenizer.convert_tokens_to_ids("<|vision_end|>"), - } - - return tokenizer, new_token_ids - - -# ============================================================================ -# Data structures -# ============================================================================ - - -@dataclass -class ModalityData: - """Unified container for a single generation modality's data. - - This dataclass serves dual purposes: - 1. During packing: Acts as a builder, accumulating data in lists - 2. After finalize(): Holds finalized tensors ready for model consumption - - Attributes: - sequence_indexes: Indices in the packed sequence where this modality's tokens appear. - List during building, Tensor after finalize(). - timesteps: Diffusion timesteps for each noised token. - List during building, Tensor after finalize(). - mse_loss_indexes: Indices where MSE loss should be computed (noised tokens only). - List during building, Tensor after finalize(). - token_shapes: Shape metadata for each sample's tokens. - For vision: list of (T, H, W) tuples. - For action: list of (T,) tuples. - tokens: The actual latent tokens. List during build, Tensor after finalize(). - condition_mask: Mask indicating clean frames (1=clean, 0=noised). Only after finalize(). - noisy_frame_indexes: Indices of noised frames. Constructed from condition_mask during - sequence packing to reduce GPU->CPU synchronization later. Only after finalize(). - domain_id: Domain ID for multi-domain training. Only after finalize(). NOTE: only used for action modality. - raw_action_dim: Raw action dimension. Only after finalize(). NOTE: only used for action modality. - """ - - # Core tracking (list during build, tensor after finalize) - sequence_indexes: list[int] | torch.Tensor = field(default_factory=list) - timesteps: list[float] | torch.Tensor = field(default_factory=list) - mse_loss_indexes: list[int] | torch.Tensor = field(default_factory=list) - # list[tuple[int,int,int]] for vision, list[tuple[int]] for action, list[tuple[int,int,int]] for sound - token_shapes: list = field(default_factory=list) - - # Populated during finalization (from GenerationDataClean / noise path) - tokens: list[torch.Tensor] = field(default_factory=list) - condition_mask: list[torch.Tensor] = field(default_factory=list) - noisy_frame_indexes: list[torch.Tensor] = field(default_factory=list) - domain_id: list[torch.Tensor] = field(default_factory=list) - raw_action_dim: list[torch.Tensor | None] | None = field(default_factory=list) - - def to_cuda(self) -> None: - """Move all tensor fields to CUDA in-place.""" - if isinstance(self.sequence_indexes, torch.Tensor): - self.sequence_indexes = self.sequence_indexes.cuda() - if isinstance(self.timesteps, torch.Tensor): - self.timesteps = self.timesteps.cuda() - if isinstance(self.mse_loss_indexes, torch.Tensor): - self.mse_loss_indexes = self.mse_loss_indexes.cuda() - self.tokens = [token.cuda() for token in self.tokens] - self.condition_mask = [cm.cuda() for cm in self.condition_mask] - self.noisy_frame_indexes = [ni.cuda() for ni in self.noisy_frame_indexes] - self.domain_id = [d.cuda() for d in self.domain_id] - # raw_action_dim is optional (e.g., when action-channel masking is disabled). - if self.raw_action_dim is not None: - self.raw_action_dim = [d.cuda() if d is not None else None for d in self.raw_action_dim] - - -@dataclass -class PackedSequence: - """Unified sequence container - works as builder during packing and final output. - - This dataclass replaces the old SequenceStatus + PackedSequence pattern: - - Build phase: Accumulate data using lists, modalities use ModalityData builders - - After finalize(): Ready for model consumption with tensors - - Attributes: - # Sequence structure - sample_lens: Length of each sample in the packed sequence. - split_lens: Length of each split (text/vision/action sections). - attn_modes: Attention mode for each split ('causal', 'full'). - is_image_batch: Whether this batch contains images (vs videos). - sequence_length: Total length of packed sequence. Computed during finalize(). - - # Build-time tracking (not used after finalize) - curr: Current position in the packed sequence during building. - - # Text modality (list during build, tensor after finalize) - text_ids: All text token IDs (including special tokens). - text_indexes: Indices where text tokens appear in sequence. - position_ids: RoPE position IDs for all tokens. - - # Loss computation - Cross Entropy (text) - label_ids: Label IDs for cross-entropy loss. - ce_loss_indexes: Indices for computing cross-entropy loss. - ce_loss_weights: Weights for cross-entropy loss. - - # Generation modalities - named fields for type safety - vision: Vision modality data (images/videos). None if no vision in batch. - action: Action modality data (robotics). None if no actions in batch. - sound: Sound modality data (audio). None if no sound in batch. - """ - - # Sequence structure - sample_lens: list[int] = field(default_factory=list) - split_lens: list[int] = field(default_factory=list) - attn_modes: list[str] = field(default_factory=list) - is_image_batch: bool = False - sequence_length: int = 0 - - # Build-time tracking (used during packing, not after finalize) - curr: int = 0 - - # Text modality (list during build, tensor after finalize) - text_ids: list[int] | torch.Tensor = field(default_factory=list) - text_indexes: list[int] | torch.Tensor = field(default_factory=list) - position_ids: list[int] | torch.Tensor = field(default_factory=list) - - # Loss computation - Cross Entropy (text) - label_ids: list[int] | torch.Tensor | None = field(default_factory=list) - ce_loss_indexes: list[int] | torch.Tensor | None = field(default_factory=list) - ce_loss_weights: list[float] | torch.Tensor | None = field(default_factory=list) - - # Build-time mRoPE tracking (used during packing, not after finalize) - # When _use_mrope=True, position_ids accumulates (3, N) tensors instead of ints, - # and finalize() produces a (3, total_seq_len) tensor instead of (total_seq_len,). - _use_mrope: bool = False - # Running temporal index for mRoPE position ID generation within a single sample. - # Reset to 0 at the start of each sample, then advanced by text and vision helpers - # as segments are packed. Action reuses the pre-vision snapshot (parallel temporal - # range) without advancing it. Float when FPS modulation is enabled. - # E.g. offset=0 -> text(4 tokens) -> offset=4 -> vision(3 frames) -> offset=7. - _mrope_temporal_offset: int | float = 0 - _mrope_reset_spatial: bool = True - - # Temporal causal: whether supertoken 0's action slot contains null tokens. - # True for all training calls and AR frame 0; False for AR frame N>0 (real actions). - # Used by three_way_attention to zero out V for null action tokens (inline when attention_meta.null_action_supertokens=True). - null_action_supertokens: bool = False - - # Temporal causal: number of action tokens prefixing each vision supertoken. - # Equals temporal_compression_factor when actions are packed inline; 0 when - # action_gen=False or for non-temporal-causal layouts. Single source of truth - # for downstream attention/KV-cache code (per-supertoken layout is - # num_action_tokens_per_supertoken + H_p * W_p). - num_action_tokens_per_supertoken: int = 0 - - # Generation modalities - NAMED FIELDS for type safety - vision: ModalityData | None = None - action: ModalityData | None = None - sound: ModalityData | None = None - - def finalize( - self, - gen_data_clean: GenerationDataClean, - ) -> "PackedSequence": - """Convert all lists to tensors and compute derived values. - - Args: - gen_data_clean: GenerationDataClean for metadata (e.g., action domain IDs). - - Returns: - New PackedSequence instance with tensors instead of lists. - """ - # Compute sequence length - sequence_length = sum(self.sample_lens) - sample_lens = self.sample_lens.copy() - split_lens = self.split_lens.copy() - attn_modes = self.attn_modes.copy() - - # Prepare loss-related tensors (cross-entropy) - label_ids: torch.Tensor | None = None - ce_loss_indexes: torch.Tensor | None = None - ce_loss_weights: torch.Tensor | None = None - if self.label_ids and len(self.label_ids) > 0: - label_ids = torch.tensor(self.label_ids) # [N_ce_tokens] - ce_loss_indexes = torch.tensor(self.ce_loss_indexes) # [N_ce_tokens] - ce_loss_weights = torch.tensor(self.ce_loss_weights) # [N_ce_tokens] - - # The condition_mask and noisy_frame_indexes are kept as lists to support variable shapes. - - # Finalize vision modality - vision: ModalityData | None = None - if self.vision is not None and len(self.vision.sequence_indexes) > 0: - vision = ModalityData( - sequence_indexes=torch.tensor(self.vision.sequence_indexes, dtype=torch.long), # [N_vision_tokens] - timesteps=torch.tensor(self.vision.timesteps), # [N_vision_noisy_tokens] - mse_loss_indexes=torch.tensor( - self.vision.mse_loss_indexes, dtype=torch.long - ), # [N_vision_noisy_tokens] - token_shapes=list(self.vision.token_shapes), - tokens=self.vision.tokens, - condition_mask=list(self.vision.condition_mask), - noisy_frame_indexes=list(self.vision.noisy_frame_indexes), - ) - - # Finalize action modality - action: ModalityData | None = None - if self.action is not None and len(self.action.sequence_indexes) > 0: - action = ModalityData( - sequence_indexes=torch.tensor(self.action.sequence_indexes, dtype=torch.long), # [N_action_tokens] - timesteps=torch.tensor(self.action.timesteps), # [N_action_noisy_tokens] - mse_loss_indexes=torch.tensor( - self.action.mse_loss_indexes, dtype=torch.long - ), # [N_action_noisy_tokens] - token_shapes=list(self.action.token_shapes), - tokens=self.action.tokens, - condition_mask=list(self.action.condition_mask), # Keep as list to support variable shapes - noisy_frame_indexes=list(self.action.noisy_frame_indexes), - domain_id=( - gen_data_clean.action_domain_id - if gen_data_clean.action_domain_id is not None - else [torch.zeros(1, dtype=torch.long)] * len(self.action.token_shapes) - ), - raw_action_dim=gen_data_clean.raw_action_dim, - ) - - # Finalize sound modality (placeholder for future) - sound: ModalityData | None = None - if self.sound is not None and len(self.sound.sequence_indexes) > 0: - sound = ModalityData( - sequence_indexes=torch.tensor(self.sound.sequence_indexes, dtype=torch.long), # [N_sound_tokens] - timesteps=torch.tensor(self.sound.timesteps), # [N_sound_noisy_tokens] - mse_loss_indexes=torch.tensor(self.sound.mse_loss_indexes, dtype=torch.long), # [N_sound_noisy_tokens] - token_shapes=list(self.sound.token_shapes), - tokens=self.sound.tokens, - condition_mask=list(self.sound.condition_mask), - noisy_frame_indexes=list(self.sound.noisy_frame_indexes), - ) - - # Finalize position IDs: 3D mRoPE (3, seq_len) or 1D RoPE (seq_len,) - if self._use_mrope and len(self.position_ids) > 0 and isinstance(self.position_ids[0], torch.Tensor): - mrope_tensors: list[torch.Tensor] = self.position_ids # type: ignore[assignment] - position_ids = torch.cat(mrope_tensors, dim=1) # [3,actual_seq_len] - else: # Original 1D RoPE from Bagel, where all the media tokens share the same 1D position ID - position_ids = torch.tensor(self.position_ids) # [seq_len] - - return PackedSequence( - # Sequence structure - sequence_length=sequence_length, - sample_lens=sample_lens, - split_lens=split_lens, - attn_modes=attn_modes, - is_image_batch=gen_data_clean.is_image_batch, - # Text modality (converted to tensors) - text_ids=torch.tensor(self.text_ids, dtype=torch.long), # [N_text_tokens] - text_indexes=torch.tensor(self.text_indexes, dtype=torch.long), # [N_text_tokens] - position_ids=position_ids, # [seq_len] or [3,seq_len] - # Loss computation - Cross Entropy - label_ids=label_ids, - ce_loss_indexes=ce_loss_indexes, - ce_loss_weights=ce_loss_weights, - # Generation modalities - vision=vision, - action=action, - sound=sound, - # Temporal causal - null_action_supertokens=self.null_action_supertokens, - num_action_tokens_per_supertoken=self.num_action_tokens_per_supertoken, - ) - - def to_cuda(self) -> None: - """Move all tensor fields to CUDA in-place.""" - if isinstance(self.text_ids, torch.Tensor): - self.text_ids = self.text_ids.cuda() - if isinstance(self.text_indexes, torch.Tensor): - self.text_indexes = self.text_indexes.cuda() - if isinstance(self.position_ids, torch.Tensor): - self.position_ids = self.position_ids.cuda() - if isinstance(self.label_ids, torch.Tensor): - self.label_ids = self.label_ids.cuda() - if isinstance(self.ce_loss_indexes, torch.Tensor): - self.ce_loss_indexes = self.ce_loss_indexes.cuda() - if isinstance(self.ce_loss_weights, torch.Tensor): - self.ce_loss_weights = self.ce_loss_weights.cuda() - if self.vision is not None: - self.vision.to_cuda() - if self.action is not None: - self.action.to_cuda() - if self.sound is not None: - self.sound.to_cuda() - - -@dataclass -class SequencePlan: - """Plan describing which modalities are present in a sample. - - This dataclass tracks the presence of different modalities (text, vision, action) - and their conditioning configurations for a dataset sample. Unlike SequencePlan - which holds the actual tensor data, this class provides a lightweight summary - of what modalities exist and how they should be conditioned. - - Attributes: - has_text: Whether text/caption tokens are present for this sample. - Used for text-conditioned generation (e.g., text-to-image/video). - has_vision: Whether vision input (image or video latents) is present. - Defaults to False. - condition_frame_indexes_vision: Indexes of latent vision frames that are clean/conditioning. - [] means all frames are noised/supervised. - All frames specified means all frames are clean (no MSE supervision). - For multi-item samples (e.g. image editing where each sample has multiple - separately-encoded images), this applies to each vision item individually. - The number of items per sample is tracked by - ``GenerationDataClean.num_vision_items_per_sample``. - has_action: Whether action input is present for robotics/embodied AI tasks. - Defaults to False. - condition_frame_indexes_action: Indexes of action steps that are clean/conditioning. - [] means all steps are noised/supervised. - All steps specified means all steps are clean (no MSE supervision). - """ - - # -- understanding (text conditioning) -- - has_text: bool - - # -- vision modality -- - has_vision: bool = False - condition_frame_indexes_vision: list[int] = field(default_factory=list) - # If True, all vision items in this sample share the same temporal mRoPE grid - # (controlnet-style transfer: target frame i is spatio-temporally aligned with - # control frame i). Each item gets the same temporal_offset; spatial reset - # behavior is unchanged. Requires num_vision_items_per_sample > 1, equal latent_t, - # and equal fps across items. Default False preserves single-clip and - # image-editing semantics where items represent distinct time states. - share_vision_temporal_positions: bool = False - - # -- action modality -- - has_action: bool = False - condition_frame_indexes_action: list[int] = field(default_factory=list) - action_start_frame_offset: int = 1 - - # -- sound modality -- - has_sound: bool = False - condition_frame_indexes_sound: list[int] = field(default_factory=list) - - def as_dict(self) -> dict: - return { - "has_text": self.has_text, - "has_vision": self.has_vision, - "has_action": self.has_action, - "has_sound": self.has_sound, - "condition_frame_indexes_vision": self.condition_frame_indexes_vision, - "condition_frame_indexes_action": self.condition_frame_indexes_action, - "condition_frame_indexes_sound": self.condition_frame_indexes_sound, - "share_vision_temporal_positions": self.share_vision_temporal_positions, - } - - -# ============================================================================ -# Helper functions for packing sequences -# ============================================================================ - - -def compute_text_split_length( - num_caption_tokens: int, - special_tokens: Dict[str, int], - has_generation: bool = True, -) -> int: - """Compute the total text split length without mutating any state. - - This is the number of token positions occupied by the text split in a - packed sequence: caption tokens + optional BOS + EOS + optional BOV. - - Args: - num_caption_tokens: Number of raw caption token IDs (before special tokens). - special_tokens: Dictionary of special token IDs (checked for ``"bos_token_id"``). - has_generation: Whether a start-of-generation (BOV) token follows text. - - Returns: - Total text split length (positions consumed in the packed sequence). - """ - n = num_caption_tokens - if "bos_token_id" in special_tokens: - n += 1 - n += 1 # EOS - if has_generation: - n += 1 # start-of-generation / BOV - return n - - -def _pack_text_tokens( - packed_seq: PackedSequence, - text_ids: List[int], - special_tokens: Dict[str, int], - curr_rope_id: int, - has_generation: bool, - use_float_positions: bool = False, -) -> Tuple[int, int, int]: - """Pack text tokens into the sequence. - - Args: - packed_seq: PackedSequence instance to accumulate data into. - text_ids: List of text token IDs (integers). - special_tokens: Dictionary of special token IDs. - curr_rope_id: Current RoPE position ID. - has_generation: Whether there's media/action after text. - use_float_positions: If True, generate float position IDs for 3D mRoPE - (for consistency with FPS-modulated vision tokens). - - Returns: - Tuple of (updated curr_rope_id, split_length, sample_length). - """ - # Ensure we're in build mode (fields are lists, not tensors) - assert isinstance(packed_seq.text_ids, list), "PackedSequence must be in build mode" - assert isinstance(packed_seq.text_indexes, list) - assert isinstance(packed_seq.position_ids, list) - assert isinstance(packed_seq.label_ids, list) - assert isinstance(packed_seq.ce_loss_indexes, list) - assert isinstance(packed_seq.ce_loss_weights, list) - - curr = packed_seq.curr - - # Prepend BOS token if available - if "bos_token_id" in special_tokens: - shifted_text_ids = [special_tokens["bos_token_id"]] + text_ids - else: - shifted_text_ids = text_ids - - split_len = 0 - - # Add text tokens to sequence - packed_seq.text_ids.extend(shifted_text_ids) - packed_seq.text_indexes.extend(range(curr, curr + len(shifted_text_ids))) - - # Configure loss computation for text tokens - packed_seq.ce_loss_indexes.extend(range(curr, curr + len(shifted_text_ids))) - packed_seq.ce_loss_weights.extend([1.0] * len(shifted_text_ids)) - packed_seq.label_ids.extend(text_ids[1:] + [special_tokens["eos_token_id"]]) - - curr += len(shifted_text_ids) - split_len += len(shifted_text_ids) - - # Add EOS token - packed_seq.text_ids.append(special_tokens["eos_token_id"]) - packed_seq.text_indexes.append(curr) - curr += 1 - split_len += 1 - - # Add start-of-generation token, but only if there's media/action present. - if has_generation: - packed_seq.text_ids.append(special_tokens["start_of_generation"]) - packed_seq.text_indexes.append(curr) - curr += 1 - split_len += 1 - - # Sanity check -- compute_text_split_length() is called elsewhere. - assert split_len == compute_text_split_length(len(text_ids), special_tokens, has_generation) - - # Update position IDs and attention mode for text split - if packed_seq._use_mrope: - text_mrope_ids, packed_seq._mrope_temporal_offset = get_3d_mrope_ids_text_tokens( - num_tokens=split_len, - temporal_offset=packed_seq._mrope_temporal_offset, - use_float_positions=use_float_positions, - ) # text_mrope_ids: [3,split_len] - packed_seq.position_ids.append(text_mrope_ids) - else: - packed_seq.position_ids.extend(range(curr_rope_id, curr_rope_id + split_len)) - packed_seq.attn_modes.append("causal") - packed_seq.split_lens.append(split_len) - - packed_seq.curr = curr - return curr_rope_id + split_len, split_len, split_len - - -def _pack_vision_tokens( - packed_seq: PackedSequence, - input_vision_tokens: torch.Tensor, - condition_frame_indexes_vision: list[int], - input_timestep: float | torch.Tensor, - curr_rope_id: int, - latent_patch_size: int = 1, - vision_fps: float | None = None, - enable_fps_modulation: bool = False, - base_fps: float = 24.0, - temporal_compression_factor: int = 4, - vision_temporal_positions: torch.Tensor | None = None, -) -> int: - """Pack vision tokens into the sequence. - - Args: - packed_seq: PackedSequence instance to accumulate data into. - input_vision_tokens: Vision latent tokens (C, T, H, W). - condition_frame_indexes_vision: Indexes of conditioning frames. - input_timestep: Diffusion timestep. Either a float (teacher_forcing/none — all frames - share the same sigma) or a Tensor(T_max,) (diffusion_forcing — per-frame sigma; - indexed as input_timestep[frame_idx] for each noisy frame). - curr_rope_id: Current RoPE position ID. - latent_patch_size: Patch size for latent patchification. - vision_fps: Frames per second of the video. Used when enable_fps_modulation=True. - enable_fps_modulation: If True, scale temporal position IDs based on video FPS. - base_fps: Base FPS for normalization (default 24.0). - temporal_compression_factor: VAE temporal compression factor (default 4). - vision_temporal_positions: Optional explicit temporal coordinate per latent - frame, shape ``(T,)``. Used by UniAE to account for kept boundary latents. - Returns: - Vision split length. - """ - # Ensure we're in build mode - assert isinstance(packed_seq.position_ids, list), "PackedSequence must be in build mode" - - curr = packed_seq.curr - vision_split_len = 0 - - # Initialize vision modality if not present. - if packed_seq.vision is None: - packed_seq.vision = ModalityData() - - # Ensure vision modality is in build mode - assert isinstance(packed_seq.vision.sequence_indexes, list) - assert isinstance(packed_seq.vision.mse_loss_indexes, list) - assert isinstance(packed_seq.vision.timesteps, list) - assert isinstance(packed_seq.vision.tokens, list) - - # Compute position IDs for image patches - _, _, latent_t, latent_h, latent_w = input_vision_tokens.shape - if latent_patch_size < 1: - raise ValueError(f"latent_patch_size must be >= 1, got {latent_patch_size}") - # Use ceil to support latent dims not divisible by patch size (padding handled in network) - patch_h = math.ceil(latent_h / latent_patch_size) - patch_w = math.ceil(latent_w / latent_patch_size) - packed_seq.vision.token_shapes.append((latent_t, patch_h, patch_w)) - packed_seq.vision.tokens.append(input_vision_tokens) - - # Add image token indexes and loss information - num_vision_tokens = latent_t * patch_h * patch_w - packed_seq.vision.sequence_indexes.extend(range(curr, curr + num_vision_tokens)) - - # Supervise vision tokens based on conditioning frames - condition_set = {idx for idx in condition_frame_indexes_vision if 0 <= idx < latent_t} - assert isinstance(packed_seq.vision.condition_mask, list) - - vision_condition_mask = torch.zeros( - (latent_t, 1, 1), device=input_vision_tokens.device, dtype=input_vision_tokens.dtype - ) # [T,1,1] - for frame_idx in condition_set: - vision_condition_mask[frame_idx, 0, 0] = 1.0 - packed_seq.vision.condition_mask.append(vision_condition_mask) - - vision_noisy_frame_indexes = torch.tensor( - [idx for idx in range(latent_t) if idx not in condition_set], - device=input_vision_tokens.device, - dtype=torch.long, - ) # [N_noisy_frames] - assert isinstance(packed_seq.vision.noisy_frame_indexes, list) - packed_seq.vision.noisy_frame_indexes.append(vision_noisy_frame_indexes) - - frame_token_stride = patch_h * patch_w - for frame_idx in range(latent_t): - if frame_idx in condition_set: - continue - frame_start = curr + frame_idx * frame_token_stride - frame_end = frame_start + frame_token_stride - packed_seq.vision.mse_loss_indexes.extend(range(frame_start, frame_end)) - if isinstance(input_timestep, torch.Tensor): - frame_ts = input_timestep[frame_idx].item() - else: - frame_ts = input_timestep - packed_seq.vision.timesteps.extend([frame_ts] * frame_token_stride) - - curr += num_vision_tokens - vision_split_len += num_vision_tokens - - # Update position IDs for image split - if packed_seq._use_mrope: - # Determine FPS for this vision segment (None disables FPS modulation) - effective_fps = vision_fps if enable_fps_modulation else None - if vision_temporal_positions is not None: - vision_temporal_positions = vision_temporal_positions.to(device="cpu", dtype=torch.float32) # [T] - - vision_mrope_ids, packed_seq._mrope_temporal_offset = get_3d_mrope_ids_vae_tokens( - grid_t=latent_t, - grid_h=patch_h, - grid_w=patch_w, - temporal_offset=packed_seq._mrope_temporal_offset, - reset_spatial_indices=packed_seq._mrope_reset_spatial, - fps=effective_fps, - base_fps=base_fps, - temporal_compression_factor=temporal_compression_factor, - temporal_positions=vision_temporal_positions, - actual_temporal_compression_factor=temporal_compression_factor, - ) # vision_mrope_ids: [3,N_vision_tokens] - packed_seq.position_ids.append(vision_mrope_ids) - else: - # All image tokens share the same RoPE position ID - packed_seq.position_ids.extend([curr_rope_id] * vision_split_len) - - packed_seq.curr = curr - return vision_split_len - - -def _pack_action_tokens( - packed_seq: PackedSequence, - input_action_tokens: torch.Tensor, - condition_frame_indexes_action: list[int], - input_timestep: float, - curr_rope_id: int, - action_temporal_offset: int | float = 0, - enable_fps_modulation: bool = False, - base_fps: float = 24.0, - action_fps: float | None = None, - base_temporal_compression_factor: int | None = None, - action_start_frame_offset: int = 1, -) -> int: - """Pack action tokens into the sequence. - - Args: - packed_seq: PackedSequence instance to accumulate data into. - input_action_tokens: Action latent tokens (T, D). - condition_frame_indexes_action: Indexes of conditioning action steps. - input_timestep: Diffusion timestep. - curr_rope_id: Current RoPE position ID. - action_temporal_offset: Temporal offset for action mRoPE IDs (typically - the vision start offset so action aligns temporally with vision). - enable_fps_modulation: If True, scale temporal position IDs based on FPS. - base_fps: Base FPS for normalization (default 24.0). - action_fps: Frames per second of the action data. Used when enable_fps_modulation=True. - base_temporal_compression_factor: Base temporal compression factor for FPS scaling. - Should be set to the vision temporal compression factor (e.g. 4) so that action - tokens advance at frame rate (4x finer) relative to vision latent frames. - Only affects behavior when FPS modulation is enabled. - action_start_frame_offset: Frame offset for aligning action[0] with the - corresponding vision frame. Default 1 aligns action[0] with vision frame 1. - Returns: - Number of action tokens added. - """ - # Ensure we're in build mode - assert isinstance(packed_seq.position_ids, list), "PackedSequence must be in build mode" - - curr = packed_seq.curr - action_split_len = input_action_tokens.shape[0] - - # Initialize action modality if not present - if packed_seq.action is None: - packed_seq.action = ModalityData() - - # Ensure action modality is in build mode - assert isinstance(packed_seq.action.sequence_indexes, list) - assert isinstance(packed_seq.action.mse_loss_indexes, list) - assert isinstance(packed_seq.action.timesteps, list) - assert isinstance(packed_seq.action.tokens, list) - - # Add token indexes and loss information - action_indexes = list(range(curr, curr + action_split_len)) - packed_seq.action.sequence_indexes.extend(action_indexes) - packed_seq.action.token_shapes.append((action_split_len,)) - packed_seq.action.tokens.append(input_action_tokens) - - condition_set = {idx for idx in condition_frame_indexes_action if 0 <= idx < action_split_len} - assert isinstance(packed_seq.action.condition_mask, list) - - action_condition_mask = torch.zeros( - (action_split_len, 1), device=input_action_tokens.device, dtype=input_action_tokens.dtype - ) # [T_action,1] - for frame_idx in condition_set: - action_condition_mask[frame_idx, 0] = 1.0 - packed_seq.action.condition_mask.append(action_condition_mask) - - action_noisy_frame_indexes = torch.tensor( - [idx for idx in range(action_split_len) if idx not in condition_set], - device=input_action_tokens.device, - dtype=torch.long, - ) # [N_noisy_action_frames] - assert isinstance(packed_seq.action.noisy_frame_indexes, list) - packed_seq.action.noisy_frame_indexes.append(action_noisy_frame_indexes) - - frame_token_stride = 1 # Action has 1 token per frame (no spatial dimension) - for frame_idx in range(action_split_len): - if frame_idx in condition_set: - continue - frame_start = curr + frame_idx * frame_token_stride - frame_end = frame_start + frame_token_stride - packed_seq.action.mse_loss_indexes.extend(range(frame_start, frame_end)) - packed_seq.action.timesteps.extend([input_timestep] * frame_token_stride) - - # Update RoPE position IDs for action tokens. - if packed_seq._use_mrope: - # 3D mRoPE: action tokens use a 1x1 spatial grid with start_frame_offset=1 - # so action[0] (null token) aligns with vision frame 1, not frame 0. - effective_fps = action_fps if enable_fps_modulation else None - - action_mrope_ids, _ = get_3d_mrope_ids_vae_tokens( - grid_t=action_split_len, - grid_h=1, - grid_w=1, - temporal_offset=action_temporal_offset, - reset_spatial_indices=packed_seq._mrope_reset_spatial, - fps=effective_fps, - base_fps=base_fps, - temporal_compression_factor=1, # Action is at frame rate (no temporal compression) - base_temporal_compression_factor=base_temporal_compression_factor, - start_frame_offset=action_start_frame_offset, # Align action[0] with vision frame action_start_frame_offset - ) # action_mrope_ids: [3,N_action_tokens] - packed_seq.position_ids.append(action_mrope_ids) - # Note: we don't update _mrope_temporal_offset here because action tokens - # share the temporal space with vision tokens (they run in parallel). - else: - # All action tokens share the SAME RoPE position as vision tokens (see docs/sequence_packing.md). - packed_seq.position_ids.extend([curr_rope_id] * action_split_len) - - packed_seq.curr = curr + action_split_len - return action_split_len - - -def _pack_sound_tokens( - packed_seq: PackedSequence, - input_sound_tokens: torch.Tensor, - condition_frame_indexes_sound: list[int], - input_timestep: float, - curr_rope_id: int, - sound_temporal_offset: int | float = 0, - enable_fps_modulation: bool = False, - base_fps: float = 24.0, - sound_fps: float | None = None, - sound_base_temporal_compression_factor: int | None = None, -) -> int: - """Pack sound/audio tokens into the sequence. - - Sound latents have shape [C, T] where C is channels and T is temporal frames. - Sound tokens are added to the unified generation split to maintain FactoredSequencePack's - 2-split invariant (causal + full). - - Args: - packed_seq: PackedSequence instance to accumulate data into. - input_sound_tokens: Sound latent tokens (C, T). - condition_frame_indexes_sound: Indexes of conditioning frames. - [] means all frames are noised/supervised. - All frames specified means all frames are clean (no MSE supervision). - input_timestep: Diffusion timestep. - curr_rope_id: Current RoPE position ID. - sound_temporal_offset: Temporal offset for m-RoPE position IDs (aligned with vision start). - enable_fps_modulation: If True, scale temporal positions by FPS ratio. - base_fps: Base FPS for normalization (default 24.0). - sound_fps: Sound latent FPS (e.g., 25.0). Used for FPS-aware m-RoPE positions. - sound_base_temporal_compression_factor: Base temporal compression factor for sound FPS scaling. - ``None`` preserves the current behavior where sound advances at ``base_fps`` positions/sec. - - Returns: - Number of sound tokens added. - """ - # Ensure we're in build mode - assert isinstance(packed_seq.position_ids, list), "PackedSequence must be in build mode" - - curr = packed_seq.curr - - # Sound latent shape: [C, T] → T tokens - _, sound_split_len = input_sound_tokens.shape - - # Initialize sound modality if not present - if packed_seq.sound is None: - packed_seq.sound = ModalityData() - - # Ensure sound modality is in build mode - assert isinstance(packed_seq.sound.sequence_indexes, list) - assert isinstance(packed_seq.sound.mse_loss_indexes, list) - assert isinstance(packed_seq.sound.timesteps, list) - assert isinstance(packed_seq.sound.tokens, list) - - # Add token indexes - sound uses (T, 1, 1) shape for compatibility with 3D RoPE - packed_seq.sound.token_shapes.append((sound_split_len, 1, 1)) - packed_seq.sound.sequence_indexes.extend(range(curr, curr + sound_split_len)) - packed_seq.sound.tokens.append(input_sound_tokens) - - # Supervise sound tokens based on conditioning frames - condition_set = {idx for idx in condition_frame_indexes_sound if 0 <= idx < sound_split_len} - assert isinstance(packed_seq.sound.condition_mask, list) - - # Condition mask: shape (T, 1) — 1 = clean/conditioning, 0 = noised/supervised - sound_condition_mask = torch.zeros( - (sound_split_len, 1), device=input_sound_tokens.device, dtype=input_sound_tokens.dtype - ) # [T_sound,1] - for frame_idx in condition_set: - sound_condition_mask[frame_idx, 0] = 1.0 - packed_seq.sound.condition_mask.append(sound_condition_mask) - - sound_noisy_frame_indexes = torch.tensor( - [idx for idx in range(sound_split_len) if idx not in condition_set], - device=input_sound_tokens.device, - dtype=torch.long, - ) # [N_noisy_sound_frames] - assert isinstance(packed_seq.sound.noisy_frame_indexes, list) - packed_seq.sound.noisy_frame_indexes.append(sound_noisy_frame_indexes) - - # Add to MSE loss indexes and timesteps for non-conditioning frames - for frame_idx in range(sound_split_len): - if frame_idx in condition_set: - continue - # Sound has 1 token per frame (no spatial dimension) - frame_start = curr + frame_idx - frame_end = frame_start + 1 - packed_seq.sound.mse_loss_indexes.extend(range(frame_start, frame_end)) - packed_seq.sound.timesteps.extend([input_timestep]) - - # Update RoPE position IDs for sound tokens. - if packed_seq._use_mrope: - # 3D mRoPE: sound tokens use a 1x1 spatial grid, aligned with vision temporal positions. - # sound[0] aligns with vision frame 0 (start_frame_offset=0, unlike action which offsets by 1). - effective_fps = sound_fps if enable_fps_modulation else None - - sound_mrope_ids, _ = get_3d_mrope_ids_vae_tokens( - grid_t=sound_split_len, - grid_h=1, - grid_w=1, - temporal_offset=sound_temporal_offset, - reset_spatial_indices=packed_seq._mrope_reset_spatial, - fps=effective_fps, - base_fps=base_fps, - temporal_compression_factor=1, # Sound latent is already at sound_latent_fps (no further compression) - base_temporal_compression_factor=sound_base_temporal_compression_factor, - start_frame_offset=0, # Sound[0] aligns with vision frame 0 - ) # sound_mrope_ids: [3,N_sound_tokens] - packed_seq.position_ids.append(sound_mrope_ids) - # Note: we don't update _mrope_temporal_offset here because sound tokens - # share the temporal space with vision tokens (they run in parallel). - else: - # All sound tokens share the SAME RoPE position as vision/action tokens (unified generation split). - packed_seq.position_ids.extend([curr_rope_id] * sound_split_len) - - packed_seq.curr = curr + sound_split_len - return sound_split_len - - -def _pack_supertokens_temporal_causal( - packed_seq: "PackedSequence", - input_vision_tokens: torch.Tensor, - input_action_tokens: torch.Tensor | None, - condition_frame_indexes_vision: list[int], - input_timestep: float | torch.Tensor, - curr_rope_id: int, - latent_patch_size: int, - temporal_compression_factor: int, - action_dim: int, - vision_fps: float | None = None, - action_fps: float | None = None, - enable_fps_modulation: bool = False, - base_fps: float = 24.0, - pack_action_tokens: bool = True, -) -> tuple[int, bool]: - """Pack vision and (optionally) action tokens in supertoken order for temporal causal attention. - - Buffer layout per frame: - pack_action_tokens=True: [action_t (tcf), vision_t (H*W)] — supertoken size tcf + H*W - pack_action_tokens=False: [vision_t (H*W)] — supertoken size H*W - - Use ``pack_action_tokens=False`` when ``config.action_gen=False``; the resulting - ``num_action_tokens_per_supertoken=0`` is stamped on the pack and read by the - attention builder so NATTEN metadata stays in sync automatically. - - mRoPE layout (with actions, unified_3d_mrope only). The layout is inferred from the - action tensor shape: - - Whole-clip training (frame 0 is the clean conditioning frame, so - ``real_actions`` has ``(T-1)*tcf`` rows): null action for supertoken 0, real - actions for frames 1..T-1 with ``start_frame_offset=1`` so the last action in - group i co-locates with vision frame i; vision uses ``start_frame_offset=0``. - - AR generation, single frame OR chunk (every frame carries a real action, so - ``real_actions`` has ``latent_t*tcf`` rows): vision AND action both use - ``start_frame_offset=1``, generalizing the single-frame AR supertoken to - ``latent_t`` frames. The caller (``pack_input_sequence_autoregressive``) - seeds ``temporal_offset`` one frame-stride back to compensate, so the unit - lands at the same absolute positions as the whole-clip training pack. - - Interleaved per frame as cat([action_ids, vision_ids]). - - ``input_timestep`` is float (TF/none) or Tensor(T_max,) (DF, per-frame sigma). - Conditioning frames are excluded from mse_loss_indexes either way. - - Returns (total_split_len, null_action_flag); null_action_flag is False when - pack_action_tokens=False. - """ - assert isinstance(packed_seq.position_ids, list), "PackedSequence must be in build mode" - - _, _, latent_t, latent_h, latent_w = input_vision_tokens.shape - patch_h = math.ceil(latent_h / latent_patch_size) - patch_w = math.ceil(latent_w / latent_patch_size) - tcf = temporal_compression_factor - patches_per_frame = patch_h * patch_w - supertoken_len = tcf + patches_per_frame if pack_action_tokens else patches_per_frame # S - - # Initialize modalities if needed - if packed_seq.vision is None: - packed_seq.vision = ModalityData() - if pack_action_tokens and packed_seq.action is None: - packed_seq.action = ModalityData() - - assert isinstance(packed_seq.vision.sequence_indexes, list) - assert isinstance(packed_seq.vision.mse_loss_indexes, list) - assert isinstance(packed_seq.vision.timesteps, list) - assert isinstance(packed_seq.vision.tokens, list) - assert isinstance(packed_seq.vision.condition_mask, list) - if pack_action_tokens: - assert isinstance(packed_seq.action.sequence_indexes, list) - assert isinstance(packed_seq.action.mse_loss_indexes, list) - assert isinstance(packed_seq.action.timesteps, list) - assert isinstance(packed_seq.action.tokens, list) - assert isinstance(packed_seq.action.condition_mask, list) - - device = input_vision_tokens.device - dtype = input_vision_tokens.dtype - - null_action_flag: bool - if pack_action_tokens: - # Build all_action_tokens: shape (latent_t * tcf, action_dim) - # - # Cases (token assembly; mRoPE start_frame_offset is chosen separately below, - # inferred from the same action shape): - # 1. Whole-clip training with conditioning frame (latent_t > 1, real_actions - # has (T-1)*tcf rows): prepend tcf null tokens for frame 0, then real - # actions for frames 1..T-1. - # 2. AR generation (every frame has a real action, real_actions has - # latent_t*tcf rows — single frame OR chunk): no null prefix. - # 3. AR frame 0 / image2video (action is None): all null tokens. - if input_action_tokens is not None: - # input_action_tokens shape: (1, T*tcf, D) or (T*tcf, D) for training; (T*tcf, D) for AR units - if input_action_tokens.dim() == 3: - real_actions = input_action_tokens.squeeze(0) # [T*tcf,action_dim] or [N,action_dim] - else: - real_actions = input_action_tokens # [N,action_dim] - null_tokens = torch.zeros(tcf, action_dim, device=device, dtype=real_actions.dtype) # [tcf,action_dim] - if real_actions.shape[0] == latent_t * tcf: - # AR generation (single frame: tcf == 1*tcf, or chunk: latent_t*tcf): - # every supertoken carries a real action, no null prefix. - all_action_tokens = real_actions - null_action_flag = False - elif real_actions.shape[0] == (latent_t - 1) * tcf: - # Conditioning frame present: null for supertoken 0, real for 1..T-1 - all_action_tokens = torch.cat([null_tokens, real_actions], dim=0) # [T*tcf,action_dim] - null_action_flag = True - else: - raise ValueError( - "Temporal-causal action tokens must have either latent_t*tcf rows for AR chunks " - f"or (latent_t-1)*tcf rows for whole-clip training; got {real_actions.shape[0]} rows " - f"for latent_t={latent_t}, tcf={tcf}." - ) - else: - # AR frame 0 or image2video: all action tokens are null - all_action_tokens = torch.zeros( - latent_t * tcf, action_dim, device=device, dtype=dtype - ) # [T*tcf,action_dim] - null_action_flag = True - else: - # pack_action_tokens=False: action tokens must not be supplied. - assert input_action_tokens is None, ( - "pack_action_tokens=False requires input_action_tokens=None; got a non-None tensor." - ) - null_action_flag = False - - # Record vision token shapes and tokens - packed_seq.vision.token_shapes.append((latent_t, patch_h, patch_w)) - packed_seq.vision.tokens.append(input_vision_tokens) - - # Vision conditioning mask: (T, 1, 1) - condition_set_vision = {idx for idx in condition_frame_indexes_vision if 0 <= idx < latent_t} - vision_condition_mask = torch.zeros((latent_t, 1, 1), device=device, dtype=dtype) # [T,1,1] - for fidx in condition_set_vision: - vision_condition_mask[fidx, 0, 0] = 1.0 - packed_seq.vision.condition_mask.append(vision_condition_mask) - - vision_noisy_frame_indexes = torch.tensor( - [idx for idx in range(latent_t) if idx not in condition_set_vision], - device=device, - dtype=torch.long, - ) # [N_noisy_frames] - packed_seq.vision.noisy_frame_indexes.append(vision_noisy_frame_indexes) - - if pack_action_tokens: - # Action token shapes: latent_t * tcf total (including null tokens) - packed_seq.action.token_shapes.append((latent_t * tcf,)) - packed_seq.action.tokens.append(all_action_tokens) - - # Action conditioning mask: all action tokens are conditioning (not supervised) - # Null tokens are always conditioning; real actions are conditioning too (they are inputs) - action_condition_mask = torch.ones((latent_t * tcf, 1), device=device, dtype=dtype) # [T*tcf,1] - packed_seq.action.condition_mask.append(action_condition_mask) - - # Pack in interleaved supertoken order: [action_t, vision_t] for each frame t - # (or just [vision_t] per frame when pack_action_tokens=False) - curr = packed_seq.curr - total_split_len = 0 - - # mRoPE: snapshot offset before this sample, compute IDs - if packed_seq._use_mrope: - temporal_offset = packed_seq._mrope_temporal_offset - effective_vision_fps = vision_fps if enable_fps_modulation else None - - # AR generation (single frame OR chunk) is detected by every frame carrying a - # real action (``real_actions`` has ``latent_t*tcf`` rows). There, vision AND - # action both use start_frame_offset=1 so the last action in each group - # co-locates with its vision frame, mirroring whole-clip training; the caller - # (pack_input_sequence_autoregressive) seeds temporal_offset one frame-stride - # back to compensate. Whole-clip training (frame 0 is the null conditioning - # frame, ``real_actions`` has ``(T-1)*tcf`` rows) keeps vision start_frame_offset=0. - all_frames_have_real_action = ( - pack_action_tokens and input_action_tokens is not None and real_actions.shape[0] == latent_t * tcf - ) - vision_sfo = 1 if all_frames_have_real_action else 0 - - vision_ids_flat, new_offset = get_3d_mrope_ids_vae_tokens( - grid_t=latent_t, - grid_h=patch_h, - grid_w=patch_w, - temporal_offset=temporal_offset, - reset_spatial_indices=packed_seq._mrope_reset_spatial, - fps=effective_vision_fps, - base_fps=base_fps, - temporal_compression_factor=tcf, - start_frame_offset=vision_sfo, - ) # vision_ids_flat: [3,T*patch_h*patch_w] - - if pack_action_tokens: - effective_action_fps = action_fps if enable_fps_modulation else None - - # Action IDs. Real action tokens use start_frame_offset=1 so the last - # sub-token of a group co-locates with its vision frame. Whole-clip training - # has a null action at frame 0 (the conditioning frame); AR units have a real - # action for every frame. - fps_active = effective_action_fps is not None - t_dtype = torch.float32 if fps_active else torch.long - t_offset = float(temporal_offset) if fps_active else int(temporal_offset) - null_t = torch.full((tcf,), t_offset, dtype=t_dtype) # [tcf] - null_hw = torch.zeros(tcf, dtype=t_dtype) # [tcf] - null_ids = torch.stack([null_t, null_hw, null_hw]) # [3,tcf] - - def _real_action_ids(n_frames: int, start_frame_offset: int) -> torch.Tensor: - flat, _ = get_3d_mrope_ids_vae_tokens( - grid_t=n_frames * tcf, - grid_h=1, - grid_w=1, - temporal_offset=temporal_offset, - reset_spatial_indices=packed_seq._mrope_reset_spatial, - fps=effective_action_fps, - base_fps=base_fps, - temporal_compression_factor=1, - base_temporal_compression_factor=tcf, - start_frame_offset=start_frame_offset, - ) - return flat.reshape(3, n_frames, tcf) # [3,n_frames,tcf] - - if all_frames_have_real_action: - # AR generation (single frame: tcf == 1*tcf, or chunk: latent_t*tcf): - # every supertoken carries a real action. start_frame_offset=1 puts - # a_{j-1}'s last sub-token on vision frame j -- the whole-clip TF - # training layout. The caller seeds temporal_offset (N-1) frame-strides - # back to compensate. - action_ids_3d = _real_action_ids(latent_t, start_frame_offset=1) # [3,T,tcf] - elif latent_t > 1: - # Whole-clip training: supertoken 0 = null (conditioning frame), frames - # 1..T-1 = real with start_frame_offset=1. Covers real-action training - # (real_actions has (T-1)*tcf rows) and the architectural all-null layout - # (input_action_tokens is None); the tokens differ but the IDs match. - null_ids_3d = null_ids.reshape(3, 1, tcf) # [3,1,tcf] - real_ids_3d = _real_action_ids(latent_t - 1, start_frame_offset=1) # [3,T-1,tcf] - action_ids_3d = torch.cat([null_ids_3d, real_ids_3d], dim=1) # [3,T,tcf] - else: - # AR frame 0 / image2video (latent_t == 1, no action): only null. - action_ids_3d = null_ids.reshape(3, 1, tcf) # [3,1,tcf] - - # (3, T*H*W) → (3, T, H*W) - vision_ids_3d = vision_ids_flat.reshape(3, latent_t, patches_per_frame) # [3,T,patch_h*patch_w] - - # Interleave per frame: (3, T, tcf+H*W) → (3, T*S) - interleaved_ids = torch.cat([action_ids_3d, vision_ids_3d], dim=2).reshape( - 3, latent_t * supertoken_len - ) # [3,T*S] - packed_seq.position_ids.append(interleaved_ids) - else: - # No action tokens: just vision IDs, already in (3, T*H*W) order. - packed_seq.position_ids.append(vision_ids_flat) - - packed_seq._mrope_temporal_offset = new_offset - - for frame_t in range(latent_t): - if pack_action_tokens: - # Pack action tokens for this frame (indexes only; tokens already stored in packed_seq.action.tokens) - action_indexes = list(range(curr, curr + tcf)) - packed_seq.action.sequence_indexes.extend(action_indexes) - # Action tokens are never in MSE loss (always conditioning) - curr += tcf - total_split_len += tcf - - if not packed_seq._use_mrope: - packed_seq.position_ids.extend([curr_rope_id] * tcf) - - # Pack vision tokens for this frame - frame_indexes = list(range(curr, curr + patches_per_frame)) - packed_seq.vision.sequence_indexes.extend(frame_indexes) - curr += patches_per_frame - total_split_len += patches_per_frame - - if not packed_seq._use_mrope: - packed_seq.position_ids.extend([curr_rope_id] * patches_per_frame) - - # Vision MSE loss: supervise non-conditioning frames - if frame_t not in condition_set_vision: - packed_seq.vision.mse_loss_indexes.extend(frame_indexes) - frame_ts = input_timestep[frame_t].item() if isinstance(input_timestep, torch.Tensor) else input_timestep - packed_seq.vision.timesteps.extend([frame_ts] * patches_per_frame) - - packed_seq.curr = curr - return total_split_len, null_action_flag - - -# ============================================================================ -# Main packing function -# ============================================================================ - - -def pack_input_sequence( - sequence_plans: list[SequencePlan], - input_text_indexes: list[list[int]], - gen_data_clean: GenerationDataClean, - input_timesteps: torch.Tensor, - special_tokens: dict[str, int], - max_num_tokens: int | None = None, - latent_patch_size: int = 1, - skip_text_tokens: bool = False, - include_end_of_generation_token: bool = False, - position_embedding_type: str = "3d_rope", - unified_3d_mrope_reset_spatial_ids: bool = True, - unified_3d_mrope_temporal_modality_margin: int = 0, - enable_fps_modulation: bool = False, - base_fps: float = 24.0, - sound_base_temporal_compression_factor: int | None = None, - temporal_compression_factor: int = 4, - vision_temporal_position_mode: str = "latent_index", - video_temporal_causal: bool = False, - action_dim: int = 32, - initial_mrope_temporal_offset: int | float = 0, -) -> PackedSequence: - """ - Pack a sequence of input strings and VAE latents into a packed tensor format. - Uses SequencePlan to determine which modalities are present for each sample, - and maintains separate indices for text, vision, action, and sound to handle variable modality presence. - - Args: - sequence_plans: List of SequencePlan items describing which modalities are present. - input_text_indexes: List of text token ID sequences (only for samples where has_text=True). - gen_data_clean: GenerationDataClean containing vision, action, and sound tensors. - - x0_tokens_vision: Vision tensors for samples where has_vision=True - - x0_tokens_action: Action tensors for samples where has_action=True - - x0_tokens_sound: Sound tensors (list of [C, T]) for samples where has_sound=True - input_timesteps: Diffusion timesteps for each sample. Shape (B,) or (B, 1) for - teacher_forcing/none (all frames share the same sigma), or (B, T_max) for - diffusion_forcing (per-frame independent sigma). Entries are extracted per - sample as a float (numel==1) or Tensor(T_max,) for per-frame indexing. - special_tokens: Dictionary containing special token IDs (eos_token_id, start_of_generation, end_of_generation) - max_num_tokens: Maximum number of tokens in the packed sequence - latent_patch_size: Patch size used by the network to pack latents - skip_text_tokens: If True, skip packing text tokens - include_end_of_generation_token: If True, append end-of-generation token - position_embedding_type: Position embedding type for vision tokens: - - "3d_rope": Additive 3D RoPE embeddings + 1D position IDs for attention - - "flattened_sin_cos": Additive flattened sin/cos embeddings + 1D position IDs - - "unified_3d_mrope": No additive embedding + 3D position IDs for Qwen3VL-style mRoPE - unified_3d_mrope_reset_spatial_ids: If True (default), spatial (H, W) indices - start from 0 for each vision segment. If False, spatial indices are offset - by the temporal offset (Qwen2VL-style). Only used when position_embedding_type="unified_3d_mrope". - enable_fps_modulation: If True, scale temporal position IDs based on video FPS - to reflect real time. Requires fps_vision in gen_data_clean. - Uses the same flag as diffusion_expert_config.enable_fps_modulation. - base_fps: Base FPS for normalization (default 24.0). - Uses the same value as diffusion_expert_config.base_fps. - sound_base_temporal_compression_factor: Base temporal compression factor for sound FPS scaling. - ``None`` preserves the current behavior where sound advances at ``base_fps`` positions/sec. - temporal_compression_factor: VAE temporal compression factor (default 4). - Obtained from the VAE tokenizer at runtime. - vision_temporal_position_mode: Temporal coordinates used for unified_3d_mrope vision tokens. - "latent_index" keeps legacy positions; "uniae_source_right_edge" uses - per-latent positions from gen_data_clean.temporal_positions_vision. - Returns: - PackedSequence containing all packed tensors and metadata. See PackedSequence for field details. - """ - del max_num_tokens - - assert special_tokens is not None, "Special tokens must be provided" - assert isinstance(input_timesteps, torch.Tensor), "input_timesteps must be a tensor" - if input_timesteps.is_cuda: - raise ValueError("input_timesteps must be on CPU, not CUDA") - if isinstance(input_text_indexes, torch.Tensor): - raise ValueError("input_text_tokens must be a list, not a tensor") - - supported_vision_temporal_position_modes = {"latent_index", "uniae_source_right_edge"} - if vision_temporal_position_mode not in supported_vision_temporal_position_modes: - raise ValueError( - "Unsupported vision_temporal_position_mode: " - f"{vision_temporal_position_mode}. Supported modes: {supported_vision_temporal_position_modes}." - ) - has_any_vision = any(plan.has_vision for plan in sequence_plans) - explicit_vision_temporal_positions_active = vision_temporal_position_mode != "latent_index" and has_any_vision - if explicit_vision_temporal_positions_active: - if position_embedding_type != "unified_3d_mrope": - raise NotImplementedError( - "Explicit vision temporal positions are only supported with position_embedding_type='unified_3d_mrope'." - ) - if gen_data_clean.temporal_positions_vision is None: - raise ValueError( - f"vision_temporal_position_mode={vision_temporal_position_mode} requires " - "gen_data_clean.temporal_positions_vision." - ) - if gen_data_clean.x0_tokens_vision is not None and len(gen_data_clean.temporal_positions_vision) != len( - gen_data_clean.x0_tokens_vision - ): - raise ValueError( - "temporal_positions_vision must have one entry per x0_tokens_vision item, " - f"got {len(gen_data_clean.temporal_positions_vision)} positions for " - f"{len(gen_data_clean.x0_tokens_vision)} vision items." - ) - if video_temporal_causal: - raise NotImplementedError( - "video_temporal_causal=True is not wired for explicit UniAE vision temporal positions yet." - ) - if any(plan.has_action for plan in sequence_plans): - raise NotImplementedError("Action packing is not wired for explicit UniAE vision temporal positions yet.") - if initial_mrope_temporal_offset != 0: - raise NotImplementedError( - "Autoregressive mRoPE temporal offsets are not wired for explicit UniAE vision temporal positions yet." - ) - use_float_mrope_positions = enable_fps_modulation or explicit_vision_temporal_positions_active - - # Initialize packed sequence (acts as builder during packing) - packed_seq = PackedSequence() - - # Configure 3D mRoPE on the builder (enabled when position_embedding_type is unified_3d_mrope) - packed_seq._use_mrope = position_embedding_type == "unified_3d_mrope" - packed_seq._mrope_reset_spatial = unified_3d_mrope_reset_spatial_ids - - # Maintain separate indices for each modality - idx_text = 0 - idx_vision = 0 - idx_action = 0 - idx_sound = 0 - null_action_flags: list[bool] = [] # collected from TC path; asserted consistent after the loop - - # Validate: all samples must have text (causal split is always required for two-way attention). - # CFG dropout only drops text *content*, not the structural text split. - if not skip_text_tokens: - for plan in sequence_plans: - assert plan.has_text, "All sequence plans must have has_text=True when skip_text_tokens=False" - - # Pack each sample based on its sequence plan - for sample_idx, sequence_plan in enumerate(sequence_plans): - curr_rope_id = 0 - sample_len = 0 - - # mRoPE temporal offset resets per sample. - # initial_mrope_temporal_offset is non-zero only for AR inference (frame N seeds at N*tcf). - packed_seq._mrope_temporal_offset = initial_mrope_temporal_offset - - _ts = input_timesteps[sample_idx] - input_timestep = _ts.item() if _ts.numel() == 1 else _ts # float (TF) or Tensor(T_max,) (DF) - - # Pack text tokens if has_text=True and not skipped - if sequence_plan.has_text and not skip_text_tokens: - text_ids = input_text_indexes[idx_text] - idx_text += 1 - - has_generation_for_sample = sequence_plan.has_vision or sequence_plan.has_action or sequence_plan.has_sound - curr_rope_id, _, text_sample_len = _pack_text_tokens( - packed_seq, - text_ids, - special_tokens, - curr_rope_id, - has_generation=has_generation_for_sample, - use_float_positions=use_float_mrope_positions, - ) - sample_len += text_sample_len - - # End of text modality, add an offset as the boundary between text and vision. - packed_seq._mrope_temporal_offset += unified_3d_mrope_temporal_modality_margin - - # Save temporal offset before vision for action tokens (action uses same offset as vision start) - vision_start_temporal_offset = packed_seq._mrope_temporal_offset - - # Pack vision (and optionally action) tokens - if video_temporal_causal and sequence_plan.has_vision: - # Temporal causal path: when sequence_plan.has_action=True, interleaved supertokens - # [action_t, vision_t]; when False, supertokens are just vision patches. - assert position_embedding_type == "unified_3d_mrope", ( - "video_temporal_causal=True requires position_embedding_type='unified_3d_mrope'" - ) - input_vision_tokens = gen_data_clean.x0_tokens_vision[idx_vision] - idx_vision += 1 - - vision_fps = None - if ( - enable_fps_modulation - and gen_data_clean.fps_vision is not None - and idx_vision - 1 < len(gen_data_clean.fps_vision) - ): - vision_fps = float(gen_data_clean.fps_vision[idx_vision - 1].item()) - - input_action_tokens_tc: torch.Tensor | None = None - action_fps_tc: float | None = None - if sequence_plan.has_action: - input_action_tokens_tc = gen_data_clean.x0_tokens_action[idx_action] - if ( - enable_fps_modulation - and gen_data_clean.fps_action is not None - and idx_action < len(gen_data_clean.fps_action) - ): - action_fps_tc = float(gen_data_clean.fps_action[idx_action].item()) - idx_action += 1 - - supertoken_split_len, null_flag = _pack_supertokens_temporal_causal( - packed_seq=packed_seq, - input_vision_tokens=input_vision_tokens, - input_action_tokens=input_action_tokens_tc, - condition_frame_indexes_vision=sequence_plan.condition_frame_indexes_vision, - input_timestep=input_timestep, - curr_rope_id=curr_rope_id, - latent_patch_size=latent_patch_size, - temporal_compression_factor=temporal_compression_factor, - action_dim=action_dim, - vision_fps=vision_fps, - action_fps=action_fps_tc, - enable_fps_modulation=enable_fps_modulation, - base_fps=base_fps, - pack_action_tokens=sequence_plan.has_action, - ) - null_action_flags.append(null_flag) - # We assume all samples in a batch share the same has_action layout, so - # stamp the supertoken layout constant directly here. This is the - # single source of truth read by downstream attention / KV-cache - # code (no recomputation in the network). - packed_seq.num_action_tokens_per_supertoken = temporal_compression_factor if sequence_plan.has_action else 0 - sample_len += supertoken_split_len - vision_split_len = supertoken_split_len - action_split_len = 0 # Already absorbed into supertoken_split_len - - else: - # Standard path: vision and action packed separately - if sequence_plan.has_vision: - # Determine how many vision items this sample owns. - # For multi-item samples (e.g. image editing), num_vision_items_per_sample - # records [2, 2, ...]; for standard T2I/T2V it is None (1 item per sample). - num_vis = ( - gen_data_clean.num_vision_items_per_sample[sample_idx] - if gen_data_clean.num_vision_items_per_sample is not None - else 1 - ) - - vision_split_len = 0 - # Controlnet-style transfer: when set, all vision items share the same - # temporal mRoPE grid. We snapshot the offset before the loop and - # rewind to it before each item, so every item produces identical - # temporal IDs. Each _pack_vision_tokens call still advances the - # offset by latent_t internally; in shared-grid mode the post-loop - # offset equals snapshot + latent_t (single-clip semantics for - # downstream EOV / next-modality tokens). - shared_grid = sequence_plan.share_vision_temporal_positions and num_vis > 1 - items_temporal_offset_snapshot = packed_seq._mrope_temporal_offset - shared_latent_t: int | None = None - shared_patch_h: int | None = None - shared_patch_w: int | None = None - shared_temporal_positions: torch.Tensor | None = None - # FPS is recorded per-sample (shape [B]); for multi-item samples - # (transfer / image-edit) every vision item in this sample shares - # the same conditioning FPS, so we read by sample_idx, not by the - # flat idx_vision counter (which would alias to a neighbor sample's - # fps and corrupt RoPE FPS modulation). - sample_vision_fps: float | None = None - if ( - enable_fps_modulation - and gen_data_clean.fps_vision is not None - and sample_idx < len(gen_data_clean.fps_vision) - ): - sample_vision_fps = float(gen_data_clean.fps_vision[sample_idx].item()) - - for item_idx in range(num_vis): - flat_vision_idx = idx_vision - input_vision_tokens = gen_data_clean.x0_tokens_vision[flat_vision_idx] - vision_temporal_positions: torch.Tensor | None = None - if explicit_vision_temporal_positions_active: - assert gen_data_clean.temporal_positions_vision is not None - vision_temporal_positions = gen_data_clean.temporal_positions_vision[flat_vision_idx] - if vision_temporal_positions.shape[0] != input_vision_tokens.shape[2]: - raise ValueError( - "vision_temporal_positions must match latent_t for each vision item, " - f"got {vision_temporal_positions.shape[0]} positions and " - f"latent_t={input_vision_tokens.shape[2]} for item {flat_vision_idx}." - ) - vision_fps = sample_vision_fps - idx_vision += 1 - - # Determine conditioning for this vision item. - # For multi-item mode: all items except the last are fully conditioned - # (all frames are clean); the last item uses the SequencePlan's - # condition_frame_indexes_vision (typically [] = fully generated). - if num_vis > 1 and item_idx < num_vis - 1: - # Conditioning item (e.g. source image): mark all frames as clean - latent_t = input_vision_tokens.shape[2] - item_condition_frames = list(range(latent_t)) - else: - # Generation item (single-item mode or last item in multi-item) - item_condition_frames = sequence_plan.condition_frame_indexes_vision - - if shared_grid: - item_latent_t = input_vision_tokens.shape[2] - item_latent_h = input_vision_tokens.shape[3] - item_latent_w = input_vision_tokens.shape[4] - if shared_latent_t is None: - shared_latent_t = item_latent_t - shared_patch_h = item_latent_h - shared_patch_w = item_latent_w - else: - assert item_latent_t == shared_latent_t, ( - f"share_vision_temporal_positions requires equal latent_t across items, " - f"got item {item_idx} latent_t={item_latent_t} vs first={shared_latent_t}" - ) - assert item_latent_h == shared_patch_h and item_latent_w == shared_patch_w, ( - f"share_vision_temporal_positions requires equal spatial grid across items, " - f"got item {item_idx} (H,W)=({item_latent_h},{item_latent_w}) " - f"vs first=({shared_patch_h},{shared_patch_w})" - ) - if vision_temporal_positions is not None: - if shared_temporal_positions is None: - shared_temporal_positions = vision_temporal_positions - else: - comparison_temporal_positions = vision_temporal_positions.to( - device=shared_temporal_positions.device - ) # [T] - assert torch.allclose(comparison_temporal_positions, shared_temporal_positions), ( - "share_vision_temporal_positions requires equal explicit temporal positions " - f"across vision items, got item {item_idx} positions " - f"{vision_temporal_positions.tolist()} vs first " - f"{shared_temporal_positions.tolist()}." - ) - # Rewind so this item starts at the same temporal offset as item 0. - packed_seq._mrope_temporal_offset = items_temporal_offset_snapshot - - item_split_len = _pack_vision_tokens( - packed_seq=packed_seq, - input_vision_tokens=input_vision_tokens, - condition_frame_indexes_vision=item_condition_frames, - input_timestep=input_timestep, - curr_rope_id=curr_rope_id, - latent_patch_size=latent_patch_size, - vision_fps=vision_fps, - enable_fps_modulation=enable_fps_modulation, - base_fps=base_fps, - temporal_compression_factor=temporal_compression_factor, - vision_temporal_positions=vision_temporal_positions, - ) - vision_split_len += item_split_len - sample_len += vision_split_len - - else: - vision_split_len = 0 - - # Pack action tokens if has_action=True - if sequence_plan.has_action: - input_action_tokens = gen_data_clean.x0_tokens_action[idx_action] - - # Get FPS for action (action may have its own FPS independent of vision) - action_fps: float | None = None - if ( - enable_fps_modulation - and gen_data_clean.fps_action is not None - and idx_action < len(gen_data_clean.fps_action) - ): - action_fps = float(gen_data_clean.fps_action[idx_action].item()) - - idx_action += 1 - - action_split_len = _pack_action_tokens( - packed_seq=packed_seq, - input_action_tokens=input_action_tokens, - condition_frame_indexes_action=sequence_plan.condition_frame_indexes_action, - input_timestep=input_timestep, - curr_rope_id=curr_rope_id, - action_temporal_offset=vision_start_temporal_offset, - enable_fps_modulation=enable_fps_modulation, - base_fps=base_fps, - action_fps=action_fps, - base_temporal_compression_factor=temporal_compression_factor, - action_start_frame_offset=sequence_plan.action_start_frame_offset, - ) - sample_len += action_split_len - else: - action_split_len = 0 - - # Pack sound tokens if has_sound=True - if sequence_plan.has_sound: - input_sound_tokens = gen_data_clean.x0_tokens_sound[idx_sound] - - # Get FPS for sound (from gen_data_clean, like vision and action) - sound_fps: float | None = None - if ( - enable_fps_modulation - and gen_data_clean.fps_sound is not None - and idx_sound < len(gen_data_clean.fps_sound) - ): - sound_fps = float(gen_data_clean.fps_sound[idx_sound].item()) - - idx_sound += 1 - - sound_split_len = _pack_sound_tokens( - packed_seq=packed_seq, - input_sound_tokens=input_sound_tokens, - condition_frame_indexes_sound=sequence_plan.condition_frame_indexes_sound, - input_timestep=input_timestep, - curr_rope_id=curr_rope_id, - sound_temporal_offset=vision_start_temporal_offset, - enable_fps_modulation=enable_fps_modulation, - base_fps=base_fps, - sound_fps=sound_fps, - sound_base_temporal_compression_factor=sound_base_temporal_compression_factor, - ) - sample_len += sound_split_len - else: - sound_split_len = 0 - - # Add end-of-generation token if needed - eov_len = 0 - has_any_generation = sequence_plan.has_vision or sequence_plan.has_action or sequence_plan.has_sound - if include_end_of_generation_token and has_any_generation: - # Type narrowing: we're in build mode, fields are lists - assert isinstance(packed_seq.text_ids, list) - assert isinstance(packed_seq.text_indexes, list) - assert isinstance(packed_seq.position_ids, list) - - packed_seq.text_ids.append(special_tokens["end_of_generation"]) - packed_seq.text_indexes.append(packed_seq.curr) - - # EOV position IDs: 3D mRoPE or 1D RoPE - if packed_seq._use_mrope: - # Use float dtype when any vision mRoPE positions are fractional. - eov_dtype = torch.float32 if use_float_mrope_positions else torch.long - eov_mrope_ids = torch.full((3, 1), packed_seq._mrope_temporal_offset, dtype=eov_dtype) # [3,1] - packed_seq.position_ids.append(eov_mrope_ids) # type: ignore[arg-type] - packed_seq._mrope_temporal_offset += 1 - else: - packed_seq.position_ids.append(curr_rope_id) # type: ignore[arg-type] - - packed_seq.curr += 1 - eov_len = 1 - sample_len += 1 - - combined_split_len = vision_split_len + action_split_len + sound_split_len + eov_len - packed_seq.attn_modes.append("full") - packed_seq.split_lens.append(combined_split_len) - packed_seq.sample_lens.append(sample_len) - - # Assert consistent null_action_supertokens across all TC samples, then set once - if null_action_flags: - assert len(set(null_action_flags)) == 1, ( - f"Inconsistent null_action_supertokens across samples: {null_action_flags}. " - "All samples in a batch must have the same structure (all training or all AR inference)." - ) - packed_seq.null_action_supertokens = null_action_flags[0] - - # Finalize and return packed data - return packed_seq.finalize( - gen_data_clean=gen_data_clean, - ) - - -# ============================================================================ -# SequencePack:Operations on packed sequences -# ============================================================================ - -""" -SequencePack is a dictionary-based container for packed sequences. -We provide two implementations: - -JointSequencePack: Stores all sub-sequences for all-sequences in a single tensor. - It is more flexible but is less performant. In this implementation, understanding tokens - can be placed in either causal or full-attention sub-sequences. -FactoredSequencePack: - Stores causal/undersanding and full/generation sub-sequences as separate tensors. - It is less flexible but is more performant. In this implementation, understanding tokens - must be on the causal sub-sequence, and generation tokens must be in the full-attention sub-sequence. - -NOTES: - - We are aiming to deprecate and remove JointSequencePack; keeping it available for backwards compatibility at the moment. - - The reason we're implementing them via dict instead of python classes is to make torch.compile + activation checkpointing to work. - -is_sharded (bool): - This flag indicates whether the sequence pack contains global data or a local shard for Context Parallelism (CP). - - When True, tensors represent only the local slice (Global_Length / CP_World_Size). - - Padding and reconstruction logic is skipped in `from_joint`. - - Operations requiring global context (e.g., `get_all_seq`, position ID reconstruction) are not allowed when is_sharded is True. -""" - - -# "Fake" types for readability; everything is plain dict at runtime. -FactoredSequencePack = dict[str, Any] -JointSequencePack = dict[str, Any] -SequencePack = FactoredSequencePack | JointSequencePack - -# ------------------------------------ -# SequencePack: internal helpers -# ------------------------------------ - - -def _find_non_causal_text_token_idx( - attn_modes: List[str], split_lens: List[int], und_token_indexes: List[int] -) -> List[int]: - """ - Find the indexes of the "und" tokens that are under the "full" mode. - This are indices into the full_only_seq. - """ - # Return indexes *into* full_only_seq, not into the original packed sequence. - # The order within full_only_seq is the concatenation of each "full" split in order. - out = [] - full_offset = 0 - packed_idx = 0 - und_token_set = set(und_token_indexes) - for attn_mode, split_len in zip(attn_modes, split_lens): - if attn_mode == "full": - split_indices = range(packed_idx, packed_idx + split_len) - # For this "full" split, find the und tokens within this split, mapped local to full_only_seq offset - for local_idx, split_idx in enumerate(split_indices): - if split_idx in und_token_set: - out.append(full_offset + local_idx) - full_offset += split_len - packed_idx += split_len - return out - - -def _compute_mode_indices_and_offsets( - split_lens: torch.Tensor | List[int], attn_modes: List[str], mode: str, device: torch.device -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Compute indices from a joint tensor that are in the given mode. - """ - indices = [] - offsets = [0] - next_offset = 0 - start = 0 - - if isinstance(split_lens, torch.Tensor): - split_lens = split_lens.tolist() - - for i, (split_len, attn_mode) in enumerate(zip(split_lens, attn_modes)): - if attn_mode == mode: - indices.extend(range(start, start + split_len)) - next_offset += split_len - offsets.append(next_offset) - start += split_len - return torch.tensor(indices, dtype=torch.int32, device=device), torch.tensor( # [N_mode_tokens], [N_mode_splits+1] - offsets, dtype=torch.int32, device=device - ) - - -# Pad causal_seq and full_only_seq to have length 2048 if not already at that size -def _pad_to_N(N, x: torch.Tensor) -> torch.Tensor: - assert x.shape[0] <= N - padded = x.new_zeros((N, *x.shape[1:])) - padded[: x.shape[0]] = x - return padded - - -def _round_up_to_N(n: int, cp_world_size: int = 1, pad_for_cuda_graphs: bool = False) -> int: - if pad_for_cuda_graphs: - # Reduce recompilations / CUDA graph re-captures by bucketing lengths. - # <= 2K: 128, <= 4K: 256, <= 8K: 512, <= 16K: 1024, > 16K: 2048 - if n <= 2048: - alignment = 128 - elif n <= 4096: - alignment = 256 - elif n <= 8192: - alignment = 512 - elif n <= 16384: - alignment = 1024 - else: - alignment = 2048 - n = ((n + alignment - 1) // alignment) * alignment - - # ensure it's divisible by cp_world_size - if cp_world_size > 1: - remainder = n % cp_world_size - if remainder != 0: - n += cp_world_size - remainder - - return n - - -def _pad( - causal_seq: torch.Tensor, full_only_seq: torch.Tensor, max_causal_len: int, max_full_len: int -) -> tuple[torch.Tensor, torch.Tensor]: - causal_seq = _pad_to_N(max_causal_len, causal_seq) - full_only_seq = _pad_to_N(max_full_len, full_only_seq) - return causal_seq, full_only_seq - - -def _ensure_core_metadata(pack: SequencePack) -> None: - required = [ - "sample_offsets", - "max_sample_len", - "max_causal_len", - "max_full_len", - "_causal_indices", - "_full_indices", - "_causal_seq_offsets", - "_full_only_seq_offsets", - "is_sharded", - ] - for key in required: - if key not in pack: - raise KeyError(f"Missing required pack field: {key}") - - -def _init_sequence_pack( - sample_lens: List[int], - split_lens: List[int], - attn_modes: List[str], - device: torch.device, -) -> dict[str, Any]: - _max_sample_len = max(sample_lens) - _max_causal_len = max((split_lens[i] for i in range(len(split_lens)) if attn_modes[i] == "causal"), default=0) - _max_full_len = max((split_lens[i] for i in range(len(split_lens)) if attn_modes[i] == "full"), default=0) - - sample_lens_cu = torch.tensor([0] + sample_lens, device=device, dtype=torch.int32) # [N_samples+1] - _sample_offsets = torch.cumsum(sample_lens_cu, dim=0, dtype=torch.int32) # [N_samples+1] - - _causal_indices, _causal_seq_offsets = _compute_mode_indices_and_offsets(split_lens, attn_modes, "causal", device) - _full_indices, _full_only_seq_offsets = _compute_mode_indices_and_offsets(split_lens, attn_modes, "full", device) - - return dict( - sample_offsets=_sample_offsets, - max_sample_len=_max_sample_len, - max_causal_len=_max_causal_len, - max_full_len=_max_full_len, - _causal_indices=_causal_indices, - _full_indices=_full_indices, - _causal_seq_offsets=_causal_seq_offsets, - _full_only_seq_offsets=_full_only_seq_offsets, - _num_causal_tokens=len(_causal_indices), - _num_full_tokens=len(_full_indices), - split_lens=split_lens, - attn_modes=attn_modes, - ) - - -# ------------------------------------ -# SequencePack constructors -# ------------------------------------ - - -def _round_up_for_cuda_graphs_or_cp( - causal_seq: torch.Tensor, - full_only_seq: torch.Tensor, - need_causal: int, - need_full: int, - is_image_batch: bool, - pad_for_cuda_graphs: bool, -) -> tuple[torch.Tensor, torch.Tensor]: - """Pad causal/full sequences to the required lengths, growing global bounds for CUDA graphs.""" - if pad_for_cuda_graphs: - global \ - MAX_CAUSAL_LEN_IMAGE_BATCH, \ - MAX_FULL_LEN_IMAGE_BATCH, \ - MAX_CAUSAL_LEN_VIDEO_BATCH, \ - MAX_FULL_LEN_VIDEO_BATCH - if is_image_batch: - if need_causal > MAX_CAUSAL_LEN_IMAGE_BATCH: - MAX_CAUSAL_LEN_IMAGE_BATCH = need_causal - log.info(f"Growing MAX_CAUSAL_LEN_IMAGE_BATCH to {MAX_CAUSAL_LEN_IMAGE_BATCH}", rank0_only=False) - if need_full > MAX_FULL_LEN_IMAGE_BATCH: - MAX_FULL_LEN_IMAGE_BATCH = need_full - log.info(f"Growing MAX_FULL_LEN_IMAGE_BATCH to {MAX_FULL_LEN_IMAGE_BATCH}", rank0_only=False) - causal_seq, full_only_seq = _pad( - causal_seq, - full_only_seq, - max_causal_len=MAX_CAUSAL_LEN_IMAGE_BATCH, - max_full_len=MAX_FULL_LEN_IMAGE_BATCH, - ) - else: - if need_causal > MAX_CAUSAL_LEN_VIDEO_BATCH: - MAX_CAUSAL_LEN_VIDEO_BATCH = need_causal - log.info(f"Growing MAX_CAUSAL_LEN_VIDEO_BATCH to {MAX_CAUSAL_LEN_VIDEO_BATCH}", rank0_only=False) - if need_full > MAX_FULL_LEN_VIDEO_BATCH: - MAX_FULL_LEN_VIDEO_BATCH = need_full - log.info(f"Growing MAX_FULL_LEN_VIDEO_BATCH to {MAX_FULL_LEN_VIDEO_BATCH}", rank0_only=False) - causal_seq, full_only_seq = _pad( - causal_seq, - full_only_seq, - max_causal_len=MAX_CAUSAL_LEN_VIDEO_BATCH, - max_full_len=MAX_FULL_LEN_VIDEO_BATCH, - ) - elif need_causal != int(causal_seq.shape[0]) or need_full != int(full_only_seq.shape[0]): - causal_seq, full_only_seq = _pad(causal_seq, full_only_seq, need_causal, need_full) - return causal_seq, full_only_seq - - -def factored_from_joint_sequence( - packed_sequence: torch.Tensor, - attn_modes: List[str], - split_lens: List[int], - sample_lens: List[int], - packed_und_token_indexes: torch.Tensor, - packed_gen_token_indexes: torch.Tensor, - is_image_batch: bool = False, - cp_world_size: int = 1, - pad_for_cuda_graphs: bool = False, -) -> FactoredSequencePack: - """ - Create a factored sequence pack from a packed sequence and metadata. - NOTE: Some arguments seem redundant because they in principle support more flexible sequence setups. - This constructor checks that the required invariants for FactoredSequencePack are satisfied. - NOTE: This constructor checks that there are no "und" tokens under "full" mode, and no "gen" tokens under "causal" mode, - since this is a requirement for FactoredSequencePack. - Args: - packed_sequence (torch.Tensor): Tensor containing all tokens in the batch of sequences. - attn_modes (List[str]): List of attention modes. Must be alternating ["causal", "full", ... "causal", "full"] - split_lens (List[int]): Length of each subsequence. len(split_lens) == len(attn_modes) - sample_lens (List[int]): Length of each sequence. len(sample_lens) == number of samples. - packed_und_token_indexes (torch.Tensor): The indexes of the understanding tokens in the packed sequence. - packed_gen_token_indexes (torch.Tensor): The indexes of the generating tokens in the packed sequence. - """ - del packed_gen_token_indexes - - non_causal_text_idxs = _find_non_causal_text_token_idx(attn_modes, split_lens, packed_und_token_indexes.tolist()) - assert len(non_causal_text_idxs) == 0, "non_causal_text_idxs should be empty" - - assert sum(sample_lens) == packed_sequence.shape[0], ( - "sum(sample_lens) must be equal to the length of the packed sequence" - ) - - meta = _init_sequence_pack(sample_lens, split_lens, attn_modes, packed_sequence.device) - causal_seq = packed_sequence[meta["_causal_indices"]] # [N_causal_tokens,D] - full_only_seq = packed_sequence[meta["_full_indices"]] # [N_full_tokens,D] - - need_causal = _round_up_to_N(int(causal_seq.shape[0]), cp_world_size, pad_for_cuda_graphs) - need_full = _round_up_to_N(int(full_only_seq.shape[0]), cp_world_size, pad_for_cuda_graphs) - - causal_seq, full_only_seq = _round_up_for_cuda_graphs_or_cp( - causal_seq, - full_only_seq, - need_causal, - need_full, - is_image_batch, - pad_for_cuda_graphs, - ) - - pack: FactoredSequencePack = { - **meta, - "max_num_tokens": sum(sample_lens), - "causal_seq": causal_seq, - "full_only_seq": full_only_seq, - "is_sharded": False, - } - return pack - - -def _validate_single_dim_params(params: Mapping, layer_idx: int, num_dims: int | None) -> dict: - """ - Helper function to validate NATTEN parameters for a dimensionality profile. - - Args: - params (Mapping): parameter dict with window_size/window_size_float and other params - layer_idx (int): layer index for error messages - num_dims (int | None): 1, 2, 3, or None (for single-profile format) - - Returns: - dict: validated parameter dict with proper types - """ - if not isinstance(params, Mapping): - dim_str = f" ({num_dims}-D)" if num_dims else "" - raise ValueError(f"Parameters for layer {layer_idx}{dim_str} must be a dict or None, got {params=}.") - - is_causal = False if "is_causal" not in params else params["is_causal"] - - if "window_size_float" in params: - window_size_float = params["window_size_float"] - if ( - not isinstance(window_size_float, Sequence) - or len(window_size_float) not in [1, 2, 3] - or any(not isinstance(x, float) for x in window_size_float) - ): - raise ValueError(f"'window_size_float' must be a float tuple of size 1, 2, or 3, got {window_size_float=}") - window_size_float = tuple(k for k in window_size_float) - - num_dims = len(window_size_float) - - def check_stride_dilation(x): - if isinstance(x, float): - if 0.0 <= x <= 1.0: - return tuple(x for _ in range(num_dims)) - elif ( - isinstance(x, Sequence) - and len(x) == num_dims - and all(isinstance(y, float) and 0.0 <= y <= 1.0 for y in x) - ): - return tuple(y for y in x) - else: - raise ValueError(f"Invalid natten float parameter: {x=}") - - stride_float = 0.0 if "stride_float" not in params else params["stride_float"] - dilation_float = 0.0 if "dilation_float" not in params else params["dilation_float"] - - stride_float = check_stride_dilation(stride_float) - dilation_float = check_stride_dilation(dilation_float) - is_causal = check_valid_tuple_or_element( - is_causal, num_dims=num_dims, typename=bool, raise_error=True, param_name="is_causal" - ) - - if any(x in params for x in ["window_size", "stride", "dilation"]): - raise ValueError( - f"Please either use _float parameters, or integer ones, and not mix the two. Got {params=}." - ) - - return { - "window_size_float": window_size_float, - "stride_float": stride_float, - "dilation_float": dilation_float, - "is_causal": is_causal, - } - - elif "window_size" in params: - window_size = params["window_size"] - num_dims = len(window_size) - - stride = 1 if "stride" not in params else params["stride"] - dilation = 1 if "dilation" not in params else params["dilation"] - - if any("_float" in x for x in params.keys()): - raise ValueError( - f"Please either use _float parameters, or integer ones, and not mix the two. Got {params=}." - ) - - window_size = check_valid_tuple_or_element( - window_size, num_dims=num_dims, typename=int, raise_error=True, param_name="window_size" - ) - stride = check_valid_tuple_or_element( - stride, num_dims=num_dims, typename=int, raise_error=True, param_name="stride" - ) - dilation = check_valid_tuple_or_element( - dilation, num_dims=num_dims, typename=int, raise_error=True, param_name="dilation" - ) - is_causal = check_valid_tuple_or_element( - is_causal, num_dims=num_dims, typename=bool, raise_error=True, param_name="is_causal" - ) - - return {"window_size": window_size, "stride": stride, "dilation": dilation, "is_causal": is_causal} - else: - raise ValueError( - "Sparse parameters for a layer must have key 'window_size' or 'window_size_float', " - f"got {params=} in layer index {layer_idx}." - ) - - -def verify_natten_parameter_list( - natten_parameter_list: list | None, - num_layers: int, -) -> list | None: - """ - Converts list of NATTEN parameters into expected types, and assigns defaults to unset - parameters. - This needs to be done separately during model initialization, and not forward pass. - There are no torch operations in this function. - - Args: - natten_parameter_list (list | None): list of NATTEN parameters. Must be either None, or a - list of mappings, one for each layer. Each list element must be either None, - representing no sparsity / masking (full dense attention), or a mapping of NATTEN - parameters. - - Parameters can be specified directly with integer or float format: - - 'window_size_float' (required), 'stride_float', 'dilation_float' - - 'window_size' (required), 'stride', 'dilation' - - Or, parameters can be specified for multiple dimensionality profiles in case of - mixed-training (i.e. image and video training) using keys "1d", "2d", "3d": - - Each key maps to either None (dense attention) or a parameter dict - - Integer and float parameters cannot be used together in the same layer! - Additionally, you can specify 'is_causal'. - - Examples: - ``` - # 50 percent sparsity along each dimension in a 2-D token layout - {'window_size_float': (0.5, 0.5)} # valid - - # 50 percent sparsity along each dimension in a 2-D token layout - # Maximum dilation along first dimension, no dilation along second dimension - {'window_size_float': (0.5, 0.5), 'dilation_float': (1.0, 0.0)} # valid - - # Fixed window size of 8x8, dilation of 2x1. - # NOTE: requires ALL inputs to be at least 16x8 - {'window_size': (8, 8), 'dilation': (2, 1)} # valid - - # Multi-profile: different parameters for 2D (images) and 3D (videos) - { - "2d": {"window_size_float": (0.5, 0.5)}, - "3d": {"window_size_float": (1.0, 0.5, 0.5)} - } # valid - - # Multi-profile: 2D uses dense attention, 3D uses sparse - { - "2d": None, - "3d": {"window_size_float": (1.0, 0.5, 0.5)} - } # valid - - # Invalid: - {'window_size_float': (0.5, 0.5), 'dilation': (2, 1)} - ``` - - num_layers (int): number of layers in the model. Just used to verify list length. - - Returns: - output_parameter_list (list | None): verified and type-checked NATTEN parameters, or None if - no parameters passed. - """ - - if natten_parameter_list is not None: - parameter_list_out = [] - if not isinstance(natten_parameter_list, Sequence): - raise ValueError(f"Argument 'natten_parameter_list' must be a list or None, got {natten_parameter_list=}.") - - if len(natten_parameter_list) != num_layers: - raise ValueError( - "Number of elements in 'natten_parameter_list' must match number of layers " - f"in the model, got {num_layers=}, {len(natten_parameter_list)=}." - ) - - for i, layer_parameters in enumerate(natten_parameter_list): - if layer_parameters is None: - log.debug(f"Layer {i} will use DENSE attention.") - parameter_list_out.append(None) - continue - - if not isinstance(layer_parameters, Mapping): - raise ValueError( - f"Sparse parameters for a layer must be a dict or None, got {layer_parameters=} in layer index {i}." - ) - - # Detect format: multi-profile if has keys "1d", "2d", or "3d" - dim_keys = {"1d", "2d", "3d"} - has_dim_keys = any(k in layer_parameters for k in dim_keys) - - if has_dim_keys: - # Multi-profile format: validate each explicitly defined dimensionality profile - validated_multi_profile = {} - for dim_str, dim_int in [("1d", 1), ("2d", 2), ("3d", 3)]: - if dim_str in layer_parameters: - dim_params = layer_parameters[dim_str] - if dim_params is None: - validated_multi_profile[dim_int] = None - else: - validated_multi_profile[dim_int] = _validate_single_dim_params(dim_params, i, dim_int) - else: - # Single-profile format: validate and convert to multi-profile format - # Infer dimensionality from parameter tuple length - validated_params = _validate_single_dim_params(layer_parameters, i, None) - if "window_size_float" in validated_params: - num_dims = len(validated_params["window_size_float"]) - else: # "window_size" - num_dims = len(validated_params["window_size"]) - validated_multi_profile = {num_dims: validated_params} - - # If all explicitly defined profiles are None, treat as fully dense layer - if all(v is None for v in validated_multi_profile.values()): - log.debug(f"Layer {i} will use DENSE attention (all profiles None).") - parameter_list_out.append(None) - else: - parameter_list_out.append(validated_multi_profile) - log.info(f"Layer {i} NATTEN parameters: {validated_multi_profile}") - - return parameter_list_out - - return None - - -def generate_natten_metadata( - token_shapes: list[tuple[int, int, int]], - head_dim: int, - num_layers: int, - device: torch.device, - dtype: torch.dtype, - requires_grad: bool, - natten_parameter_list: list | None = None, -) -> list | None: - """ - Generates list of metadata required by Variable-Sized (variable-length) operations in NATTEN. - Required when training with three_way attention and NATTEN (multi-dimensional / sparse - attention). - - Args: - token_shapes (list[tuple]): list of integer tuples corresponding to the - post-tokenization/patchify token layout shapes in the packed sequence. Must strictly be - integer tuples with the same profile (all 1D, 2D, or 3D). 1s will be automatically - stripped (i.e. [(1, 8, 8), (1, 16, 16)] is interpreted as [(8, 8), (16, 16)]). - - head_dim (int): Attention head dimension (used to select NATTEN kernel configurations). - - num_layers (int): number of layers in the model. Just used to verify list length. - - device (torch.device): PyTorch device for offset tensors (should match QKV device). - - dtype (torch.dtype): Expected QKV dtype. - - requires_grad (bool): Determines whether backprop is expected, and sets up metadata for - backward pass as well. - - natten_parameter_list (list | None): list of NATTEN parameters. Must be either None, or a - list of mappings, one for each layer. Each list element must be either None, - representing no sparsity / masking (full dense attention), or a mapping of NATTEN - parameters in either integer or float format: - - 'window_size_float' (required), 'stride_float', 'dilation_float' - - 'window_size' (required), 'stride', 'dilation' - - Integer and float parameters cannot be used together in the same layer! - Additionally, you can specify 'is_causal'. - - Examples: - ``` - # 50 percent sparsity along each dimension in a 2-D token layout - {'window_size_float': (0.5, 0.5)} # valid - - # 50 percent sparsity along each dimension in a 2-D token layout - # Maximum dilation along first dimension, no dilation along second dimension - {'window_size_float': (0.5, 0.5), 'dilation_float': (1.0, 0.0)} # valid - - # Fixed window size of 8x8, dilation of 2x1. - # NOTE: requires ALL inputs to be at least 16x8 - {'window_size': (8, 8), 'dilation': (2, 1)} # valid - - # Invalid: - {'window_size_float': (0.5, 0.5), 'dilation': (2, 1)} - ``` - - Returns: - natten_metadata_list (list | None): list of NATTEN varlen metadata, or Nones (dense layers). - Each non-None element will be a dictionary containing final parameters, and varlen - metadata (offset and size tensors, max lengths). - NOTE: to avoid excessive recompilations in torch.compile, we must carefully index into - this list during model.forward, and ideally using the iteration counter from the loop - over layers (nn.ModuleList). - """ - - - if token_shapes is None or len(token_shapes) < 1: - raise ValueError("'token_shapes' is required for 'three_way' attention.") - - natten_metadata = None - - if natten_parameter_list is not None: - natten_metadata = [] - if not isinstance(natten_parameter_list, list): - raise ValueError(f"Argument 'natten_parameter_list' must be a list or None, got {natten_parameter_list=}.") - - if len(natten_parameter_list) != num_layers: - raise ValueError( - "Number of elements in 'natten_parameter_list' must match number of layers " - f"in the model, got {num_layers=}, {len(natten_parameter_list)=}." - ) - - # We need to filter out 1s from shapes - def filter_shape(shape: tuple) -> tuple: - return tuple(x for x in shape if x > 1) - - # Infer token layout rank (dimensionality) - num_dims = max([len(filter_shape(token_shape)) for token_shape in token_shapes]) - - # Single pass: check if all layers support this dimensionality and if any need processing - needs_processing = False - for i, layer_parameters in enumerate(natten_parameter_list): - if layer_parameters is None: - continue - - # Fail fast if this dimensionality is not defined - if num_dims not in layer_parameters: - raise ValueError( - f"Layer {i}: batch has {num_dims}D data but parameters are not defined for {num_dims}D. " - f"Defined dimensionalities: {sorted(layer_parameters.keys())}" - ) - - # Check if this layer needs processing for this dimensionality - if layer_parameters[num_dims] is not None: - needs_processing = True - - # Early exit if all layers are dense for this dimensionality profile - if not needs_processing: - log.debug(f"All layers use DENSE attention for {num_dims}D data.") - return None - - # We actually need to process, so validate and filter all shapes - token_layout_list = [] - for shape in token_shapes: - assert isinstance(shape, tuple) - shape_filtered = filter_shape(shape) - assert len(shape_filtered) == num_dims, ( - f"All data in batch must have same dimensionality, got {num_dims}D and {len(shape_filtered)}D" - ) - token_layout_list.append(shape_filtered) - - log.debug(f"Batch dimensionality: {num_dims}D, token_layout_list={token_layout_list}") - - for i, layer_parameters in enumerate(natten_parameter_list): - if layer_parameters is None: - natten_metadata.append(None) - continue - - # Get parameters for this dimensionality (already validated above) - dim_params = layer_parameters[num_dims] - - if dim_params is None: - # Dense attention for this dimensionality - natten_metadata.append(None) - continue - - # Use dim_params (parameters for this specific dimensionality) - window_size_list = [] - stride_list = [] - dilation_list = [] - - if "window_size_float" in dim_params: - window_size_float = dim_params["window_size_float"] - stride_float = dim_params["stride_float"] - dilation_float = dim_params["dilation_float"] - - for token_layout in token_layout_list: - window_size_ = tuple( - min(x, max(2, int(k * float(x)))) for k, x in zip(window_size_float, token_layout) - ) - stride_ = tuple(min(k, max(1, int(s * float(k)))) for s, k in zip(stride_float, window_size_)) - max_dilation = tuple(x // k for k, x in zip(window_size_, token_layout)) - dilation_ = tuple(min(m, max(1, int(d * float(m)))) for d, m in zip(dilation_float, max_dilation)) - - window_size_list.append(window_size_) - stride_list.append(stride_) - dilation_list.append(dilation_) - - assert len(window_size_list) == len(stride_list) == len(dilation_list) == len(token_layout_list) - - log.debug(f"Layer {i}: {window_size_list=}") - log.debug(f"Layer {i}: {stride_list=}") - log.debug(f"Layer {i}: {dilation_list=}") - - elif "window_size" in dim_params: - window_size = dim_params["window_size"] - stride = dim_params["stride"] - dilation = dim_params["dilation"] - - window_size_list = [window_size for _ in range(len(token_layout_list))] - stride_list = [stride for _ in range(len(token_layout_list))] - dilation_list = [dilation for _ in range(len(token_layout_list))] - else: - raise ValueError( - "Sparse parameters for a layer must have key 'window_size' or 'window_size_float', " - f"got {dim_params=} in layer index {i}." - ) - - is_causal = dim_params["is_causal"] - - # Create varlen metadata for natten varlen/varsized ops - # NOTE: generate_multi_dim_varlen_parameters will automatically map window size -1 to - # full size, that's why constant window sizes aren't allowed. - # NOTE: if any of the parameters are constant, natten will simplify them - natten_metadata.append( - generate_multi_dim_varlen_parameters( - token_layout_list=token_layout_list, - head_dim=head_dim, - device=device, - dtype=dtype, - requires_grad=requires_grad, - # - window_size_list=window_size_list, - stride_list=stride_list, - dilation_list=dilation_list, - # - is_causal=is_causal, - ) - ) - - return natten_metadata - - -def generate_temporal_causal_natten_metadata( - vision_token_shapes: list[tuple[int, int, int]], - num_action_tokens_per_supertoken: int, - num_layers: int, - head_dim: int, - device: torch.device, - dtype: torch.dtype, - requires_grad: bool, -) -> list: - """Generate per-layer varlen metadata for temporal causal attention on supertokens. - - Each sample's generation tokens are laid out as T_i supertokens of size - S_i = num_action_tokens_per_supertoken + H_i*W_i. Metadata encodes - is_causal=(True, False): causal across T, full within S. All layers share - the same metadata (full window, no spatial sparsity). - - Unlike generate_natten_metadata, this function does not apply filter_shape — (T, S) layouts - are passed directly even when T=1. NATTEN handles T=1 causal masking correctly (trivially - full attention within S). - - Args: - vision_token_shapes: List of (T, H, W) per sample. - num_action_tokens_per_supertoken: Number of action tokens prefixing each - supertoken (0 when actions are not packed inline). - num_layers: Number of transformer layers. - head_dim: Attention head dimension. - device: Target device. - dtype: Target dtype. - requires_grad: Whether metadata tensors require gradient. - - Returns: - List of length num_layers, each element the same NATTEN varlen metadata dict. - """ - # T=1: NATTEN requires kernel_size >= 2 and kernel_size <= token_layout, which are mutually - # exclusive when T=1. Fall back to full dense attention (None) — a single supertoken trivially - # attends to only itself, so temporal causality is already satisfied. - # Mixed T=1/T>1 batches are rejected: NATTEN can't mask T=1 samples, and falling back to dense - # attention for the whole batch would break temporal causality for the T>1 samples. - # Ensure min_frames >= 5 in the dataloader so that T_latent = 1 + (N-1)//tcf >= 2 always. - has_short = any(t < 2 for t, h, w in vision_token_shapes) - if has_short: - if not all(t < 2 for t, h, w in vision_token_shapes): - raise ValueError( - "Mixed T=1 and T>1 samples in causal training batch: NATTEN cannot apply " - "causal masking when any sample has T=1 (kernel_size constraint), and falling " - "back to dense attention would break temporal causality for T>1 samples. " - "Ensure all samples have T_latent >= 2 (set min_frames >= 5 in the dataloader)." - ) - return [None] * num_layers - token_layout_list = [(t, num_action_tokens_per_supertoken + h * w) for t, h, w in vision_token_shapes] - metadata = generate_multi_dim_varlen_parameters( - token_layout_list=token_layout_list, - head_dim=head_dim, - device=device, - dtype=dtype, - requires_grad=requires_grad, - is_causal=(True, False), - ) - return [metadata] * num_layers - - -def joint_from_joint_sequence( - packed_sequence: torch.Tensor, - attn_modes: List[str], - split_lens: List[int], - sample_lens: List[int], - packed_und_token_indexes: torch.Tensor, - packed_gen_token_indexes: torch.Tensor, - is_image_batch: bool = False, - cp_world_size: int = 1, - pad_for_cuda_graphs: bool = False, -) -> JointSequencePack: - f""" - Create a JointSequencePack from a packed sequence and metadata. - This is in order to support the legacy joint flex-attention implementation. - Differently from FactoredSequencePack, it has less strict requirements on the packed sequence. - - Args: - packed_sequence (torch.Tensor): Tensor containing all tokens in the batch of sequences. - attn_modes (List[str]): List of attention modes. Supports any sequence of {"causal", "full", "noise"} - split_lens (List[int]): Length of each subsequence. len(split_lens) == len(attn_modes) - sample_lens (List[int]): Length of each sequence. In this mode, sequences may have different number of splits, - as opposed to FactoredSequencePack where each sequence has exactly two splits.. - packed_und_token_indexes (torch.Tensor): The indexes of the understanding tokens in the packed sequence. - packed_gen_token_indexes (torch.Tensor): The indexes of the generating tokens in the packed sequence. - """ - assert sum(sample_lens) == packed_sequence.shape[0], ( - "sum(sample_lens) must be equal to the length of the packed sequence" - ) - meta = _init_sequence_pack(sample_lens, split_lens, attn_modes, packed_sequence.device) - pack: JointSequencePack = { - **meta, - "max_num_tokens": sum(sample_lens), - "packed_sequence": packed_sequence, - "packed_und_token_indexes": packed_und_token_indexes, - "packed_gen_token_indexes": packed_gen_token_indexes, - "is_sharded": False, - } - return pack - - -def zeros_like(orig: FactoredSequencePack | JointSequencePack, shape: Tuple[int, ...] | torch.Size | None = None): - """ - Create a new sequence pack with the same metadata as the original, but with all tokens set to zero. - Args: - orig (FactoredSequencePack | JointSequencePack): The original sequence pack to copy metadata from. - shape (Tuple[int, ...] | torch.Size | None): The shape of the new sequence pack. If None, the shape will be the same as the original. - """ - _ensure_core_metadata(orig) - if "packed_sequence" in orig: - if shape is None: - shape_ = orig["packed_sequence"].shape - else: - assert len(shape) >= 1 and shape[0] == -1 - shape_ = (orig["packed_sequence"].shape[0],) + tuple(shape)[1:] - packed_sequence = torch.zeros( - shape_, device=orig["packed_sequence"].device, dtype=orig["packed_sequence"].dtype - ) # [seq_len,D] - return from_joint(packed_sequence, orig) - else: - if shape is None: - shape_causal = orig["causal_seq"].shape - shape_full = orig["full_only_seq"].shape - else: - assert len(shape) >= 1 and shape[0] == -1 - shape_causal = (orig["causal_seq"].shape[0],) + tuple(shape)[1:] - shape_full = (orig["full_only_seq"].shape[0],) + tuple(shape)[1:] - causal_seq = torch.zeros( - shape_causal, device=orig["causal_seq"].device, dtype=orig["causal_seq"].dtype - ) # [N_causal_tokens,D] - full_only_seq = torch.zeros( - shape_full, device=orig["full_only_seq"].device, dtype=orig["full_only_seq"].dtype - ) # [N_full_tokens,D] - return from_mode_splits(causal_seq, full_only_seq, orig) - - -def from_joint(packed_sequence: torch.Tensor, metadata_source: FactoredSequencePack | JointSequencePack): - """ - Create a new sequence pack from a packed sequence and another sequence pack with the same metadata. - Args: - packed_sequence (torch.Tensor): Tensor containing all tokens in the batch of sequences. - metadata_source (FactoredSequencePack | JointSequencePack): The metadata source to copy from. - """ - _ensure_core_metadata(metadata_source) - if "packed_sequence" in metadata_source: - out = dict(metadata_source) - out["packed_sequence"] = packed_sequence - return out - else: - if metadata_source["is_sharded"]: - # Use sharded sequences as is when is_sharded is True (used in Context Parallel) - causal_seq = packed_sequence[: len(metadata_source["causal_seq"])] # [N_causal_tokens,D] - full_only_seq = packed_sequence[len(metadata_source["causal_seq"]) :] # [N_full_tokens,D] - else: - causal_seq = packed_sequence[metadata_source["_causal_indices"]] # [N_causal_tokens,D] - full_only_seq = packed_sequence[metadata_source["_full_indices"]] # [N_full_tokens,D] - causal_seq, full_only_seq = _pad( - causal_seq, - full_only_seq, - max_causal_len=metadata_source["causal_seq"].shape[0], - max_full_len=metadata_source["full_only_seq"].shape[0], - ) - - return from_mode_splits(causal_seq, full_only_seq, metadata_source) - - -def from_mode_splits( - causal_seq: torch.Tensor, - full_only_seq: torch.Tensor, - orig: FactoredSequencePack | JointSequencePack, - is_sharded: bool | None = None, -): - """ - Create a new sequence pack from two mode splits. - Args: - causal_seq (torch.Tensor): The causal sequence. - full_only_seq (torch.Tensor): The full-only sequence. - orig (FactoredSequencePack | JointSequencePack): The metadata source to copy from. - is_sharded (bool | None): If True, create a local pack for context parallel. - If None, inherits from orig. - """ - _ensure_core_metadata(orig) - if is_sharded is None: - is_sharded = orig.get("is_sharded", False) - - if "packed_sequence" in orig: - all_len = int(orig["_causal_indices"].shape[0] + orig["_full_indices"].shape[0]) - packed_sequence = causal_seq.new_zeros((all_len, *causal_seq.shape[1:])) # [seq_len,D] - packed_sequence[orig["_causal_indices"]] = causal_seq - packed_sequence[orig["_full_indices"]] = full_only_seq - return from_joint(packed_sequence, orig) - else: - out = dict(orig) - out["causal_seq"] = causal_seq - out["full_only_seq"] = full_only_seq - out["is_sharded"] = is_sharded - return out - - -def from_und_gen_splits(und_seq: torch.Tensor, gen_seq: torch.Tensor, orig: FactoredSequencePack | JointSequencePack): - """ - Create a new sequence pack from two und/gen splits. - Args: - und_seq (torch.Tensor): The understanding sequence. - gen_seq (torch.Tensor): The generating sequence. - orig (FactoredSequencePack | JointSequencePack): The metadata source to copy from. - """ - # If we have a joint pack (single packed_sequence), place by und/gen indexes. - if "packed_sequence" in orig and "packed_und_token_indexes" in orig and "packed_gen_token_indexes" in orig: - all_len = int(und_seq.shape[0] + gen_seq.shape[0]) - packed_sequence = und_seq.new_zeros((all_len, *und_seq.shape[1:])) # [seq_len,D] - packed_sequence[orig["packed_und_token_indexes"]] = und_seq - packed_sequence[orig["packed_gen_token_indexes"]] = gen_seq - return from_joint(packed_sequence, orig) - # Otherwise, treat und/gen as mode splits (und == causal; gen == full). - return from_mode_splits(und_seq, gen_seq, orig) - - -# ------------------------------------ -# Getters and setters for SequencePack -# ------------------------------------ -def get_und_seq(pack: SequencePack) -> torch.Tensor: - """ - Get all understanding tokens in a sequence pack in a single tensor. - - Args: - pack (FactoredSequencePack | JointSequencePack): The sequence pack to get the understanding sequence from. - Returns: - torch.Tensor: All understanding tokens concatenated over all sequences in the batch. - """ - if "causal_seq" in pack: - return pack["causal_seq"] - if "packed_sequence" in pack and "packed_und_token_indexes" in pack: - return pack["packed_sequence"][pack["packed_und_token_indexes"]] - raise KeyError("Cannot derive und_seq from provided pack") - - -def set_und_seq(pack: SequencePack, value: torch.Tensor) -> None: - """ - Override the understanding tokens in a sequence pack. - The order of tokens passed in must correspond to the order of tokens returned by get_und_seq. - - Args: - pack (FactoredSequencePack | JointSequencePack): The sequence pack to set the understanding sequence in. - value (torch.Tensor): The understanding sequence to set. - """ - if "packed_sequence" in pack and "packed_und_token_indexes" in pack: - pack["packed_sequence"][pack["packed_und_token_indexes"]] = value - elif "causal_seq" in pack: - pack["causal_seq"] = value - else: - raise KeyError("Cannot set und_seq from provided pack") - - -def get_gen_seq(pack: SequencePack) -> torch.Tensor: - """ - Get all generating tokens in a sequence pack in a single tensor. - Args: - pack (FactoredSequencePack | JointSequencePack): The sequence pack to get the generating sequence from. - Returns: - torch.Tensor: All generating tokens concatenated over all sequences in the batch. - """ - if "full_only_seq" in pack: - return pack["full_only_seq"] - if "packed_sequence" in pack and "packed_gen_token_indexes" in pack: - return pack["packed_sequence"][pack["packed_gen_token_indexes"]] - raise KeyError("Cannot derive gen_seq from provided pack") - - -def set_gen_seq(pack: SequencePack, value: torch.Tensor) -> None: - """ - Override the generating tokens in a sequence pack. - The order of tokens passed in must correspond to the order of tokens returned by get_gen_seq. - Args: - pack (FactoredSequencePack | JointSequencePack): The sequence pack to set the generating sequence in. - value (torch.Tensor): The generating sequence to set. - """ - if "packed_sequence" in pack and "packed_gen_token_indexes" in pack: - pack["packed_sequence"][pack["packed_gen_token_indexes"]] = value - elif "full_only_seq" in pack: - pack["full_only_seq"] = value - else: - raise KeyError("Cannot set gen_seq from provided pack") - - -def get_all_seq(pack: SequencePack) -> torch.Tensor: - """ - Get all tokens in a sequence pack in a single tensor. - Args: - pack (FactoredSequencePack | JointSequencePack): The sequence pack to get the all sequence from. - Returns: - torch.Tensor: All tokens concatenated over all sequences in the batch. - """ - if "all_seq" in pack: - return pack["all_seq"] - if "packed_sequence" in pack: - return pack["packed_sequence"] - if "causal_seq" in pack and "full_only_seq" in pack: - _ensure_core_metadata(pack) - if pack["is_sharded"]: - assert False, "get_all_seq is not supported in context parallel sharded mode" - else: - out = pack["causal_seq"].new_zeros( - int(pack["_causal_indices"].shape[0] + pack["_full_indices"].shape[0]), *pack["causal_seq"].shape[1:] - ) # [seq_len,D] - if pack["causal_seq"].shape[0] > 0: - out[pack["_causal_indices"]] = pack["causal_seq"][: pack["_causal_indices"].shape[0]] - if pack["full_only_seq"].shape[0] > 0: - out[pack["_full_indices"]] = pack["full_only_seq"][: pack["_full_indices"].shape[0]] - return out - raise KeyError("Cannot derive all_seq from provided pack") - - -def set_all_seq(pack: SequencePack, value: torch.Tensor) -> None: - """ - Override the all tokens in a sequence pack. - The order of tokens passed in must correspond to the order of tokens returned by get_all_seq. - Args: - pack (FactoredSequencePack | JointSequencePack): The sequence pack to set the all sequence in. - value (torch.Tensor): The all sequence to set. - """ - if "packed_sequence" in pack: - pack["packed_sequence"] = value - elif "causal_seq" in pack and "full_only_seq" in pack: - _ensure_core_metadata(pack) - pack["causal_seq"][: pack["_causal_indices"].shape[0]] = value[pack["_causal_indices"]] - pack["full_only_seq"][: pack["_full_indices"].shape[0]] = value[pack["_full_indices"]] - else: - pack["all_seq"] = value - - -def get_causal_seq(pack: SequencePack) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Get the causal sequence and its offsets in a sequence pack. - Args: - pack (FactoredSequencePack | JointSequencePack): The sequence pack to get the causal sequence from. - Returns: - Tuple[torch.Tensor, torch.Tensor]: The concatenated causal sub-sequences and the starting offset for each sub-sequence. - """ - _ensure_core_metadata(pack) - if "causal_seq" in pack: - return pack["causal_seq"], pack["_causal_seq_offsets"] - assert "packed_sequence" in pack - return pack["packed_sequence"][pack["_causal_indices"]], pack["_causal_seq_offsets"] - - -def get_full_only_seq(pack: SequencePack) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Get the full-only sequence and its offsets in a sequence pack. - Args: - pack (FactoredSequencePack | JointSequencePack): The sequence pack to get the full-only sequence from. - Returns: - Tuple[torch.Tensor, torch.Tensor]: The concatenated full-only sub-sequences and the starting offset for each sub-sequence. - """ - _ensure_core_metadata(pack) - if "full_only_seq" in pack: - return pack["full_only_seq"], pack["_full_only_seq_offsets"] - assert "packed_sequence" in pack - return pack["packed_sequence"][pack["_full_indices"]], pack["_full_only_seq_offsets"] - - -def get_device_and_dtype(pack: SequencePack) -> Tuple[torch.device, torch.dtype]: - """ - Get the device and dtype of a sequence pack. - Args: - pack (FactoredSequencePack | JointSequencePack): The sequence pack to get the device and dtype from. - Returns: - Tuple[torch.device, torch.dtype]: The device and dtype of the sequence pack. - """ - if "packed_sequence" in pack: - return pack["packed_sequence"].device, pack["packed_sequence"].dtype - if "causal_seq" in pack and "full_only_seq" in pack: - return pack["causal_seq"].device, pack["causal_seq"].dtype - raise KeyError("Cannot derive device and dtype from provided pack") - - -def build_sequence_plans_from_data_batch( - data_batch: dict, - input_video_key, - input_image_key: str, -) -> list[SequencePlan]: - """Build or retrieve sequence plans from a data batch dictionary. - - This function extracts sequence plans from the data batch if they exist, - otherwise creates default SequencePlan objects for each sample - in the batch. - - Args: - data_batch: Dictionary containing the data batch from the dataloader. - Expected keys include 'video' or other tensors to determine batch size. - If 'sequence_plan' key exists, those plans are returned directly. - - Returns: - List of SequencePlan objects, one per sample in the batch. - """ - # For new modalities, please generate the sequence_plan in the dataset class!!!! - - # If sequence_plan already exists in data_batch, return it - if "sequence_plan" in data_batch: - return data_batch["sequence_plan"] - - assert "action" not in data_batch or data_batch["action"] is None, "Action data SHOULD have sequence_plans!" - assert "sound" not in data_batch or data_batch["sound"] is None, "Sound data SHOULD have sequence_plans!" - - # Determine batch size from available tensors - batch_size = 0 - for key in [input_video_key, input_image_key]: - if key in data_batch: - val = data_batch[key] - if isinstance(val, torch.Tensor): - batch_size = val.shape[0] - break - elif isinstance(val, list): - batch_size = len(val) - break - - if batch_size == 0: - raise ValueError( - f"Cannot determine batch size from data_batch. Expected {input_video_key}, {input_image_key}, or similar key." - ) - - # Build default SequencePlan objects - return [ - SequencePlan( - has_text=True, # Has text prompt! - has_vision=True, - condition_frame_indexes_vision=[], # No conditioning frames! - ) - for _ in range(batch_size) - ] - - -# ============================================================================ -# Demo/Test function -# ============================================================================ - - -def main(): - """Demonstrate sequence packing with sample text and images.""" - # Initialize tokenizer and add special tokens - tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") - tokenizer, _ = add_special_tokens(tokenizer) - - # Define special tokens (Note: Qwen models don't have bos_token_id) - special_tokens = { - "eos_token_id": tokenizer.eos_token_id, - "start_of_generation": tokenizer.convert_tokens_to_ids("<|vision_start|>"), - "end_of_generation": tokenizer.convert_tokens_to_ids("<|vision_end|>"), - } - - # Sample text inputs - input_strings = ["Hello world", "How are you?", "I am fine"] - - # Tokenize input strings - input_text_tokens = [tokenizer.encode(text, add_special_tokens=False) for text in input_strings] - - # Create sample images (in practice, these would be VAE latents) - input_images = torch.stack([torch.randn(3, 1, 64, 64) for _ in range(3)]) # [B, C, T, H, W] format - - # Diffusion timesteps for each image - input_timesteps = torch.tensor([0.0, 0.5, 0.9]) - - # Create GenerationDataClean for images - gen_data_clean_images = GenerationDataClean( - batch_size=3, - is_image_batch=True, - raw_state_vision=input_images, - x0_tokens_vision=torch.randn(3, 16, 8, 8), # dummy tokenized latents - raw_state_action=None, - ) - - # Create SequencePlan for each sample (all have text and vision) - sequence_plans = [ - SequencePlan( - has_text=True, - has_vision=True, - has_action=False, - condition_frame_indexes_vision=[], - condition_frame_indexes_action=[], - ) - for _ in range(3) - ] - - # Pack sequences - packed_data = pack_input_sequence( - sequence_plans=sequence_plans, - input_text_indexes=input_text_tokens, - gen_data_clean=gen_data_clean_images, - input_timesteps=input_timesteps, - special_tokens=special_tokens, - include_end_of_generation_token=True, - ) - - # Display results (after finalize, fields are tensors) - print(f"Packed sequence length: {packed_data.sequence_length}") - assert isinstance(packed_data.text_ids, torch.Tensor) - print(f"Packed text IDs shape: {packed_data.text_ids.shape}") - if packed_data.vision: - assert isinstance(packed_data.vision.sequence_indexes, torch.Tensor) - print(f"VAE token indexes shape: {packed_data.vision.sequence_indexes.shape}") - print(f"Packed position_ids: {packed_data.position_ids}") - - ################## - ## Video data - input_videos = torch.stack([torch.randn(3, 5, 64, 64) for _ in range(2)]) # [B, C, T, H, W] format - - # Diffusion timesteps for each video - input_timesteps_video = torch.tensor([0.5, 0.9]) - - # Create GenerationDataClean for videos - gen_data_clean_videos = GenerationDataClean( - batch_size=2, - is_image_batch=False, - raw_state_vision=input_videos, - x0_tokens_vision=torch.randn(2, 16, 2, 8, 8), # dummy tokenized latents - raw_state_action=None, - ) - - # Create SequencePlan for video samples - sequence_plans_video = [ - SequencePlan( - has_text=True, - has_vision=True, - has_action=False, - condition_frame_indexes_vision=[], - condition_frame_indexes_action=[], - ) - for _ in range(2) - ] - - # Pack sequences - packed_data = pack_input_sequence( - sequence_plans=sequence_plans_video, - input_text_indexes=input_text_tokens[0:2], - gen_data_clean=gen_data_clean_videos, - input_timesteps=input_timesteps_video, - special_tokens=special_tokens, - include_end_of_generation_token=True, - ) - - # Display results (after finalize, fields are tensors) - print(f"Packed sequence length: {packed_data.sequence_length}") - assert isinstance(packed_data.text_ids, torch.Tensor) - print(f"Packed text IDs shape: {packed_data.text_ids.shape}") - if packed_data.vision: - assert isinstance(packed_data.vision.sequence_indexes, torch.Tensor) - print(f"VAE token indexes shape: {packed_data.vision.sequence_indexes.shape}") - print(f"Packed position_ids: {packed_data.position_ids}") - - -def get_und_position_ids(position_ids: torch.Tensor, meta: dict[str, Any]) -> torch.Tensor: - """ - Get the understanding position ids in a sequence pack. - Args: - position_ids (torch.Tensor): The position ids. Shape (seq_len,) for 1D RoPE - or (3, seq_len) for 3D mRoPE. - meta (dict[str, Any]): The metadata. - Returns: - torch.Tensor: The understanding position ids. - """ - assert not meta["is_sharded"], "get_und_position_ids is not supported in context parallel sharded mode" - if position_ids.dim() == 2: - # 3D mRoPE: position_ids is (3, seq_len) - return position_ids[:, meta["_causal_indices"]] # [3,N_causal_tokens] - return position_ids[meta["_causal_indices"]] # [N_causal_tokens] - - -def get_gen_position_ids(position_ids: torch.Tensor, meta: dict[str, Any]) -> torch.Tensor: - """ - Get the generating position ids in a sequence pack. - Args: - position_ids (torch.Tensor): The position ids. Shape (seq_len,) for 1D RoPE - or (3, seq_len) for 3D mRoPE. - meta (dict[str, Any]): The metadata. - Returns: - torch.Tensor: The generating position ids. - """ - assert not meta["is_sharded"], "get_gen_position_ids is not supported in context parallel sharded mode" - if position_ids.dim() == 2: - # 3D mRoPE: position_ids is (3, seq_len) - return position_ids[:, meta["_full_indices"]] # [3,N_full_tokens] - return position_ids[meta["_full_indices"]] # [N_full_tokens] - - -if __name__ == "__main__": - main() diff --git a/cosmos_framework/data/vfm/sequence_packing/__init__.py b/cosmos_framework/data/vfm/sequence_packing/__init__.py new file mode 100644 index 0000000..3e91c62 --- /dev/null +++ b/cosmos_framework/data/vfm/sequence_packing/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""High-level entry points for VFM sequence packing.""" + +from cosmos_framework.data.vfm.sequence_packing.packers import pack_input_sequence +from cosmos_framework.data.vfm.sequence_packing.types import ( + ModalityData, + PackedSequence, + SequencePlan, + build_sequence_plans_from_data_batch, +) + +__all__ = [ + "ModalityData", + "PackedSequence", + "SequencePlan", + "build_sequence_plans_from_data_batch", + "pack_input_sequence", +] diff --git a/cosmos_framework/data/vfm/sequence_packing/modalities.py b/cosmos_framework/data/vfm/sequence_packing/modalities.py new file mode 100644 index 0000000..e30b188 --- /dev/null +++ b/cosmos_framework/data/vfm/sequence_packing/modalities.py @@ -0,0 +1,585 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Modality-specific append helpers for VFM sequence packing.""" + +import math +from typing import Dict, List, Tuple + +import torch + +from cosmos_framework.data.vfm.sequence_packing.mrope import ( + get_3d_mrope_ids_text_tokens, + get_3d_mrope_ids_vae_tokens, +) +from cosmos_framework.data.vfm.sequence_packing.types import ModalityData, PackedSequence + + +def prepare_attention_mask_per_sample(split_lens, attn_modes, device="cpu"): + """Prepare dense attention mask for a single sample with multiple splits. + + Args: + split_lens: List of integers indicating length of each split within the sample + attn_modes: List of attention modes for each split ('causal', 'full', or 'noise') + device: Device to place the attention mask tensor on + + Returns: + Attention mask tensor of shape (sample_len, sample_len) with -inf for masked positions + """ + sample_len = sum(split_lens) + attention_mask = torch.zeros((sample_len, sample_len), dtype=torch.bool, device=device) # [sample_len,sample_len] + + # First pass: Set up basic attention patterns for each split + current_pos = 0 + for split_len, attn_mode in zip(split_lens, attn_modes): + assert attn_mode in ["causal", "full", "noise"], f"Invalid attention mode: {attn_mode}" + + split_start = current_pos + split_end = current_pos + split_len + + if attn_mode == "causal": + # Causal: lower triangular within split + full attention to previous splits + attention_mask[split_start:split_end, split_start:split_end] = torch.ones( + (split_len, split_len), device=device + ).tril() # [split_len,split_len] + attention_mask[split_start:split_end, :split_start] = 1 + else: # "full" or "noise" + # Full attention within split and to previous splits + attention_mask[split_start:split_end, split_start:split_end] = torch.ones( + (split_len, split_len), device=device + ) # [split_len,split_len] + attention_mask[split_start:split_end, :split_start] = 1 + + current_pos += split_len + + # Second pass: Handle noise mode - mask out noise columns except within same split + current_pos = 0 + for split_len, attn_mode in zip(split_lens, attn_modes): + if attn_mode == "noise": + split_start = current_pos + split_end = current_pos + split_len + + # Zero out the entire column for noise tokens + attention_mask[:, split_start:split_end] = 0 + # But allow self-attention within the noise split + attention_mask[split_start:split_end, split_start:split_end] = 1 + + current_pos += split_len + + # Convert boolean mask to float with -inf for masked positions + attention_mask = torch.zeros_like(attention_mask, dtype=torch.float).masked_fill_( + ~attention_mask, float("-inf") + ) # [sample_len,sample_len] + + return attention_mask + + +# ============================================================================ +# Tokenizer utilities +# ============================================================================ + + +def add_special_tokens(tokenizer): + """Add image-related special tokens to tokenizer if not already present. + + Args: + tokenizer: Tokenizer to add special tokens to + + Returns: + Tuple of (modified tokenizer, dict of new token IDs) + """ + # Collect existing special tokens + existing_special_tokens = [] + for key, value in tokenizer.special_tokens_map.items(): + if isinstance(value, str): + existing_special_tokens.append(value) + elif isinstance(value, list): + existing_special_tokens.extend(value) + + # Define image boundary tokens to add if missing + tokens_to_add = [] + if "<|vision_start|>" not in existing_special_tokens: + tokens_to_add.append("<|vision_start|>") + if "<|vision_end|>" not in existing_special_tokens: + tokens_to_add.append("<|vision_end|>") + + # Add new tokens to tokenizer vocabulary + if tokens_to_add: + tokenizer.add_tokens(tokens_to_add) + + # Get token IDs for image boundary tokens + new_token_ids = { + "start_of_generation": tokenizer.convert_tokens_to_ids("<|vision_start|>"), + "end_of_generation": tokenizer.convert_tokens_to_ids("<|vision_end|>"), + } + + return tokenizer, new_token_ids + + +def compute_text_split_length( + num_caption_tokens: int, + special_tokens: Dict[str, int], + has_generation: bool = True, +) -> int: + """Compute the total text split length without mutating any state. + + This is the number of token positions occupied by the text split in a + packed sequence: caption tokens + optional BOS + EOS + optional BOV. + + Args: + num_caption_tokens: Number of raw caption token IDs (before special tokens). + special_tokens: Dictionary of special token IDs (checked for ``"bos_token_id"``). + has_generation: Whether a start-of-generation (BOV) token follows text. + + Returns: + Total text split length (positions consumed in the packed sequence). + """ + n = num_caption_tokens + if "bos_token_id" in special_tokens: + n += 1 + n += 1 # EOS + if has_generation: + n += 1 # start-of-generation / BOV + return n + + +def pack_text_tokens( + packed_seq: PackedSequence, + text_ids: List[int], + special_tokens: Dict[str, int], + curr_rope_id: int, + has_generation: bool, + use_float_positions: bool = False, +) -> Tuple[int, int, int]: + """Pack text tokens into the sequence. + + Args: + packed_seq: PackedSequence instance to accumulate data into. + text_ids: List of text token IDs (integers). + special_tokens: Dictionary of special token IDs. + curr_rope_id: Current RoPE position ID. + has_generation: Whether there's media/action after text. + use_float_positions: If True, generate float position IDs for 3D mRoPE + (for consistency with FPS-modulated vision tokens). + + Returns: + Tuple of (updated curr_rope_id, split_length, sample_length). + """ + # Ensure we're in build mode (fields are lists, not tensors) + assert isinstance(packed_seq.text_ids, list), "PackedSequence must be in build mode" + assert isinstance(packed_seq.text_indexes, list) + assert isinstance(packed_seq.position_ids, list) + assert isinstance(packed_seq.label_ids, list) + assert isinstance(packed_seq.ce_loss_indexes, list) + assert isinstance(packed_seq.ce_loss_weights, list) + + curr = packed_seq.curr + + # Prepend BOS token if available + if "bos_token_id" in special_tokens: + shifted_text_ids = [special_tokens["bos_token_id"]] + text_ids + else: + shifted_text_ids = text_ids + + split_len = 0 + + # Add text tokens to sequence + packed_seq.text_ids.extend(shifted_text_ids) + packed_seq.text_indexes.extend(range(curr, curr + len(shifted_text_ids))) + + # Configure loss computation for text tokens + packed_seq.ce_loss_indexes.extend(range(curr, curr + len(shifted_text_ids))) + packed_seq.ce_loss_weights.extend([1.0] * len(shifted_text_ids)) + packed_seq.label_ids.extend(text_ids[1:] + [special_tokens["eos_token_id"]]) + + curr += len(shifted_text_ids) + split_len += len(shifted_text_ids) + + # Add EOS token + packed_seq.text_ids.append(special_tokens["eos_token_id"]) + packed_seq.text_indexes.append(curr) + curr += 1 + split_len += 1 + + # Add start-of-generation token, but only if there's media/action present. + if has_generation: + packed_seq.text_ids.append(special_tokens["start_of_generation"]) + packed_seq.text_indexes.append(curr) + curr += 1 + split_len += 1 + + # Sanity check -- compute_text_split_length() is called elsewhere. + assert split_len == compute_text_split_length(len(text_ids), special_tokens, has_generation) + + # Update position IDs and attention mode for text split + if packed_seq._use_mrope: + text_mrope_ids, packed_seq._mrope_temporal_offset = get_3d_mrope_ids_text_tokens( + num_tokens=split_len, + temporal_offset=packed_seq._mrope_temporal_offset, + use_float_positions=use_float_positions, + ) # text_mrope_ids: [3,split_len] + packed_seq.position_ids.append(text_mrope_ids) + else: + packed_seq.position_ids.extend(range(curr_rope_id, curr_rope_id + split_len)) + packed_seq.attn_modes.append("causal") + packed_seq.split_lens.append(split_len) + + packed_seq.curr = curr + return curr_rope_id + split_len, split_len, split_len + + +def pack_vision_tokens( + packed_seq: PackedSequence, + input_vision_tokens: torch.Tensor, + condition_frame_indexes_vision: list[int], + input_timestep: float | torch.Tensor, + curr_rope_id: int, + latent_patch_size: int = 1, + vision_fps: float | None = None, + enable_fps_modulation: bool = False, + base_fps: float = 24.0, + temporal_compression_factor: int = 4, + vision_temporal_positions: torch.Tensor | None = None, +) -> int: + """Pack vision tokens into the sequence. + + Args: + packed_seq: PackedSequence instance to accumulate data into. + input_vision_tokens: Vision latent tokens (C, T, H, W). + condition_frame_indexes_vision: Indexes of conditioning frames. + input_timestep: Diffusion timestep. Either a float (teacher_forcing/none — all frames + share the same sigma) or a Tensor(T_max,) (diffusion_forcing — per-frame sigma; + indexed as input_timestep[frame_idx] for each noisy frame). + curr_rope_id: Current RoPE position ID. + latent_patch_size: Patch size for latent patchification. + vision_fps: Frames per second of the video. Used when enable_fps_modulation=True. + enable_fps_modulation: If True, scale temporal position IDs based on video FPS. + base_fps: Base FPS for normalization (default 24.0). + temporal_compression_factor: VAE temporal compression factor (default 4). + vision_temporal_positions: Optional explicit temporal coordinate per latent + frame, shape ``(T,)``. Used by UniAE to account for kept boundary latents. + Returns: + Vision split length. + """ + # Ensure we're in build mode + assert isinstance(packed_seq.position_ids, list), "PackedSequence must be in build mode" + + curr = packed_seq.curr + vision_split_len = 0 + + # Initialize vision modality if not present. + if packed_seq.vision is None: + packed_seq.vision = ModalityData() + + # Ensure vision modality is in build mode + assert isinstance(packed_seq.vision.sequence_indexes, list) + assert isinstance(packed_seq.vision.mse_loss_indexes, list) + assert isinstance(packed_seq.vision.timesteps, list) + assert isinstance(packed_seq.vision.tokens, list) + + # Compute position IDs for image patches + _, _, latent_t, latent_h, latent_w = input_vision_tokens.shape + if latent_patch_size < 1: + raise ValueError(f"latent_patch_size must be >= 1, got {latent_patch_size}") + # Use ceil to support latent dims not divisible by patch size (padding handled in network) + patch_h = math.ceil(latent_h / latent_patch_size) + patch_w = math.ceil(latent_w / latent_patch_size) + packed_seq.vision.token_shapes.append((latent_t, patch_h, patch_w)) + packed_seq.vision.tokens.append(input_vision_tokens) + + # Add image token indexes and loss information + num_vision_tokens = latent_t * patch_h * patch_w + packed_seq.vision.sequence_indexes.extend(range(curr, curr + num_vision_tokens)) + + # Supervise vision tokens based on conditioning frames + condition_set = {idx for idx in condition_frame_indexes_vision if 0 <= idx < latent_t} + assert isinstance(packed_seq.vision.condition_mask, list) + + vision_condition_mask = torch.zeros( + (latent_t, 1, 1), device=input_vision_tokens.device, dtype=input_vision_tokens.dtype + ) # [T,1,1] + for frame_idx in condition_set: + vision_condition_mask[frame_idx, 0, 0] = 1.0 + packed_seq.vision.condition_mask.append(vision_condition_mask) + + vision_noisy_frame_indexes = torch.tensor( + [idx for idx in range(latent_t) if idx not in condition_set], + device=input_vision_tokens.device, + dtype=torch.long, + ) # [N_noisy_frames] + assert isinstance(packed_seq.vision.noisy_frame_indexes, list) + packed_seq.vision.noisy_frame_indexes.append(vision_noisy_frame_indexes) + + frame_token_stride = patch_h * patch_w + for frame_idx in range(latent_t): + if frame_idx in condition_set: + continue + frame_start = curr + frame_idx * frame_token_stride + frame_end = frame_start + frame_token_stride + packed_seq.vision.mse_loss_indexes.extend(range(frame_start, frame_end)) + if isinstance(input_timestep, torch.Tensor): + frame_ts = input_timestep[frame_idx].item() + else: + frame_ts = input_timestep + packed_seq.vision.timesteps.extend([frame_ts] * frame_token_stride) + + curr += num_vision_tokens + vision_split_len += num_vision_tokens + + # Update position IDs for image split + if packed_seq._use_mrope: + # Determine FPS for this vision segment (None disables FPS modulation) + effective_fps = vision_fps if enable_fps_modulation else None + if vision_temporal_positions is not None: + vision_temporal_positions = vision_temporal_positions.to(device="cpu", dtype=torch.float32) # [T] + + vision_mrope_ids, packed_seq._mrope_temporal_offset = get_3d_mrope_ids_vae_tokens( + grid_t=latent_t, + grid_h=patch_h, + grid_w=patch_w, + temporal_offset=packed_seq._mrope_temporal_offset, + reset_spatial_indices=packed_seq._mrope_reset_spatial, + fps=effective_fps, + base_fps=base_fps, + temporal_compression_factor=temporal_compression_factor, + temporal_positions=vision_temporal_positions, + actual_temporal_compression_factor=temporal_compression_factor, + ) # vision_mrope_ids: [3,N_vision_tokens] + packed_seq.position_ids.append(vision_mrope_ids) + else: + # All image tokens share the same RoPE position ID + packed_seq.position_ids.extend([curr_rope_id] * vision_split_len) + + packed_seq.curr = curr + return vision_split_len + + +def pack_action_tokens( + packed_seq: PackedSequence, + input_action_tokens: torch.Tensor, + condition_frame_indexes_action: list[int], + input_timestep: float, + curr_rope_id: int, + action_temporal_offset: int | float = 0, + enable_fps_modulation: bool = False, + base_fps: float = 24.0, + action_fps: float | None = None, + base_temporal_compression_factor: int | None = None, + action_start_frame_offset: int = 1, +) -> int: + """Pack action tokens into the sequence. + + Args: + packed_seq: PackedSequence instance to accumulate data into. + input_action_tokens: Action latent tokens (T, D). + condition_frame_indexes_action: Indexes of conditioning action steps. + input_timestep: Diffusion timestep. + curr_rope_id: Current RoPE position ID. + action_temporal_offset: Temporal offset for action mRoPE IDs (typically + the vision start offset so action aligns temporally with vision). + enable_fps_modulation: If True, scale temporal position IDs based on FPS. + base_fps: Base FPS for normalization (default 24.0). + action_fps: Frames per second of the action data. Used when enable_fps_modulation=True. + base_temporal_compression_factor: Base temporal compression factor for FPS scaling. + Should be set to the vision temporal compression factor (e.g. 4) so that action + tokens advance at frame rate (4x finer) relative to vision latent frames. + Only affects behavior when FPS modulation is enabled. + action_start_frame_offset: Frame offset for aligning action[0] with the + corresponding vision frame. Default 1 aligns action[0] with vision frame 1. + Returns: + Number of action tokens added. + """ + # Ensure we're in build mode + assert isinstance(packed_seq.position_ids, list), "PackedSequence must be in build mode" + + curr = packed_seq.curr + action_split_len = input_action_tokens.shape[0] + + # Initialize action modality if not present + if packed_seq.action is None: + packed_seq.action = ModalityData() + + # Ensure action modality is in build mode + assert isinstance(packed_seq.action.sequence_indexes, list) + assert isinstance(packed_seq.action.mse_loss_indexes, list) + assert isinstance(packed_seq.action.timesteps, list) + assert isinstance(packed_seq.action.tokens, list) + + # Add token indexes and loss information + action_indexes = list(range(curr, curr + action_split_len)) + packed_seq.action.sequence_indexes.extend(action_indexes) + packed_seq.action.token_shapes.append((action_split_len,)) + packed_seq.action.tokens.append(input_action_tokens) + + condition_set = {idx for idx in condition_frame_indexes_action if 0 <= idx < action_split_len} + assert isinstance(packed_seq.action.condition_mask, list) + + action_condition_mask = torch.zeros( + (action_split_len, 1), device=input_action_tokens.device, dtype=input_action_tokens.dtype + ) # [T_action,1] + for frame_idx in condition_set: + action_condition_mask[frame_idx, 0] = 1.0 + packed_seq.action.condition_mask.append(action_condition_mask) + + action_noisy_frame_indexes = torch.tensor( + [idx for idx in range(action_split_len) if idx not in condition_set], + device=input_action_tokens.device, + dtype=torch.long, + ) # [N_noisy_action_frames] + assert isinstance(packed_seq.action.noisy_frame_indexes, list) + packed_seq.action.noisy_frame_indexes.append(action_noisy_frame_indexes) + + frame_token_stride = 1 # Action has 1 token per frame (no spatial dimension) + for frame_idx in range(action_split_len): + if frame_idx in condition_set: + continue + frame_start = curr + frame_idx * frame_token_stride + frame_end = frame_start + frame_token_stride + packed_seq.action.mse_loss_indexes.extend(range(frame_start, frame_end)) + packed_seq.action.timesteps.extend([input_timestep] * frame_token_stride) + + # Update RoPE position IDs for action tokens. + if packed_seq._use_mrope: + # 3D mRoPE: action tokens use a 1x1 spatial grid with start_frame_offset=1 + # so action[0] (null token) aligns with vision frame 1, not frame 0. + effective_fps = action_fps if enable_fps_modulation else None + + action_mrope_ids, _ = get_3d_mrope_ids_vae_tokens( + grid_t=action_split_len, + grid_h=1, + grid_w=1, + temporal_offset=action_temporal_offset, + reset_spatial_indices=packed_seq._mrope_reset_spatial, + fps=effective_fps, + base_fps=base_fps, + temporal_compression_factor=1, # Action is at frame rate (no temporal compression) + base_temporal_compression_factor=base_temporal_compression_factor, + start_frame_offset=action_start_frame_offset, # Align action[0] with vision frame action_start_frame_offset + ) # action_mrope_ids: [3,N_action_tokens] + packed_seq.position_ids.append(action_mrope_ids) + # Note: we don't update _mrope_temporal_offset here because action tokens + # share the temporal space with vision tokens (they run in parallel). + else: + # All action tokens share the SAME RoPE position as vision tokens (see docs/sequence_packing.md). + packed_seq.position_ids.extend([curr_rope_id] * action_split_len) + + packed_seq.curr = curr + action_split_len + return action_split_len + + +def pack_sound_tokens( + packed_seq: PackedSequence, + input_sound_tokens: torch.Tensor, + condition_frame_indexes_sound: list[int], + input_timestep: float, + curr_rope_id: int, + sound_temporal_offset: int | float = 0, + enable_fps_modulation: bool = False, + base_fps: float = 24.0, + sound_fps: float | None = None, + sound_base_temporal_compression_factor: int | None = None, +) -> int: + """Pack sound/audio tokens into the sequence. + + Sound latents have shape [C, T] where C is channels and T is temporal frames. + Sound tokens are added to the unified generation split to maintain SequencePack's + 2-split invariant (causal + full). + + Args: + packed_seq: PackedSequence instance to accumulate data into. + input_sound_tokens: Sound latent tokens (C, T). + condition_frame_indexes_sound: Indexes of conditioning frames. + [] means all frames are noised/supervised. + All frames specified means all frames are clean (no MSE supervision). + input_timestep: Diffusion timestep. + curr_rope_id: Current RoPE position ID. + sound_temporal_offset: Temporal offset for m-RoPE position IDs (aligned with vision start). + enable_fps_modulation: If True, scale temporal positions by FPS ratio. + base_fps: Base FPS for normalization (default 24.0). + sound_fps: Sound latent FPS (e.g., 25.0). Used for FPS-aware m-RoPE positions. + sound_base_temporal_compression_factor: Base temporal compression factor for sound FPS scaling. + ``None`` preserves the current behavior where sound advances at ``base_fps`` positions/sec. + + Returns: + Number of sound tokens added. + """ + # Ensure we're in build mode + assert isinstance(packed_seq.position_ids, list), "PackedSequence must be in build mode" + + curr = packed_seq.curr + + # Sound latent shape: [C, T] → T tokens + _, sound_split_len = input_sound_tokens.shape + + # Initialize sound modality if not present + if packed_seq.sound is None: + packed_seq.sound = ModalityData() + + # Ensure sound modality is in build mode + assert isinstance(packed_seq.sound.sequence_indexes, list) + assert isinstance(packed_seq.sound.mse_loss_indexes, list) + assert isinstance(packed_seq.sound.timesteps, list) + assert isinstance(packed_seq.sound.tokens, list) + + # Add token indexes - sound uses (T, 1, 1) shape for compatibility with 3D RoPE + packed_seq.sound.token_shapes.append((sound_split_len, 1, 1)) + packed_seq.sound.sequence_indexes.extend(range(curr, curr + sound_split_len)) + packed_seq.sound.tokens.append(input_sound_tokens) + + # Supervise sound tokens based on conditioning frames + condition_set = {idx for idx in condition_frame_indexes_sound if 0 <= idx < sound_split_len} + assert isinstance(packed_seq.sound.condition_mask, list) + + # Condition mask: shape (T, 1) — 1 = clean/conditioning, 0 = noised/supervised + sound_condition_mask = torch.zeros( + (sound_split_len, 1), device=input_sound_tokens.device, dtype=input_sound_tokens.dtype + ) # [T_sound,1] + for frame_idx in condition_set: + sound_condition_mask[frame_idx, 0] = 1.0 + packed_seq.sound.condition_mask.append(sound_condition_mask) + + sound_noisy_frame_indexes = torch.tensor( + [idx for idx in range(sound_split_len) if idx not in condition_set], + device=input_sound_tokens.device, + dtype=torch.long, + ) # [N_noisy_sound_frames] + assert isinstance(packed_seq.sound.noisy_frame_indexes, list) + packed_seq.sound.noisy_frame_indexes.append(sound_noisy_frame_indexes) + + # Add to MSE loss indexes and timesteps for non-conditioning frames + for frame_idx in range(sound_split_len): + if frame_idx in condition_set: + continue + # Sound has 1 token per frame (no spatial dimension) + frame_start = curr + frame_idx + frame_end = frame_start + 1 + packed_seq.sound.mse_loss_indexes.extend(range(frame_start, frame_end)) + packed_seq.sound.timesteps.extend([input_timestep]) + + # Update RoPE position IDs for sound tokens. + if packed_seq._use_mrope: + # 3D mRoPE: sound tokens use a 1x1 spatial grid, aligned with vision temporal positions. + # sound[0] aligns with vision frame 0 (start_frame_offset=0, unlike action which offsets by 1). + effective_fps = sound_fps if enable_fps_modulation else None + + sound_mrope_ids, _ = get_3d_mrope_ids_vae_tokens( + grid_t=sound_split_len, + grid_h=1, + grid_w=1, + temporal_offset=sound_temporal_offset, + reset_spatial_indices=packed_seq._mrope_reset_spatial, + fps=effective_fps, + base_fps=base_fps, + temporal_compression_factor=1, # Sound latent is already at sound_latent_fps (no further compression) + base_temporal_compression_factor=sound_base_temporal_compression_factor, + start_frame_offset=0, # Sound[0] aligns with vision frame 0 + ) # sound_mrope_ids: [3,N_sound_tokens] + packed_seq.position_ids.append(sound_mrope_ids) + # Note: we don't update _mrope_temporal_offset here because sound tokens + # share the temporal space with vision tokens (they run in parallel). + else: + # All sound tokens share the SAME RoPE position as vision/action tokens (unified generation split). + packed_seq.position_ids.extend([curr_rope_id] * sound_split_len) + + packed_seq.curr = curr + sound_split_len + return sound_split_len diff --git a/cosmos_framework/model/vfm/mot/unified_3dmrope_utils.py b/cosmos_framework/data/vfm/sequence_packing/mrope.py similarity index 100% rename from cosmos_framework/model/vfm/mot/unified_3dmrope_utils.py rename to cosmos_framework/data/vfm/sequence_packing/mrope.py diff --git a/cosmos_framework/data/vfm/sequence_packing/natten.py b/cosmos_framework/data/vfm/sequence_packing/natten.py new file mode 100644 index 0000000..c917689 --- /dev/null +++ b/cosmos_framework/data/vfm/sequence_packing/natten.py @@ -0,0 +1,497 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""NATTEN parameter validation and metadata generation.""" + +from collections.abc import Mapping, Sequence + +import torch + +from cosmos_framework.model.attention.checks import check_valid_tuple_or_element +from cosmos_framework.model.attention.varlen import generate_multi_dim_varlen_parameters +from cosmos_framework.utils import log + + +def _validate_single_dim_params(params: Mapping, layer_idx: int, num_dims: int | None) -> dict: + """ + Helper function to validate NATTEN parameters for a dimensionality profile. + + Args: + params (Mapping): parameter dict with window_size/window_size_float and other params + layer_idx (int): layer index for error messages + num_dims (int | None): 1, 2, 3, or None (for single-profile format) + + Returns: + dict: validated parameter dict with proper types + """ + if not isinstance(params, Mapping): + dim_str = f" ({num_dims}-D)" if num_dims else "" + raise ValueError(f"Parameters for layer {layer_idx}{dim_str} must be a dict or None, got {params=}.") + + is_causal = False if "is_causal" not in params else params["is_causal"] + + if "window_size_float" in params: + window_size_float = params["window_size_float"] + if ( + not isinstance(window_size_float, Sequence) + or len(window_size_float) not in [1, 2, 3] + or any(not isinstance(x, float) for x in window_size_float) + ): + raise ValueError(f"'window_size_float' must be a float tuple of size 1, 2, or 3, got {window_size_float=}") + window_size_float = tuple(k for k in window_size_float) + + num_dims = len(window_size_float) + + def check_stride_dilation(x): + if isinstance(x, float): + if 0.0 <= x <= 1.0: + return tuple(x for _ in range(num_dims)) + elif ( + isinstance(x, Sequence) + and len(x) == num_dims + and all(isinstance(y, float) and 0.0 <= y <= 1.0 for y in x) + ): + return tuple(y for y in x) + else: + raise ValueError(f"Invalid natten float parameter: {x=}") + + stride_float = 0.0 if "stride_float" not in params else params["stride_float"] + dilation_float = 0.0 if "dilation_float" not in params else params["dilation_float"] + + stride_float = check_stride_dilation(stride_float) + dilation_float = check_stride_dilation(dilation_float) + is_causal = check_valid_tuple_or_element( + is_causal, num_dims=num_dims, typename=bool, raise_error=True, param_name="is_causal" + ) + + if any(x in params for x in ["window_size", "stride", "dilation"]): + raise ValueError( + f"Please either use _float parameters, or integer ones, and not mix the two. Got {params=}." + ) + + return { + "window_size_float": window_size_float, + "stride_float": stride_float, + "dilation_float": dilation_float, + "is_causal": is_causal, + } + + elif "window_size" in params: + window_size = params["window_size"] + num_dims = len(window_size) + + stride = 1 if "stride" not in params else params["stride"] + dilation = 1 if "dilation" not in params else params["dilation"] + + if any("_float" in x for x in params.keys()): + raise ValueError( + f"Please either use _float parameters, or integer ones, and not mix the two. Got {params=}." + ) + + window_size = check_valid_tuple_or_element( + window_size, num_dims=num_dims, typename=int, raise_error=True, param_name="window_size" + ) + stride = check_valid_tuple_or_element( + stride, num_dims=num_dims, typename=int, raise_error=True, param_name="stride" + ) + dilation = check_valid_tuple_or_element( + dilation, num_dims=num_dims, typename=int, raise_error=True, param_name="dilation" + ) + is_causal = check_valid_tuple_or_element( + is_causal, num_dims=num_dims, typename=bool, raise_error=True, param_name="is_causal" + ) + + return {"window_size": window_size, "stride": stride, "dilation": dilation, "is_causal": is_causal} + else: + raise ValueError( + "Sparse parameters for a layer must have key 'window_size' or 'window_size_float', " + f"got {params=} in layer index {layer_idx}." + ) + + +def verify_natten_parameter_list( + natten_parameter_list: list | None, + num_layers: int, +) -> list | None: + """ + Converts list of NATTEN parameters into expected types, and assigns defaults to unset + parameters. + This needs to be done separately during model initialization, and not forward pass. + There are no torch operations in this function. + + Args: + natten_parameter_list (list | None): list of NATTEN parameters. Must be either None, or a + list of mappings, one for each layer. Each list element must be either None, + representing no sparsity / masking (full dense attention), or a mapping of NATTEN + parameters. + + Parameters can be specified directly with integer or float format: + - 'window_size_float' (required), 'stride_float', 'dilation_float' + - 'window_size' (required), 'stride', 'dilation' + + Or, parameters can be specified for multiple dimensionality profiles in case of + mixed-training (i.e. image and video training) using keys "1d", "2d", "3d": + - Each key maps to either None (dense attention) or a parameter dict + + Integer and float parameters cannot be used together in the same layer! + Additionally, you can specify 'is_causal'. + + Examples: + ``` + # 50 percent sparsity along each dimension in a 2-D token layout + {'window_size_float': (0.5, 0.5)} # valid + + # 50 percent sparsity along each dimension in a 2-D token layout + # Maximum dilation along first dimension, no dilation along second dimension + {'window_size_float': (0.5, 0.5), 'dilation_float': (1.0, 0.0)} # valid + + # Fixed window size of 8x8, dilation of 2x1. + # NOTE: requires ALL inputs to be at least 16x8 + {'window_size': (8, 8), 'dilation': (2, 1)} # valid + + # Multi-profile: different parameters for 2D (images) and 3D (videos) + { + "2d": {"window_size_float": (0.5, 0.5)}, + "3d": {"window_size_float": (1.0, 0.5, 0.5)} + } # valid + + # Multi-profile: 2D uses dense attention, 3D uses sparse + { + "2d": None, + "3d": {"window_size_float": (1.0, 0.5, 0.5)} + } # valid + + # Invalid: + {'window_size_float': (0.5, 0.5), 'dilation': (2, 1)} + ``` + + num_layers (int): number of layers in the model. Just used to verify list length. + + Returns: + output_parameter_list (list | None): verified and type-checked NATTEN parameters, or None if + no parameters passed. + """ + + if natten_parameter_list is not None: + parameter_list_out = [] + if not isinstance(natten_parameter_list, Sequence): + raise ValueError(f"Argument 'natten_parameter_list' must be a list or None, got {natten_parameter_list=}.") + + if len(natten_parameter_list) != num_layers: + raise ValueError( + "Number of elements in 'natten_parameter_list' must match number of layers " + f"in the model, got {num_layers=}, {len(natten_parameter_list)=}." + ) + + for i, layer_parameters in enumerate(natten_parameter_list): + if layer_parameters is None: + log.debug(f"Layer {i} will use DENSE attention.") + parameter_list_out.append(None) + continue + + if not isinstance(layer_parameters, Mapping): + raise ValueError( + f"Sparse parameters for a layer must be a dict or None, got {layer_parameters=} in layer index {i}." + ) + + # Detect format: multi-profile if has keys "1d", "2d", or "3d" + dim_keys = {"1d", "2d", "3d"} + has_dim_keys = any(k in layer_parameters for k in dim_keys) + + if has_dim_keys: + # Multi-profile format: validate each explicitly defined dimensionality profile + validated_multi_profile = {} + for dim_str, dim_int in [("1d", 1), ("2d", 2), ("3d", 3)]: + if dim_str in layer_parameters: + dim_params = layer_parameters[dim_str] + if dim_params is None: + validated_multi_profile[dim_int] = None + else: + validated_multi_profile[dim_int] = _validate_single_dim_params(dim_params, i, dim_int) + else: + # Single-profile format: validate and convert to multi-profile format + # Infer dimensionality from parameter tuple length + validated_params = _validate_single_dim_params(layer_parameters, i, None) + if "window_size_float" in validated_params: + num_dims = len(validated_params["window_size_float"]) + else: # "window_size" + num_dims = len(validated_params["window_size"]) + validated_multi_profile = {num_dims: validated_params} + + # If all explicitly defined profiles are None, treat as fully dense layer + if all(v is None for v in validated_multi_profile.values()): + log.debug(f"Layer {i} will use DENSE attention (all profiles None).") + parameter_list_out.append(None) + else: + parameter_list_out.append(validated_multi_profile) + log.info(f"Layer {i} NATTEN parameters: {validated_multi_profile}") + + return parameter_list_out + + return None + + +def generate_natten_metadata( + token_shapes: list[tuple[int, int, int]], + head_dim: int, + num_layers: int, + device: torch.device, + dtype: torch.dtype, + requires_grad: bool, + natten_parameter_list: list | None = None, +) -> list | None: + """ + Generates list of metadata required by Variable-Sized (variable-length) operations in NATTEN. + Required when training with three_way attention and NATTEN (multi-dimensional / sparse + attention). + + Args: + token_shapes (list[tuple]): list of integer tuples corresponding to the + post-tokenization/patchify token layout shapes in the packed sequence. Must strictly be + integer tuples with the same profile (all 1D, 2D, or 3D). 1s will be automatically + stripped (i.e. [(1, 8, 8), (1, 16, 16)] is interpreted as [(8, 8), (16, 16)]). + + head_dim (int): Attention head dimension (used to select NATTEN kernel configurations). + + num_layers (int): number of layers in the model. Just used to verify list length. + + device (torch.device): PyTorch device for offset tensors (should match QKV device). + + dtype (torch.dtype): Expected QKV dtype. + + requires_grad (bool): Determines whether backprop is expected, and sets up metadata for + backward pass as well. + + natten_parameter_list (list | None): list of NATTEN parameters. Must be either None, or a + list of mappings, one for each layer. Each list element must be either None, + representing no sparsity / masking (full dense attention), or a mapping of NATTEN + parameters in either integer or float format: + - 'window_size_float' (required), 'stride_float', 'dilation_float' + - 'window_size' (required), 'stride', 'dilation' + + Integer and float parameters cannot be used together in the same layer! + Additionally, you can specify 'is_causal'. + + Examples: + ``` + # 50 percent sparsity along each dimension in a 2-D token layout + {'window_size_float': (0.5, 0.5)} # valid + + # 50 percent sparsity along each dimension in a 2-D token layout + # Maximum dilation along first dimension, no dilation along second dimension + {'window_size_float': (0.5, 0.5), 'dilation_float': (1.0, 0.0)} # valid + + # Fixed window size of 8x8, dilation of 2x1. + # NOTE: requires ALL inputs to be at least 16x8 + {'window_size': (8, 8), 'dilation': (2, 1)} # valid + + # Invalid: + {'window_size_float': (0.5, 0.5), 'dilation': (2, 1)} + ``` + + Returns: + natten_metadata_list (list | None): list of NATTEN varlen metadata, or Nones (dense layers). + Each non-None element will be a dictionary containing final parameters, and varlen + metadata (offset and size tensors, max lengths). + NOTE: to avoid excessive recompilations in torch.compile, we must carefully index into + this list during model.forward, and ideally using the iteration counter from the loop + over layers (nn.ModuleList). + """ + + + if token_shapes is None or len(token_shapes) < 1: + raise ValueError("'token_shapes' is required for 'three_way' attention.") + + natten_metadata = None + + if natten_parameter_list is not None: + natten_metadata = [] + if not isinstance(natten_parameter_list, list): + raise ValueError(f"Argument 'natten_parameter_list' must be a list or None, got {natten_parameter_list=}.") + + if len(natten_parameter_list) != num_layers: + raise ValueError( + "Number of elements in 'natten_parameter_list' must match number of layers " + f"in the model, got {num_layers=}, {len(natten_parameter_list)=}." + ) + + # We need to filter out 1s from shapes + def filter_shape(shape: tuple) -> tuple: + return tuple(x for x in shape if x > 1) + + # Infer token layout rank (dimensionality) + num_dims = max([len(filter_shape(token_shape)) for token_shape in token_shapes]) + + # Single pass: check if all layers support this dimensionality and if any need processing + needs_processing = False + for i, layer_parameters in enumerate(natten_parameter_list): + if layer_parameters is None: + continue + + # Fail fast if this dimensionality is not defined + if num_dims not in layer_parameters: + raise ValueError( + f"Layer {i}: batch has {num_dims}D data but parameters are not defined for {num_dims}D. " + f"Defined dimensionalities: {sorted(layer_parameters.keys())}" + ) + + # Check if this layer needs processing for this dimensionality + if layer_parameters[num_dims] is not None: + needs_processing = True + + # Early exit if all layers are dense for this dimensionality profile + if not needs_processing: + log.debug(f"All layers use DENSE attention for {num_dims}D data.") + return None + + # We actually need to process, so validate and filter all shapes + token_layout_list = [] + for shape in token_shapes: + assert isinstance(shape, tuple) + shape_filtered = filter_shape(shape) + assert len(shape_filtered) == num_dims, ( + f"All data in batch must have same dimensionality, got {num_dims}D and {len(shape_filtered)}D" + ) + token_layout_list.append(shape_filtered) + + log.debug(f"Batch dimensionality: {num_dims}D, token_layout_list={token_layout_list}") + + for i, layer_parameters in enumerate(natten_parameter_list): + if layer_parameters is None: + natten_metadata.append(None) + continue + + # Get parameters for this dimensionality (already validated above) + dim_params = layer_parameters[num_dims] + + if dim_params is None: + # Dense attention for this dimensionality + natten_metadata.append(None) + continue + + # Use dim_params (parameters for this specific dimensionality) + window_size_list = [] + stride_list = [] + dilation_list = [] + + if "window_size_float" in dim_params: + window_size_float = dim_params["window_size_float"] + stride_float = dim_params["stride_float"] + dilation_float = dim_params["dilation_float"] + + for token_layout in token_layout_list: + window_size_ = tuple( + min(x, max(2, int(k * float(x)))) for k, x in zip(window_size_float, token_layout) + ) + stride_ = tuple(min(k, max(1, int(s * float(k)))) for s, k in zip(stride_float, window_size_)) + max_dilation = tuple(x // k for k, x in zip(window_size_, token_layout)) + dilation_ = tuple(min(m, max(1, int(d * float(m)))) for d, m in zip(dilation_float, max_dilation)) + + window_size_list.append(window_size_) + stride_list.append(stride_) + dilation_list.append(dilation_) + + assert len(window_size_list) == len(stride_list) == len(dilation_list) == len(token_layout_list) + + log.debug(f"Layer {i}: {window_size_list=}") + log.debug(f"Layer {i}: {stride_list=}") + log.debug(f"Layer {i}: {dilation_list=}") + + elif "window_size" in dim_params: + window_size = dim_params["window_size"] + stride = dim_params["stride"] + dilation = dim_params["dilation"] + + window_size_list = [window_size for _ in range(len(token_layout_list))] + stride_list = [stride for _ in range(len(token_layout_list))] + dilation_list = [dilation for _ in range(len(token_layout_list))] + else: + raise ValueError( + "Sparse parameters for a layer must have key 'window_size' or 'window_size_float', " + f"got {dim_params=} in layer index {i}." + ) + + is_causal = dim_params["is_causal"] + + # Create varlen metadata for natten varlen/varsized ops + # NOTE: generate_multi_dim_varlen_parameters will automatically map window size -1 to + # full size, that's why constant window sizes aren't allowed. + # NOTE: if any of the parameters are constant, natten will simplify them + natten_metadata.append( + generate_multi_dim_varlen_parameters( + token_layout_list=token_layout_list, + head_dim=head_dim, + device=device, + dtype=dtype, + requires_grad=requires_grad, + # + window_size_list=window_size_list, + stride_list=stride_list, + dilation_list=dilation_list, + # + is_causal=is_causal, + ) + ) + + return natten_metadata + + +def generate_temporal_causal_natten_metadata( + vision_token_shapes: list[tuple[int, int, int]], + num_action_tokens_per_supertoken: int, + num_layers: int, + head_dim: int, + device: torch.device, + dtype: torch.dtype, + requires_grad: bool, +) -> list: + """Generate per-layer varlen metadata for temporal causal attention on supertokens. + + Each sample's generation tokens are laid out as T_i supertokens of size + S_i = num_action_tokens_per_supertoken + H_i*W_i. Metadata encodes + is_causal=(True, False): causal across T, full within S. All layers share + the same metadata (full window, no spatial sparsity). + + Unlike generate_natten_metadata, this function does not apply filter_shape — (T, S) layouts + are passed directly even when T=1. NATTEN handles T=1 causal masking correctly (trivially + full attention within S). + + Args: + vision_token_shapes: List of (T, H, W) per sample. + num_action_tokens_per_supertoken: Number of action tokens prefixing each + supertoken (0 when actions are not packed inline). + num_layers: Number of transformer layers. + head_dim: Attention head dimension. + device: Target device. + dtype: Target dtype. + requires_grad: Whether metadata tensors require gradient. + + Returns: + List of length num_layers, each element the same NATTEN varlen metadata dict. + """ + # T=1: NATTEN requires kernel_size >= 2 and kernel_size <= token_layout, which are mutually + # exclusive when T=1. Fall back to full dense attention (None) — a single supertoken trivially + # attends to only itself, so temporal causality is already satisfied. + # Mixed T=1/T>1 batches are rejected: NATTEN can't mask T=1 samples, and falling back to dense + # attention for the whole batch would break temporal causality for the T>1 samples. + # Ensure min_frames >= 5 in the dataloader so that T_latent = 1 + (N-1)//tcf >= 2 always. + has_short = any(t < 2 for t, h, w in vision_token_shapes) + if has_short: + if not all(t < 2 for t, h, w in vision_token_shapes): + raise ValueError( + "Mixed T=1 and T>1 samples in causal training batch: NATTEN cannot apply " + "causal masking when any sample has T=1 (kernel_size constraint), and falling " + "back to dense attention would break temporal causality for T>1 samples. " + "Ensure all samples have T_latent >= 2 (set min_frames >= 5 in the dataloader)." + ) + return [None] * num_layers + token_layout_list = [(t, num_action_tokens_per_supertoken + h * w) for t, h, w in vision_token_shapes] + metadata = generate_multi_dim_varlen_parameters( + token_layout_list=token_layout_list, + head_dim=head_dim, + device=device, + dtype=dtype, + requires_grad=requires_grad, + is_causal=(True, False), + ) + return [metadata] * num_layers diff --git a/cosmos_framework/data/vfm/sequence_packing/packers.py b/cosmos_framework/data/vfm/sequence_packing/packers.py new file mode 100644 index 0000000..f311f87 --- /dev/null +++ b/cosmos_framework/data/vfm/sequence_packing/packers.py @@ -0,0 +1,472 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Top-level input sequence packing orchestration.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from cosmos_framework.data.vfm.sequence_packing.modalities import ( + pack_action_tokens, + pack_sound_tokens, + pack_text_tokens, + pack_vision_tokens, +) +from cosmos_framework.data.vfm.sequence_packing.temporal_causal import pack_supertokens_temporal_causal +from cosmos_framework.data.vfm.sequence_packing.types import PackedSequence, SequencePlan + +if TYPE_CHECKING: + from cosmos_framework.model.vfm.utils.data_and_condition import GenerationDataClean + + +def pack_input_sequence( + sequence_plans: list[SequencePlan], + input_text_indexes: list[list[int]], + gen_data_clean: GenerationDataClean, + input_timesteps: torch.Tensor, + special_tokens: dict[str, int], + max_num_tokens: int | None = None, + latent_patch_size: int = 1, + skip_text_tokens: bool = False, + include_end_of_generation_token: bool = False, + position_embedding_type: str = "3d_rope", + unified_3d_mrope_reset_spatial_ids: bool = True, + unified_3d_mrope_temporal_modality_margin: int = 0, + enable_fps_modulation: bool = False, + base_fps: float = 24.0, + sound_base_temporal_compression_factor: int | None = None, + temporal_compression_factor: int = 4, + vision_temporal_position_mode: str = "latent_index", + video_temporal_causal: bool = False, + action_dim: int = 32, + initial_mrope_temporal_offset: int | float = 0, +) -> PackedSequence: + """ + Pack a sequence of input strings and VAE latents into a packed tensor format. + Uses SequencePlan to determine which modalities are present for each sample, + and maintains separate indices for text, vision, action, and sound to handle variable modality presence. + + Args: + sequence_plans: List of SequencePlan items describing which modalities are present. + input_text_indexes: List of text token ID sequences (only for samples where has_text=True). + gen_data_clean: GenerationDataClean containing vision, action, and sound tensors. + - x0_tokens_vision: Vision tensors for samples where has_vision=True + - x0_tokens_action: Action tensors for samples where has_action=True + - x0_tokens_sound: Sound tensors (list of [C, T]) for samples where has_sound=True + input_timesteps: Diffusion timesteps for each sample. Shape (B,) or (B, 1) for + teacher_forcing/none (all frames share the same sigma), or (B, T_max) for + diffusion_forcing (per-frame independent sigma). Entries are extracted per + sample as a float (numel==1) or Tensor(T_max,) for per-frame indexing. + special_tokens: Dictionary containing special token IDs (eos_token_id, start_of_generation, end_of_generation) + max_num_tokens: Maximum number of tokens in the packed sequence + latent_patch_size: Patch size used by the network to pack latents + skip_text_tokens: If True, skip packing text tokens + include_end_of_generation_token: If True, append end-of-generation token + position_embedding_type: Position embedding type for vision tokens: + - "3d_rope": Additive 3D RoPE embeddings + 1D position IDs for attention + - "flattened_sin_cos": Additive flattened sin/cos embeddings + 1D position IDs + - "unified_3d_mrope": No additive embedding + 3D position IDs for Qwen3VL-style mRoPE + unified_3d_mrope_reset_spatial_ids: If True (default), spatial (H, W) indices + start from 0 for each vision segment. If False, spatial indices are offset + by the temporal offset (Qwen2VL-style). Only used when position_embedding_type="unified_3d_mrope". + enable_fps_modulation: If True, scale temporal position IDs based on video FPS + to reflect real time. Requires fps_vision in gen_data_clean. + Uses the same flag as diffusion_expert_config.enable_fps_modulation. + base_fps: Base FPS for normalization (default 24.0). + Uses the same value as diffusion_expert_config.base_fps. + sound_base_temporal_compression_factor: Base temporal compression factor for sound FPS scaling. + ``None`` preserves the current behavior where sound advances at ``base_fps`` positions/sec. + temporal_compression_factor: VAE temporal compression factor (default 4). + Obtained from the VAE tokenizer at runtime. + vision_temporal_position_mode: Temporal coordinates used for unified_3d_mrope vision tokens. + "latent_index" keeps legacy positions; "uniae_source_right_edge" uses + per-latent positions from gen_data_clean.temporal_positions_vision. + Returns: + PackedSequence containing all packed tensors and metadata. See PackedSequence for field details. + """ + del max_num_tokens + + assert special_tokens is not None, "Special tokens must be provided" + assert isinstance(input_timesteps, torch.Tensor), "input_timesteps must be a tensor" + if input_timesteps.is_cuda: + raise ValueError("input_timesteps must be on CPU, not CUDA") + if isinstance(input_text_indexes, torch.Tensor): + raise ValueError("input_text_tokens must be a list, not a tensor") + + supported_vision_temporal_position_modes = {"latent_index", "uniae_source_right_edge"} + if vision_temporal_position_mode not in supported_vision_temporal_position_modes: + raise ValueError( + "Unsupported vision_temporal_position_mode: " + f"{vision_temporal_position_mode}. Supported modes: {supported_vision_temporal_position_modes}." + ) + has_any_vision = any(plan.has_vision for plan in sequence_plans) + explicit_vision_temporal_positions_active = vision_temporal_position_mode != "latent_index" and has_any_vision + if explicit_vision_temporal_positions_active: + if position_embedding_type != "unified_3d_mrope": + raise NotImplementedError( + "Explicit vision temporal positions are only supported with position_embedding_type='unified_3d_mrope'." + ) + if gen_data_clean.temporal_positions_vision is None: + raise ValueError( + f"vision_temporal_position_mode={vision_temporal_position_mode} requires " + "gen_data_clean.temporal_positions_vision." + ) + if gen_data_clean.x0_tokens_vision is not None and len(gen_data_clean.temporal_positions_vision) != len( + gen_data_clean.x0_tokens_vision + ): + raise ValueError( + "temporal_positions_vision must have one entry per x0_tokens_vision item, " + f"got {len(gen_data_clean.temporal_positions_vision)} positions for " + f"{len(gen_data_clean.x0_tokens_vision)} vision items." + ) + if video_temporal_causal: + raise NotImplementedError( + "video_temporal_causal=True is not wired for explicit UniAE vision temporal positions yet." + ) + if any(plan.has_action for plan in sequence_plans): + raise NotImplementedError("Action packing is not wired for explicit UniAE vision temporal positions yet.") + if initial_mrope_temporal_offset != 0: + raise NotImplementedError( + "Autoregressive mRoPE temporal offsets are not wired for explicit UniAE vision temporal positions yet." + ) + use_float_mrope_positions = enable_fps_modulation or explicit_vision_temporal_positions_active + + # Initialize packed sequence (acts as builder during packing) + packed_seq = PackedSequence() + + # Configure 3D mRoPE on the builder (enabled when position_embedding_type is unified_3d_mrope) + packed_seq._use_mrope = position_embedding_type == "unified_3d_mrope" + packed_seq._mrope_reset_spatial = unified_3d_mrope_reset_spatial_ids + + # Maintain separate indices for each modality + idx_text = 0 + idx_vision = 0 + idx_action = 0 + idx_sound = 0 + null_action_flags: list[bool] = [] # collected from TC path; asserted consistent after the loop + + # Validate: all samples must have text (causal split is always required for two-way attention). + # CFG dropout only drops text *content*, not the structural text split. + if not skip_text_tokens: + for plan in sequence_plans: + assert plan.has_text, "All sequence plans must have has_text=True when skip_text_tokens=False" + + # Pack each sample based on its sequence plan + for sample_idx, sequence_plan in enumerate(sequence_plans): + curr_rope_id = 0 + sample_len = 0 + + # mRoPE temporal offset resets per sample. + # initial_mrope_temporal_offset is non-zero only for AR inference (frame N seeds at N*tcf). + packed_seq._mrope_temporal_offset = initial_mrope_temporal_offset + + _ts = input_timesteps[sample_idx] + input_timestep = _ts.item() if _ts.numel() == 1 else _ts # float (TF) or Tensor(T_max,) (DF) + + # Pack text tokens if has_text=True and not skipped + if sequence_plan.has_text and not skip_text_tokens: + text_ids = input_text_indexes[idx_text] + idx_text += 1 + + has_generation_for_sample = sequence_plan.has_vision or sequence_plan.has_action or sequence_plan.has_sound + curr_rope_id, _, text_sample_len = pack_text_tokens( + packed_seq, + text_ids, + special_tokens, + curr_rope_id, + has_generation=has_generation_for_sample, + use_float_positions=use_float_mrope_positions, + ) + sample_len += text_sample_len + + # End of text modality, add an offset as the boundary between text and vision. + packed_seq._mrope_temporal_offset += unified_3d_mrope_temporal_modality_margin + + # Save temporal offset before vision for action tokens (action uses same offset as vision start) + vision_start_temporal_offset = packed_seq._mrope_temporal_offset + + # Pack vision (and optionally action) tokens + if video_temporal_causal and sequence_plan.has_vision: + # Temporal causal path: when sequence_plan.has_action=True, interleaved supertokens + # [action_t, vision_t]; when False, supertokens are just vision patches. + assert position_embedding_type == "unified_3d_mrope", ( + "video_temporal_causal=True requires position_embedding_type='unified_3d_mrope'" + ) + input_vision_tokens = gen_data_clean.x0_tokens_vision[idx_vision] + idx_vision += 1 + + vision_fps = None + if ( + enable_fps_modulation + and gen_data_clean.fps_vision is not None + and idx_vision - 1 < len(gen_data_clean.fps_vision) + ): + vision_fps = float(gen_data_clean.fps_vision[idx_vision - 1].item()) + + input_action_tokens_tc: torch.Tensor | None = None + action_fps_tc: float | None = None + if sequence_plan.has_action: + input_action_tokens_tc = gen_data_clean.x0_tokens_action[idx_action] + if ( + enable_fps_modulation + and gen_data_clean.fps_action is not None + and idx_action < len(gen_data_clean.fps_action) + ): + action_fps_tc = float(gen_data_clean.fps_action[idx_action].item()) + idx_action += 1 + + supertoken_split_len, null_flag = pack_supertokens_temporal_causal( + packed_seq=packed_seq, + input_vision_tokens=input_vision_tokens, + input_action_tokens=input_action_tokens_tc, + condition_frame_indexes_vision=sequence_plan.condition_frame_indexes_vision, + input_timestep=input_timestep, + curr_rope_id=curr_rope_id, + latent_patch_size=latent_patch_size, + temporal_compression_factor=temporal_compression_factor, + action_dim=action_dim, + vision_fps=vision_fps, + action_fps=action_fps_tc, + enable_fps_modulation=enable_fps_modulation, + base_fps=base_fps, + pack_action_tokens=sequence_plan.has_action, + ) + null_action_flags.append(null_flag) + # We assume all samples in a batch share the same has_action layout, so + # stamp the supertoken layout constant directly here. This is the + # single source of truth read by downstream attention / KV-cache + # code (no recomputation in the network). + packed_seq.num_action_tokens_per_supertoken = temporal_compression_factor if sequence_plan.has_action else 0 + sample_len += supertoken_split_len + vision_split_len = supertoken_split_len + action_split_len = 0 # Already absorbed into supertoken_split_len + + else: + # Standard path: vision and action packed separately + if sequence_plan.has_vision: + # Determine how many vision items this sample owns. + # For multi-item samples (e.g. image editing), num_vision_items_per_sample + # records [2, 2, ...]; for standard T2I/T2V it is None (1 item per sample). + num_vis = ( + gen_data_clean.num_vision_items_per_sample[sample_idx] + if gen_data_clean.num_vision_items_per_sample is not None + else 1 + ) + + vision_split_len = 0 + # Controlnet-style transfer: when set, all vision items share the same + # temporal mRoPE grid. We snapshot the offset before the loop and + # rewind to it before each item, so every item produces identical + # temporal IDs. Each pack_vision_tokens call still advances the + # offset by latent_t internally; in shared-grid mode the post-loop + # offset equals snapshot + latent_t (single-clip semantics for + # downstream EOV / next-modality tokens). + shared_grid = sequence_plan.share_vision_temporal_positions and num_vis > 1 + items_temporal_offset_snapshot = packed_seq._mrope_temporal_offset + shared_latent_t: int | None = None + shared_patch_h: int | None = None + shared_patch_w: int | None = None + shared_temporal_positions: torch.Tensor | None = None + # FPS is recorded per-sample (shape [B]); for multi-item samples + # (transfer / image-edit) every vision item in this sample shares + # the same conditioning FPS, so we read by sample_idx, not by the + # flat idx_vision counter (which would alias to a neighbor sample's + # fps and corrupt RoPE FPS modulation). + sample_vision_fps: float | None = None + if ( + enable_fps_modulation + and gen_data_clean.fps_vision is not None + and sample_idx < len(gen_data_clean.fps_vision) + ): + sample_vision_fps = float(gen_data_clean.fps_vision[sample_idx].item()) + + for item_idx in range(num_vis): + flat_vision_idx = idx_vision + input_vision_tokens = gen_data_clean.x0_tokens_vision[flat_vision_idx] + vision_temporal_positions: torch.Tensor | None = None + if explicit_vision_temporal_positions_active: + assert gen_data_clean.temporal_positions_vision is not None + vision_temporal_positions = gen_data_clean.temporal_positions_vision[flat_vision_idx] + if vision_temporal_positions.shape[0] != input_vision_tokens.shape[2]: + raise ValueError( + "vision_temporal_positions must match latent_t for each vision item, " + f"got {vision_temporal_positions.shape[0]} positions and " + f"latent_t={input_vision_tokens.shape[2]} for item {flat_vision_idx}." + ) + vision_fps = sample_vision_fps + idx_vision += 1 + + # Determine conditioning for this vision item. + # For multi-item mode: all items except the last are fully conditioned + # (all frames are clean); the last item uses the SequencePlan's + # condition_frame_indexes_vision (typically [] = fully generated). + if num_vis > 1 and item_idx < num_vis - 1: + # Conditioning item (e.g. source image): mark all frames as clean + latent_t = input_vision_tokens.shape[2] + item_condition_frames = list(range(latent_t)) + else: + # Generation item (single-item mode or last item in multi-item) + item_condition_frames = sequence_plan.condition_frame_indexes_vision + + if shared_grid: + item_latent_t = input_vision_tokens.shape[2] + item_latent_h = input_vision_tokens.shape[3] + item_latent_w = input_vision_tokens.shape[4] + if shared_latent_t is None: + shared_latent_t = item_latent_t + shared_patch_h = item_latent_h + shared_patch_w = item_latent_w + else: + assert item_latent_t == shared_latent_t, ( + f"share_vision_temporal_positions requires equal latent_t across items, " + f"got item {item_idx} latent_t={item_latent_t} vs first={shared_latent_t}" + ) + assert item_latent_h == shared_patch_h and item_latent_w == shared_patch_w, ( + f"share_vision_temporal_positions requires equal spatial grid across items, " + f"got item {item_idx} (H,W)=({item_latent_h},{item_latent_w}) " + f"vs first=({shared_patch_h},{shared_patch_w})" + ) + if vision_temporal_positions is not None: + if shared_temporal_positions is None: + shared_temporal_positions = vision_temporal_positions + else: + comparison_temporal_positions = vision_temporal_positions.to( + device=shared_temporal_positions.device + ) # [T] + assert torch.allclose(comparison_temporal_positions, shared_temporal_positions), ( + "share_vision_temporal_positions requires equal explicit temporal positions " + f"across vision items, got item {item_idx} positions " + f"{vision_temporal_positions.tolist()} vs first " + f"{shared_temporal_positions.tolist()}." + ) + # Rewind so this item starts at the same temporal offset as item 0. + packed_seq._mrope_temporal_offset = items_temporal_offset_snapshot + + item_split_len = pack_vision_tokens( + packed_seq=packed_seq, + input_vision_tokens=input_vision_tokens, + condition_frame_indexes_vision=item_condition_frames, + input_timestep=input_timestep, + curr_rope_id=curr_rope_id, + latent_patch_size=latent_patch_size, + vision_fps=vision_fps, + enable_fps_modulation=enable_fps_modulation, + base_fps=base_fps, + temporal_compression_factor=temporal_compression_factor, + vision_temporal_positions=vision_temporal_positions, + ) + vision_split_len += item_split_len + sample_len += vision_split_len + + else: + vision_split_len = 0 + + # Pack action tokens if has_action=True + if sequence_plan.has_action: + input_action_tokens = gen_data_clean.x0_tokens_action[idx_action] + + # Get FPS for action (action may have its own FPS independent of vision) + action_fps: float | None = None + if ( + enable_fps_modulation + and gen_data_clean.fps_action is not None + and idx_action < len(gen_data_clean.fps_action) + ): + action_fps = float(gen_data_clean.fps_action[idx_action].item()) + + idx_action += 1 + + action_split_len = pack_action_tokens( + packed_seq=packed_seq, + input_action_tokens=input_action_tokens, + condition_frame_indexes_action=sequence_plan.condition_frame_indexes_action, + input_timestep=input_timestep, + curr_rope_id=curr_rope_id, + action_temporal_offset=vision_start_temporal_offset, + enable_fps_modulation=enable_fps_modulation, + base_fps=base_fps, + action_fps=action_fps, + base_temporal_compression_factor=temporal_compression_factor, + action_start_frame_offset=sequence_plan.action_start_frame_offset, + ) + sample_len += action_split_len + else: + action_split_len = 0 + + # Pack sound tokens if has_sound=True + if sequence_plan.has_sound: + input_sound_tokens = gen_data_clean.x0_tokens_sound[idx_sound] + + # Get FPS for sound (from gen_data_clean, like vision and action) + sound_fps: float | None = None + if ( + enable_fps_modulation + and gen_data_clean.fps_sound is not None + and idx_sound < len(gen_data_clean.fps_sound) + ): + sound_fps = float(gen_data_clean.fps_sound[idx_sound].item()) + + idx_sound += 1 + + sound_split_len = pack_sound_tokens( + packed_seq=packed_seq, + input_sound_tokens=input_sound_tokens, + condition_frame_indexes_sound=sequence_plan.condition_frame_indexes_sound, + input_timestep=input_timestep, + curr_rope_id=curr_rope_id, + sound_temporal_offset=vision_start_temporal_offset, + enable_fps_modulation=enable_fps_modulation, + base_fps=base_fps, + sound_fps=sound_fps, + sound_base_temporal_compression_factor=sound_base_temporal_compression_factor, + ) + sample_len += sound_split_len + else: + sound_split_len = 0 + + # Add end-of-generation token if needed + eov_len = 0 + has_any_generation = sequence_plan.has_vision or sequence_plan.has_action or sequence_plan.has_sound + if include_end_of_generation_token and has_any_generation: + # Type narrowing: we're in build mode, fields are lists + assert isinstance(packed_seq.text_ids, list) + assert isinstance(packed_seq.text_indexes, list) + assert isinstance(packed_seq.position_ids, list) + + packed_seq.text_ids.append(special_tokens["end_of_generation"]) + packed_seq.text_indexes.append(packed_seq.curr) + + # EOV position IDs: 3D mRoPE or 1D RoPE + if packed_seq._use_mrope: + # Use float dtype when any vision mRoPE positions are fractional. + eov_dtype = torch.float32 if use_float_mrope_positions else torch.long + eov_mrope_ids = torch.full((3, 1), packed_seq._mrope_temporal_offset, dtype=eov_dtype) # [3,1] + packed_seq.position_ids.append(eov_mrope_ids) # type: ignore[arg-type] + packed_seq._mrope_temporal_offset += 1 + else: + packed_seq.position_ids.append(curr_rope_id) # type: ignore[arg-type] + + packed_seq.curr += 1 + eov_len = 1 + sample_len += 1 + + combined_split_len = vision_split_len + action_split_len + sound_split_len + eov_len + packed_seq.attn_modes.append("full") + packed_seq.split_lens.append(combined_split_len) + packed_seq.sample_lens.append(sample_len) + + # Assert consistent null_action_supertokens across all TC samples, then set once + if null_action_flags: + assert len(set(null_action_flags)) == 1, ( + f"Inconsistent null_action_supertokens across samples: {null_action_flags}. " + "All samples in a batch must have the same structure (all training or all AR inference)." + ) + packed_seq.null_action_supertokens = null_action_flags[0] + + # Finalize and return packed data + return packed_seq.finalize( + gen_data_clean=gen_data_clean, + ) diff --git a/cosmos_framework/data/vfm/sequence_packing/runtime.py b/cosmos_framework/data/vfm/sequence_packing/runtime.py new file mode 100644 index 0000000..2f80607 --- /dev/null +++ b/cosmos_framework/data/vfm/sequence_packing/runtime.py @@ -0,0 +1,524 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Runtime SequencePack helpers used by attention and context parallel paths.""" + +from typing import Any, List, Tuple + +import torch + +from cosmos_framework.utils import log + +MAX_CAUSAL_LEN_IMAGE_BATCH = 0 +MAX_FULL_LEN_IMAGE_BATCH = 0 +MAX_CAUSAL_LEN_VIDEO_BATCH = 0 +MAX_FULL_LEN_VIDEO_BATCH = 0 + + +def get_padding_stats() -> dict[str, int]: + """Return the current runtime sequence-packing padding stats.""" + return { + "MAX_CAUSAL_LEN_IMAGE_BATCH": MAX_CAUSAL_LEN_IMAGE_BATCH, + "MAX_FULL_LEN_IMAGE_BATCH": MAX_FULL_LEN_IMAGE_BATCH, + "MAX_CAUSAL_LEN_VIDEO_BATCH": MAX_CAUSAL_LEN_VIDEO_BATCH, + "MAX_FULL_LEN_VIDEO_BATCH": MAX_FULL_LEN_VIDEO_BATCH, + } + + +SequencePack = dict[str, Any] + +# ------------------------------------ +# SequencePack: internal helpers +# ------------------------------------ + + +def _find_non_causal_text_token_idx( + attn_modes: List[str], split_lens: List[int], und_token_indexes: List[int] +) -> List[int]: + """ + Find the indexes of the "und" tokens that are under the "full" mode. + This are indices into the full_only_seq. + """ + # Return indexes *into* full_only_seq, not into the original packed sequence. + # The order within full_only_seq is the concatenation of each "full" split in order. + out = [] + full_offset = 0 + packed_idx = 0 + und_token_set = set(und_token_indexes) + for attn_mode, split_len in zip(attn_modes, split_lens): + if attn_mode == "full": + split_indices = range(packed_idx, packed_idx + split_len) + # For this "full" split, find the und tokens within this split, mapped local to full_only_seq offset + for local_idx, split_idx in enumerate(split_indices): + if split_idx in und_token_set: + out.append(full_offset + local_idx) + full_offset += split_len + packed_idx += split_len + return out + + +def _compute_mode_indices_and_offsets( + split_lens: torch.Tensor | List[int], attn_modes: List[str], mode: str, device: torch.device +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute indices from a joint tensor that are in the given mode. + """ + indices = [] + offsets = [0] + next_offset = 0 + start = 0 + + if isinstance(split_lens, torch.Tensor): + split_lens = split_lens.tolist() + + for i, (split_len, attn_mode) in enumerate(zip(split_lens, attn_modes)): + if attn_mode == mode: + indices.extend(range(start, start + split_len)) + next_offset += split_len + offsets.append(next_offset) + start += split_len + return torch.tensor(indices, dtype=torch.int32, device=device), torch.tensor( # [N_mode_tokens], [N_mode_splits+1] + offsets, dtype=torch.int32, device=device + ) + + +# Pad causal_seq and full_only_seq to have length 2048 if not already at that size +def _pad_to_N(N, x: torch.Tensor) -> torch.Tensor: + assert x.shape[0] <= N + padded = x.new_zeros((N, *x.shape[1:])) + padded[: x.shape[0]] = x + return padded + + +def _round_up_to_N(n: int, cp_world_size: int = 1, pad_for_cuda_graphs: bool = False) -> int: + if pad_for_cuda_graphs: + # Reduce recompilations / CUDA graph re-captures by bucketing lengths. + # <= 2K: 128, <= 4K: 256, <= 8K: 512, <= 16K: 1024, > 16K: 2048 + if n <= 2048: + alignment = 128 + elif n <= 4096: + alignment = 256 + elif n <= 8192: + alignment = 512 + elif n <= 16384: + alignment = 1024 + else: + alignment = 2048 + n = ((n + alignment - 1) // alignment) * alignment + + # ensure it's divisible by cp_world_size + if cp_world_size > 1: + remainder = n % cp_world_size + if remainder != 0: + n += cp_world_size - remainder + + return n + + +def _pad( + causal_seq: torch.Tensor, full_only_seq: torch.Tensor, max_causal_len: int, max_full_len: int +) -> tuple[torch.Tensor, torch.Tensor]: + causal_seq = _pad_to_N(max_causal_len, causal_seq) + full_only_seq = _pad_to_N(max_full_len, full_only_seq) + return causal_seq, full_only_seq + + +def _ensure_core_metadata(pack: SequencePack) -> None: + required = [ + "sample_offsets", + "max_sample_len", + "max_causal_len", + "max_full_len", + "_causal_indices", + "_full_indices", + "_causal_seq_offsets", + "_full_only_seq_offsets", + "is_sharded", + ] + for key in required: + if key not in pack: + raise KeyError(f"Missing required pack field: {key}") + + +def init_sequence_pack( + sample_lens: List[int], + split_lens: List[int], + attn_modes: List[str], + device: torch.device, +) -> dict[str, Any]: + _max_sample_len = max(sample_lens) + _max_causal_len = max((split_lens[i] for i in range(len(split_lens)) if attn_modes[i] == "causal"), default=0) + _max_full_len = max((split_lens[i] for i in range(len(split_lens)) if attn_modes[i] == "full"), default=0) + + sample_lens_cu = torch.tensor([0] + sample_lens, device=device, dtype=torch.int32) # [N_samples+1] + _sample_offsets = torch.cumsum(sample_lens_cu, dim=0, dtype=torch.int32) # [N_samples+1] + + _causal_indices, _causal_seq_offsets = _compute_mode_indices_and_offsets(split_lens, attn_modes, "causal", device) + _full_indices, _full_only_seq_offsets = _compute_mode_indices_and_offsets(split_lens, attn_modes, "full", device) + + return dict( + sample_offsets=_sample_offsets, + max_sample_len=_max_sample_len, + max_causal_len=_max_causal_len, + max_full_len=_max_full_len, + _causal_indices=_causal_indices, + _full_indices=_full_indices, + _causal_seq_offsets=_causal_seq_offsets, + _full_only_seq_offsets=_full_only_seq_offsets, + _num_causal_tokens=len(_causal_indices), + _num_full_tokens=len(_full_indices), + split_lens=split_lens, + attn_modes=attn_modes, + ) + + +# ------------------------------------ +# SequencePack constructors +# ------------------------------------ + + +def _round_up_for_cuda_graphs_or_cp( + causal_seq: torch.Tensor, + full_only_seq: torch.Tensor, + need_causal: int, + need_full: int, + is_image_batch: bool, + pad_for_cuda_graphs: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + """Pad causal/full sequences to the required lengths, growing global bounds for CUDA graphs.""" + if pad_for_cuda_graphs: + global \ + MAX_CAUSAL_LEN_IMAGE_BATCH, \ + MAX_FULL_LEN_IMAGE_BATCH, \ + MAX_CAUSAL_LEN_VIDEO_BATCH, \ + MAX_FULL_LEN_VIDEO_BATCH + if is_image_batch: + if need_causal > MAX_CAUSAL_LEN_IMAGE_BATCH: + MAX_CAUSAL_LEN_IMAGE_BATCH = need_causal + log.info(f"Growing MAX_CAUSAL_LEN_IMAGE_BATCH to {MAX_CAUSAL_LEN_IMAGE_BATCH}", rank0_only=False) + if need_full > MAX_FULL_LEN_IMAGE_BATCH: + MAX_FULL_LEN_IMAGE_BATCH = need_full + log.info(f"Growing MAX_FULL_LEN_IMAGE_BATCH to {MAX_FULL_LEN_IMAGE_BATCH}", rank0_only=False) + causal_seq, full_only_seq = _pad( + causal_seq, + full_only_seq, + max_causal_len=MAX_CAUSAL_LEN_IMAGE_BATCH, + max_full_len=MAX_FULL_LEN_IMAGE_BATCH, + ) + else: + if need_causal > MAX_CAUSAL_LEN_VIDEO_BATCH: + MAX_CAUSAL_LEN_VIDEO_BATCH = need_causal + log.info(f"Growing MAX_CAUSAL_LEN_VIDEO_BATCH to {MAX_CAUSAL_LEN_VIDEO_BATCH}", rank0_only=False) + if need_full > MAX_FULL_LEN_VIDEO_BATCH: + MAX_FULL_LEN_VIDEO_BATCH = need_full + log.info(f"Growing MAX_FULL_LEN_VIDEO_BATCH to {MAX_FULL_LEN_VIDEO_BATCH}", rank0_only=False) + causal_seq, full_only_seq = _pad( + causal_seq, + full_only_seq, + max_causal_len=MAX_CAUSAL_LEN_VIDEO_BATCH, + max_full_len=MAX_FULL_LEN_VIDEO_BATCH, + ) + elif need_causal != int(causal_seq.shape[0]) or need_full != int(full_only_seq.shape[0]): + causal_seq, full_only_seq = _pad(causal_seq, full_only_seq, need_causal, need_full) + return causal_seq, full_only_seq + + +def sequence_pack_from_packed_sequence( + packed_sequence: torch.Tensor, + attn_modes: List[str], + split_lens: List[int], + sample_lens: List[int], + packed_und_token_indexes: torch.Tensor, + packed_gen_token_indexes: torch.Tensor, + is_image_batch: bool = False, + cp_world_size: int = 1, + pad_for_cuda_graphs: bool = False, +) -> SequencePack: + """ + Create a sequence pack from a packed sequence and metadata. + NOTE: Some arguments seem redundant because they in principle support more flexible sequence setups. + This constructor checks that the required invariants for SequencePack are satisfied. + NOTE: This constructor checks that there are no "und" tokens under "full" mode, and no "gen" tokens under "causal" mode, + since this is a requirement for SequencePack. + Args: + packed_sequence (torch.Tensor): Tensor containing all tokens in the batch of sequences. + attn_modes (List[str]): List of attention modes. Must be alternating ["causal", "full", ... "causal", "full"] + split_lens (List[int]): Length of each subsequence. len(split_lens) == len(attn_modes) + sample_lens (List[int]): Length of each sequence. len(sample_lens) == number of samples. + packed_und_token_indexes (torch.Tensor): The indexes of the understanding tokens in the packed sequence. + packed_gen_token_indexes (torch.Tensor): The indexes of the generating tokens in the packed sequence. + """ + del packed_gen_token_indexes + + non_causal_text_idxs = _find_non_causal_text_token_idx(attn_modes, split_lens, packed_und_token_indexes.tolist()) + assert len(non_causal_text_idxs) == 0, "non_causal_text_idxs should be empty" + + assert sum(sample_lens) == packed_sequence.shape[0], ( + "sum(sample_lens) must be equal to the length of the packed sequence" + ) + + meta = init_sequence_pack(sample_lens, split_lens, attn_modes, packed_sequence.device) + causal_seq = packed_sequence[meta["_causal_indices"]] # [N_causal_tokens,D] + full_only_seq = packed_sequence[meta["_full_indices"]] # [N_full_tokens,D] + + need_causal = _round_up_to_N(int(causal_seq.shape[0]), cp_world_size, pad_for_cuda_graphs) + need_full = _round_up_to_N(int(full_only_seq.shape[0]), cp_world_size, pad_for_cuda_graphs) + + causal_seq, full_only_seq = _round_up_for_cuda_graphs_or_cp( + causal_seq, + full_only_seq, + need_causal, + need_full, + is_image_batch, + pad_for_cuda_graphs, + ) + + pack: SequencePack = { + **meta, + "max_num_tokens": sum(sample_lens), + "causal_seq": causal_seq, + "full_only_seq": full_only_seq, + "is_sharded": False, + } + return pack + + +def zeros_like(orig: SequencePack, shape: Tuple[int, ...] | torch.Size | None = None) -> SequencePack: + """ + Create a new sequence pack with the same metadata as the original, but with all tokens set to zero. + Args: + orig (SequencePack): The original sequence pack to copy metadata from. + shape (Tuple[int, ...] | torch.Size | None): The shape of the new sequence pack. If None, the shape will be the same as the original. + """ + _ensure_core_metadata(orig) + if shape is None: + shape_causal = orig["causal_seq"].shape + shape_full = orig["full_only_seq"].shape + else: + assert len(shape) >= 1 and shape[0] == -1 + shape_causal = (orig["causal_seq"].shape[0],) + tuple(shape)[1:] + shape_full = (orig["full_only_seq"].shape[0],) + tuple(shape)[1:] + causal_seq = torch.zeros( + shape_causal, device=orig["causal_seq"].device, dtype=orig["causal_seq"].dtype + ) # [N_causal_tokens,D] + full_only_seq = torch.zeros( + shape_full, device=orig["full_only_seq"].device, dtype=orig["full_only_seq"].dtype + ) # [N_full_tokens,D] + return from_mode_splits(causal_seq, full_only_seq, orig) + + +def from_all_seq(packed_sequence: torch.Tensor, metadata_source: SequencePack) -> SequencePack: + """ + Create a new sequence pack from all tokens and another sequence pack with the same metadata. + Args: + packed_sequence (torch.Tensor): Tensor containing all tokens in the batch of sequences. + metadata_source (SequencePack): The metadata source to copy from. + """ + _ensure_core_metadata(metadata_source) + if metadata_source["is_sharded"]: + # Use sharded sequences as is when is_sharded is True (used in Context Parallel) + causal_seq = packed_sequence[: len(metadata_source["causal_seq"])] # [N_causal_tokens,D] + full_only_seq = packed_sequence[len(metadata_source["causal_seq"]) :] # [N_full_tokens,D] + else: + causal_seq = packed_sequence[metadata_source["_causal_indices"]] # [N_causal_tokens,D] + full_only_seq = packed_sequence[metadata_source["_full_indices"]] # [N_full_tokens,D] + causal_seq, full_only_seq = _pad( + causal_seq, + full_only_seq, + max_causal_len=metadata_source["causal_seq"].shape[0], + max_full_len=metadata_source["full_only_seq"].shape[0], + ) + + return from_mode_splits(causal_seq, full_only_seq, metadata_source) + + +def from_mode_splits( + causal_seq: torch.Tensor, + full_only_seq: torch.Tensor, + orig: SequencePack, + is_sharded: bool | None = None, +) -> SequencePack: + """ + Create a new sequence pack from two mode splits. + Args: + causal_seq (torch.Tensor): The causal sequence. + full_only_seq (torch.Tensor): The full-only sequence. + orig (SequencePack): The metadata source to copy from. + is_sharded (bool | None): If True, create a local pack for context parallel. + If None, inherits from orig. + """ + _ensure_core_metadata(orig) + if is_sharded is None: + is_sharded = orig.get("is_sharded", False) + + out = dict(orig) + out["causal_seq"] = causal_seq + out["full_only_seq"] = full_only_seq + out["is_sharded"] = is_sharded + return out + + +def from_und_gen_splits(und_seq: torch.Tensor, gen_seq: torch.Tensor, orig: SequencePack) -> SequencePack: + """ + Create a new sequence pack from two und/gen splits. + Args: + und_seq (torch.Tensor): The understanding sequence. + gen_seq (torch.Tensor): The generating sequence. + orig (SequencePack): The metadata source to copy from. + """ + # The supported SequencePack layout maps und/gen directly to causal/full. + return from_mode_splits(und_seq, gen_seq, orig) + + +# ------------------------------------ +# Getters and setters for SequencePack +# ------------------------------------ +def get_und_seq(pack: SequencePack) -> torch.Tensor: + """ + Get all understanding tokens in a sequence pack in a single tensor. + + Args: + pack (SequencePack): The sequence pack to get the understanding sequence from. + Returns: + torch.Tensor: All understanding tokens concatenated over all sequences in the batch. + """ + return pack["causal_seq"] + + +def set_und_seq(pack: SequencePack, value: torch.Tensor) -> None: + """ + Override the understanding tokens in a sequence pack. + The order of tokens passed in must correspond to the order of tokens returned by get_und_seq. + + Args: + pack (SequencePack): The sequence pack to set the understanding sequence in. + value (torch.Tensor): The understanding sequence to set. + """ + pack["causal_seq"] = value + + +def get_gen_seq(pack: SequencePack) -> torch.Tensor: + """ + Get all generating tokens in a sequence pack in a single tensor. + Args: + pack (SequencePack): The sequence pack to get the generating sequence from. + Returns: + torch.Tensor: All generating tokens concatenated over all sequences in the batch. + """ + return pack["full_only_seq"] + + +def set_gen_seq(pack: SequencePack, value: torch.Tensor) -> None: + """ + Override the generating tokens in a sequence pack. + The order of tokens passed in must correspond to the order of tokens returned by get_gen_seq. + Args: + pack (SequencePack): The sequence pack to set the generating sequence in. + value (torch.Tensor): The generating sequence to set. + """ + pack["full_only_seq"] = value + + +def get_all_seq(pack: SequencePack) -> torch.Tensor: + """ + Get all tokens in a sequence pack in a single tensor. + Args: + pack (SequencePack): The sequence pack to get the all sequence from. + Returns: + torch.Tensor: All tokens concatenated over all sequences in the batch. + """ + if "all_seq" in pack: + return pack["all_seq"] + _ensure_core_metadata(pack) + if pack["is_sharded"]: + assert False, "get_all_seq is not supported in context parallel sharded mode" + out = pack["causal_seq"].new_zeros( + int(pack["_causal_indices"].shape[0] + pack["_full_indices"].shape[0]), *pack["causal_seq"].shape[1:] + ) # [seq_len,D] + if pack["causal_seq"].shape[0] > 0: + out[pack["_causal_indices"]] = pack["causal_seq"][: pack["_causal_indices"].shape[0]] + if pack["full_only_seq"].shape[0] > 0: + out[pack["_full_indices"]] = pack["full_only_seq"][: pack["_full_indices"].shape[0]] + return out + + +def set_all_seq(pack: SequencePack, value: torch.Tensor) -> None: + """ + Override the all tokens in a sequence pack. + The order of tokens passed in must correspond to the order of tokens returned by get_all_seq. + Args: + pack (SequencePack): The sequence pack to set the all sequence in. + value (torch.Tensor): The all sequence to set. + """ + _ensure_core_metadata(pack) + pack["causal_seq"][: pack["_causal_indices"].shape[0]] = value[pack["_causal_indices"]] + pack["full_only_seq"][: pack["_full_indices"].shape[0]] = value[pack["_full_indices"]] + + +def get_causal_seq(pack: SequencePack) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get the causal sequence and its offsets in a sequence pack. + Args: + pack (SequencePack): The sequence pack to get the causal sequence from. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The concatenated causal sub-sequences and the starting offset for each sub-sequence. + """ + _ensure_core_metadata(pack) + return pack["causal_seq"], pack["_causal_seq_offsets"] + + +def get_full_only_seq(pack: SequencePack) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get the full-only sequence and its offsets in a sequence pack. + Args: + pack (SequencePack): The sequence pack to get the full-only sequence from. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The concatenated full-only sub-sequences and the starting offset for each sub-sequence. + """ + _ensure_core_metadata(pack) + return pack["full_only_seq"], pack["_full_only_seq_offsets"] + + +def get_device_and_dtype(pack: SequencePack) -> Tuple[torch.device, torch.dtype]: + """ + Get the device and dtype of a sequence pack. + Args: + pack (SequencePack): The sequence pack to get the device and dtype from. + Returns: + Tuple[torch.device, torch.dtype]: The device and dtype of the sequence pack. + """ + return pack["causal_seq"].device, pack["causal_seq"].dtype + + +def get_und_position_ids(position_ids: torch.Tensor, meta: dict[str, Any]) -> torch.Tensor: + """ + Get the understanding position ids in a sequence pack. + Args: + position_ids (torch.Tensor): The position ids. Shape (seq_len,) for 1D RoPE + or (3, seq_len) for 3D mRoPE. + meta (dict[str, Any]): The metadata. + Returns: + torch.Tensor: The understanding position ids. + """ + assert not meta["is_sharded"], "get_und_position_ids is not supported in context parallel sharded mode" + if position_ids.dim() == 2: + # 3D mRoPE: position_ids is (3, seq_len) + return position_ids[:, meta["_causal_indices"]] # [3,N_causal_tokens] + return position_ids[meta["_causal_indices"]] # [N_causal_tokens] + + +def get_gen_position_ids(position_ids: torch.Tensor, meta: dict[str, Any]) -> torch.Tensor: + """ + Get the generating position ids in a sequence pack. + Args: + position_ids (torch.Tensor): The position ids. Shape (seq_len,) for 1D RoPE + or (3, seq_len) for 3D mRoPE. + meta (dict[str, Any]): The metadata. + Returns: + torch.Tensor: The generating position ids. + """ + assert not meta["is_sharded"], "get_gen_position_ids is not supported in context parallel sharded mode" + if position_ids.dim() == 2: + # 3D mRoPE: position_ids is (3, seq_len) + return position_ids[:, meta["_full_indices"]] # [3,N_full_tokens] + return position_ids[meta["_full_indices"]] # [N_full_tokens] diff --git a/cosmos_framework/data/vfm/sequence_packing/temporal_causal.py b/cosmos_framework/data/vfm/sequence_packing/temporal_causal.py new file mode 100644 index 0000000..abb7345 --- /dev/null +++ b/cosmos_framework/data/vfm/sequence_packing/temporal_causal.py @@ -0,0 +1,289 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Temporal-causal supertoken packing helpers.""" + +import math + +import torch + +from cosmos_framework.data.vfm.sequence_packing.mrope import get_3d_mrope_ids_vae_tokens +from cosmos_framework.data.vfm.sequence_packing.types import ModalityData, PackedSequence + + +def pack_supertokens_temporal_causal( + packed_seq: "PackedSequence", + input_vision_tokens: torch.Tensor, + input_action_tokens: torch.Tensor | None, + condition_frame_indexes_vision: list[int], + input_timestep: float | torch.Tensor, + curr_rope_id: int, + latent_patch_size: int, + temporal_compression_factor: int, + action_dim: int, + vision_fps: float | None = None, + action_fps: float | None = None, + enable_fps_modulation: bool = False, + base_fps: float = 24.0, + pack_action_tokens: bool = True, +) -> tuple[int, bool]: + """Pack vision and (optionally) action tokens in supertoken order for temporal causal attention. + + Buffer layout per frame: + pack_action_tokens=True: [action_t (tcf), vision_t (H*W)] — supertoken size tcf + H*W + pack_action_tokens=False: [vision_t (H*W)] — supertoken size H*W + + Use ``pack_action_tokens=False`` when ``config.action_gen=False``; the resulting + ``num_action_tokens_per_supertoken=0`` is stamped on the pack and read by the + attention builder so NATTEN metadata stays in sync automatically. + + mRoPE layout (with actions, unified_3d_mrope only). The layout is inferred from the + action tensor shape: + - Whole-clip training (frame 0 is the clean conditioning frame, so + ``real_actions`` has ``(T-1)*tcf`` rows): null action for supertoken 0, real + actions for frames 1..T-1 with ``start_frame_offset=1`` so the last action in + group i co-locates with vision frame i; vision uses ``start_frame_offset=0``. + - AR generation, single frame OR chunk (every frame carries a real action, so + ``real_actions`` has ``latent_t*tcf`` rows): vision AND action both use + ``start_frame_offset=1``, generalizing the single-frame AR supertoken to + ``latent_t`` frames. The caller (``pack_input_sequence_autoregressive``) + seeds ``temporal_offset`` one frame-stride back to compensate, so the unit + lands at the same absolute positions as the whole-clip training pack. + - Interleaved per frame as cat([action_ids, vision_ids]). + + ``input_timestep`` is float (TF/none) or Tensor(T_max,) (DF, per-frame sigma). + Conditioning frames are excluded from mse_loss_indexes either way. + + Returns (total_split_len, null_action_flag); null_action_flag is False when + pack_action_tokens=False. + """ + assert isinstance(packed_seq.position_ids, list), "PackedSequence must be in build mode" + + _, _, latent_t, latent_h, latent_w = input_vision_tokens.shape + patch_h = math.ceil(latent_h / latent_patch_size) + patch_w = math.ceil(latent_w / latent_patch_size) + tcf = temporal_compression_factor + patches_per_frame = patch_h * patch_w + supertoken_len = tcf + patches_per_frame if pack_action_tokens else patches_per_frame # S + + # Initialize modalities if needed + if packed_seq.vision is None: + packed_seq.vision = ModalityData() + if pack_action_tokens and packed_seq.action is None: + packed_seq.action = ModalityData() + + assert isinstance(packed_seq.vision.sequence_indexes, list) + assert isinstance(packed_seq.vision.mse_loss_indexes, list) + assert isinstance(packed_seq.vision.timesteps, list) + assert isinstance(packed_seq.vision.tokens, list) + assert isinstance(packed_seq.vision.condition_mask, list) + if pack_action_tokens: + assert isinstance(packed_seq.action.sequence_indexes, list) + assert isinstance(packed_seq.action.mse_loss_indexes, list) + assert isinstance(packed_seq.action.timesteps, list) + assert isinstance(packed_seq.action.tokens, list) + assert isinstance(packed_seq.action.condition_mask, list) + + device = input_vision_tokens.device + dtype = input_vision_tokens.dtype + + null_action_flag: bool + if pack_action_tokens: + # Build all_action_tokens: shape (latent_t * tcf, action_dim) + # + # Cases (token assembly; mRoPE start_frame_offset is chosen separately below, + # inferred from the same action shape): + # 1. Whole-clip training with conditioning frame (latent_t > 1, real_actions + # has (T-1)*tcf rows): prepend tcf null tokens for frame 0, then real + # actions for frames 1..T-1. + # 2. AR generation (every frame has a real action, real_actions has + # latent_t*tcf rows — single frame OR chunk): no null prefix. + # 3. AR frame 0 / image2video (action is None): all null tokens. + if input_action_tokens is not None: + # input_action_tokens shape: (1, T*tcf, D) or (T*tcf, D) for training; (T*tcf, D) for AR units + if input_action_tokens.dim() == 3: + real_actions = input_action_tokens.squeeze(0) # [T*tcf,action_dim] or [N,action_dim] + else: + real_actions = input_action_tokens # [N,action_dim] + null_tokens = torch.zeros(tcf, action_dim, device=device, dtype=real_actions.dtype) # [tcf,action_dim] + if real_actions.shape[0] == latent_t * tcf: + # AR generation (single frame: tcf == 1*tcf, or chunk: latent_t*tcf): + # every supertoken carries a real action, no null prefix. + all_action_tokens = real_actions + null_action_flag = False + elif real_actions.shape[0] == (latent_t - 1) * tcf: + # Conditioning frame present: null for supertoken 0, real for 1..T-1 + all_action_tokens = torch.cat([null_tokens, real_actions], dim=0) # [T*tcf,action_dim] + null_action_flag = True + else: + raise ValueError( + "Temporal-causal action tokens must have either latent_t*tcf rows for AR chunks " + f"or (latent_t-1)*tcf rows for whole-clip training; got {real_actions.shape[0]} rows " + f"for latent_t={latent_t}, tcf={tcf}." + ) + else: + # AR frame 0 or image2video: all action tokens are null + all_action_tokens = torch.zeros( + latent_t * tcf, action_dim, device=device, dtype=dtype + ) # [T*tcf,action_dim] + null_action_flag = True + else: + # pack_action_tokens=False: action tokens must not be supplied. + assert input_action_tokens is None, ( + "pack_action_tokens=False requires input_action_tokens=None; got a non-None tensor." + ) + null_action_flag = False + + # Record vision token shapes and tokens + packed_seq.vision.token_shapes.append((latent_t, patch_h, patch_w)) + packed_seq.vision.tokens.append(input_vision_tokens) + + # Vision conditioning mask: (T, 1, 1) + condition_set_vision = {idx for idx in condition_frame_indexes_vision if 0 <= idx < latent_t} + vision_condition_mask = torch.zeros((latent_t, 1, 1), device=device, dtype=dtype) # [T,1,1] + for fidx in condition_set_vision: + vision_condition_mask[fidx, 0, 0] = 1.0 + packed_seq.vision.condition_mask.append(vision_condition_mask) + + vision_noisy_frame_indexes = torch.tensor( + [idx for idx in range(latent_t) if idx not in condition_set_vision], + device=device, + dtype=torch.long, + ) # [N_noisy_frames] + packed_seq.vision.noisy_frame_indexes.append(vision_noisy_frame_indexes) + + if pack_action_tokens: + # Action token shapes: latent_t * tcf total (including null tokens) + packed_seq.action.token_shapes.append((latent_t * tcf,)) + packed_seq.action.tokens.append(all_action_tokens) + + # Action conditioning mask: all action tokens are conditioning (not supervised) + # Null tokens are always conditioning; real actions are conditioning too (they are inputs) + action_condition_mask = torch.ones((latent_t * tcf, 1), device=device, dtype=dtype) # [T*tcf,1] + packed_seq.action.condition_mask.append(action_condition_mask) + + # Pack in interleaved supertoken order: [action_t, vision_t] for each frame t + # (or just [vision_t] per frame when pack_action_tokens=False) + curr = packed_seq.curr + total_split_len = 0 + + # mRoPE: snapshot offset before this sample, compute IDs + if packed_seq._use_mrope: + temporal_offset = packed_seq._mrope_temporal_offset + effective_vision_fps = vision_fps if enable_fps_modulation else None + + # AR generation (single frame OR chunk) is detected by every frame carrying a + # real action (``real_actions`` has ``latent_t*tcf`` rows). There, vision AND + # action both use start_frame_offset=1 so the last action in each group + # co-locates with its vision frame, mirroring whole-clip training; the caller + # (pack_input_sequence_autoregressive) seeds temporal_offset one frame-stride + # back to compensate. Whole-clip training (frame 0 is the null conditioning + # frame, ``real_actions`` has ``(T-1)*tcf`` rows) keeps vision start_frame_offset=0. + all_frames_have_real_action = ( + pack_action_tokens and input_action_tokens is not None and real_actions.shape[0] == latent_t * tcf + ) + vision_sfo = 1 if all_frames_have_real_action else 0 + + vision_ids_flat, new_offset = get_3d_mrope_ids_vae_tokens( + grid_t=latent_t, + grid_h=patch_h, + grid_w=patch_w, + temporal_offset=temporal_offset, + reset_spatial_indices=packed_seq._mrope_reset_spatial, + fps=effective_vision_fps, + base_fps=base_fps, + temporal_compression_factor=tcf, + start_frame_offset=vision_sfo, + ) # vision_ids_flat: [3,T*patch_h*patch_w] + + if pack_action_tokens: + effective_action_fps = action_fps if enable_fps_modulation else None + + # Action IDs. Real action tokens use start_frame_offset=1 so the last + # sub-token of a group co-locates with its vision frame. Whole-clip training + # has a null action at frame 0 (the conditioning frame); AR units have a real + # action for every frame. + fps_active = effective_action_fps is not None + t_dtype = torch.float32 if fps_active else torch.long + t_offset = float(temporal_offset) if fps_active else int(temporal_offset) + null_t = torch.full((tcf,), t_offset, dtype=t_dtype) # [tcf] + null_hw = torch.zeros(tcf, dtype=t_dtype) # [tcf] + null_ids = torch.stack([null_t, null_hw, null_hw]) # [3,tcf] + + def _real_action_ids(n_frames: int, start_frame_offset: int) -> torch.Tensor: + flat, _ = get_3d_mrope_ids_vae_tokens( + grid_t=n_frames * tcf, + grid_h=1, + grid_w=1, + temporal_offset=temporal_offset, + reset_spatial_indices=packed_seq._mrope_reset_spatial, + fps=effective_action_fps, + base_fps=base_fps, + temporal_compression_factor=1, + base_temporal_compression_factor=tcf, + start_frame_offset=start_frame_offset, + ) + return flat.reshape(3, n_frames, tcf) # [3,n_frames,tcf] + + if all_frames_have_real_action: + # AR generation (single frame: tcf == 1*tcf, or chunk: latent_t*tcf): + # every supertoken carries a real action. start_frame_offset=1 puts + # a_{j-1}'s last sub-token on vision frame j -- the whole-clip TF + # training layout. The caller seeds temporal_offset (N-1) frame-strides + # back to compensate. + action_ids_3d = _real_action_ids(latent_t, start_frame_offset=1) # [3,T,tcf] + elif latent_t > 1: + # Whole-clip training: supertoken 0 = null (conditioning frame), frames + # 1..T-1 = real with start_frame_offset=1. Covers real-action training + # (real_actions has (T-1)*tcf rows) and the architectural all-null layout + # (input_action_tokens is None); the tokens differ but the IDs match. + null_ids_3d = null_ids.reshape(3, 1, tcf) # [3,1,tcf] + real_ids_3d = _real_action_ids(latent_t - 1, start_frame_offset=1) # [3,T-1,tcf] + action_ids_3d = torch.cat([null_ids_3d, real_ids_3d], dim=1) # [3,T,tcf] + else: + # AR frame 0 / image2video (latent_t == 1, no action): only null. + action_ids_3d = null_ids.reshape(3, 1, tcf) # [3,1,tcf] + + # (3, T*H*W) → (3, T, H*W) + vision_ids_3d = vision_ids_flat.reshape(3, latent_t, patches_per_frame) # [3,T,patch_h*patch_w] + + # Interleave per frame: (3, T, tcf+H*W) → (3, T*S) + interleaved_ids = torch.cat([action_ids_3d, vision_ids_3d], dim=2).reshape( + 3, latent_t * supertoken_len + ) # [3,T*S] + packed_seq.position_ids.append(interleaved_ids) + else: + # No action tokens: just vision IDs, already in (3, T*H*W) order. + packed_seq.position_ids.append(vision_ids_flat) + + packed_seq._mrope_temporal_offset = new_offset + + for frame_t in range(latent_t): + if pack_action_tokens: + # Pack action tokens for this frame (indexes only; tokens already stored in packed_seq.action.tokens) + action_indexes = list(range(curr, curr + tcf)) + packed_seq.action.sequence_indexes.extend(action_indexes) + # Action tokens are never in MSE loss (always conditioning) + curr += tcf + total_split_len += tcf + + if not packed_seq._use_mrope: + packed_seq.position_ids.extend([curr_rope_id] * tcf) + + # Pack vision tokens for this frame + frame_indexes = list(range(curr, curr + patches_per_frame)) + packed_seq.vision.sequence_indexes.extend(frame_indexes) + curr += patches_per_frame + total_split_len += patches_per_frame + + if not packed_seq._use_mrope: + packed_seq.position_ids.extend([curr_rope_id] * patches_per_frame) + + # Vision MSE loss: supervise non-conditioning frames + if frame_t not in condition_set_vision: + packed_seq.vision.mse_loss_indexes.extend(frame_indexes) + frame_ts = input_timestep[frame_t].item() if isinstance(input_timestep, torch.Tensor) else input_timestep + packed_seq.vision.timesteps.extend([frame_ts] * patches_per_frame) + + packed_seq.curr = curr + return total_split_len, null_action_flag diff --git a/cosmos_framework/data/vfm/sequence_packing/types.py b/cosmos_framework/data/vfm/sequence_packing/types.py new file mode 100644 index 0000000..ade4a1b --- /dev/null +++ b/cosmos_framework/data/vfm/sequence_packing/types.py @@ -0,0 +1,406 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Dataclasses and plan helpers for VFM sequence packing.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +import torch + +if TYPE_CHECKING: + from cosmos_framework.model.vfm.utils.data_and_condition import GenerationDataClean + + +@dataclass +class ModalityData: + """Unified container for a single generation modality's data. + + This dataclass serves dual purposes: + 1. During packing: Acts as a builder, accumulating data in lists + 2. After finalize(): Holds finalized tensors ready for model consumption + + Attributes: + sequence_indexes: Indices in the packed sequence where this modality's tokens appear. + List during building, Tensor after finalize(). + timesteps: Diffusion timesteps for each noised token. + List during building, Tensor after finalize(). + mse_loss_indexes: Indices where MSE loss should be computed (noised tokens only). + List during building, Tensor after finalize(). + token_shapes: Shape metadata for each sample's tokens. + For vision: list of (T, H, W) tuples. + For action: list of (T,) tuples. + tokens: The actual latent tokens. List during build, Tensor after finalize(). + condition_mask: Mask indicating clean frames (1=clean, 0=noised). Only after finalize(). + noisy_frame_indexes: Indices of noised frames. Constructed from condition_mask during + sequence packing to reduce GPU->CPU synchronization later. Only after finalize(). + domain_id: Domain ID for multi-domain training. Only after finalize(). NOTE: only used for action modality. + raw_action_dim: Raw action dimension. Only after finalize(). NOTE: only used for action modality. + """ + + # Core tracking (list during build, tensor after finalize) + sequence_indexes: list[int] | torch.Tensor = field(default_factory=list) + timesteps: list[float] | torch.Tensor = field(default_factory=list) + mse_loss_indexes: list[int] | torch.Tensor = field(default_factory=list) + # list[tuple[int,int,int]] for vision, list[tuple[int]] for action, list[tuple[int,int,int]] for sound + token_shapes: list = field(default_factory=list) + + # Populated during finalization (from GenerationDataClean / noise path) + tokens: list[torch.Tensor] = field(default_factory=list) + condition_mask: list[torch.Tensor] = field(default_factory=list) + noisy_frame_indexes: list[torch.Tensor] = field(default_factory=list) + domain_id: list[torch.Tensor] = field(default_factory=list) + raw_action_dim: list[torch.Tensor | None] | None = field(default_factory=list) + + def to_cuda(self) -> None: + """Move all tensor fields to CUDA in-place.""" + if isinstance(self.sequence_indexes, torch.Tensor): + self.sequence_indexes = self.sequence_indexes.cuda() + if isinstance(self.timesteps, torch.Tensor): + self.timesteps = self.timesteps.cuda() + if isinstance(self.mse_loss_indexes, torch.Tensor): + self.mse_loss_indexes = self.mse_loss_indexes.cuda() + self.tokens = [token.cuda() for token in self.tokens] + self.condition_mask = [cm.cuda() for cm in self.condition_mask] + self.noisy_frame_indexes = [ni.cuda() for ni in self.noisy_frame_indexes] + self.domain_id = [d.cuda() for d in self.domain_id] + # raw_action_dim is optional (e.g., when action-channel masking is disabled). + if self.raw_action_dim is not None: + self.raw_action_dim = [d.cuda() if d is not None else None for d in self.raw_action_dim] + + +@dataclass +class PackedSequence: + """Unified sequence container - works as builder during packing and final output. + + This dataclass replaces the old SequenceStatus + PackedSequence pattern: + - Build phase: Accumulate data using lists, modalities use ModalityData builders + - After finalize(): Ready for model consumption with tensors + + Attributes: + # Sequence structure + sample_lens: Length of each sample in the packed sequence. + split_lens: Length of each split (text/vision/action sections). + attn_modes: Attention mode for each split ('causal', 'full'). + is_image_batch: Whether this batch contains images (vs videos). + sequence_length: Total length of packed sequence. Computed during finalize(). + + # Build-time tracking (not used after finalize) + curr: Current position in the packed sequence during building. + + # Text modality (list during build, tensor after finalize) + text_ids: All text token IDs (including special tokens). + text_indexes: Indices where text tokens appear in sequence. + position_ids: RoPE position IDs for all tokens. + + # Loss computation - Cross Entropy (text) + label_ids: Label IDs for cross-entropy loss. + ce_loss_indexes: Indices for computing cross-entropy loss. + ce_loss_weights: Weights for cross-entropy loss. + + # Generation modalities - named fields for type safety + vision: Vision modality data (images/videos). None if no vision in batch. + action: Action modality data (robotics). None if no actions in batch. + sound: Sound modality data (audio). None if no sound in batch. + """ + + # Sequence structure + sample_lens: list[int] = field(default_factory=list) + split_lens: list[int] = field(default_factory=list) + attn_modes: list[str] = field(default_factory=list) + is_image_batch: bool = False + sequence_length: int = 0 + + # Build-time tracking (used during packing, not after finalize) + curr: int = 0 + + # Text modality (list during build, tensor after finalize) + text_ids: list[int] | torch.Tensor = field(default_factory=list) + text_indexes: list[int] | torch.Tensor = field(default_factory=list) + position_ids: list[int] | torch.Tensor = field(default_factory=list) + + # Loss computation - Cross Entropy (text) + label_ids: list[int] | torch.Tensor | None = field(default_factory=list) + ce_loss_indexes: list[int] | torch.Tensor | None = field(default_factory=list) + ce_loss_weights: list[float] | torch.Tensor | None = field(default_factory=list) + + # Build-time mRoPE tracking (used during packing, not after finalize) + # When _use_mrope=True, position_ids accumulates (3, N) tensors instead of ints, + # and finalize() produces a (3, total_seq_len) tensor instead of (total_seq_len,). + _use_mrope: bool = False + # Running temporal index for mRoPE position ID generation within a single sample. + # Reset to 0 at the start of each sample, then advanced by text and vision helpers + # as segments are packed. Action reuses the pre-vision snapshot (parallel temporal + # range) without advancing it. Float when FPS modulation is enabled. + # E.g. offset=0 -> text(4 tokens) -> offset=4 -> vision(3 frames) -> offset=7. + _mrope_temporal_offset: int | float = 0 + _mrope_reset_spatial: bool = True + + # Temporal causal: whether supertoken 0's action slot contains null tokens. + # True for all training calls and AR frame 0; False for AR frame N>0 (real actions). + # Used by three_way_attention to zero out V for null action tokens (inline when attention_meta.null_action_supertokens=True). + null_action_supertokens: bool = False + + # Temporal causal: number of action tokens prefixing each vision supertoken. + # Equals temporal_compression_factor when actions are packed inline; 0 when + # action_gen=False or for non-temporal-causal layouts. Single source of truth + # for downstream attention/KV-cache code (per-supertoken layout is + # num_action_tokens_per_supertoken + H_p * W_p). + num_action_tokens_per_supertoken: int = 0 + + # Generation modalities - NAMED FIELDS for type safety + vision: ModalityData | None = None + action: ModalityData | None = None + sound: ModalityData | None = None + + def finalize( + self, + gen_data_clean: GenerationDataClean, + ) -> "PackedSequence": + """Convert all lists to tensors and compute derived values. + + Args: + gen_data_clean: GenerationDataClean for metadata (e.g., action domain IDs). + + Returns: + New PackedSequence instance with tensors instead of lists. + """ + # Compute sequence length + sequence_length = sum(self.sample_lens) + sample_lens = self.sample_lens.copy() + split_lens = self.split_lens.copy() + attn_modes = self.attn_modes.copy() + + # Prepare loss-related tensors (cross-entropy) + label_ids: torch.Tensor | None = None + ce_loss_indexes: torch.Tensor | None = None + ce_loss_weights: torch.Tensor | None = None + if self.label_ids and len(self.label_ids) > 0: + label_ids = torch.tensor(self.label_ids) # [N_ce_tokens] + ce_loss_indexes = torch.tensor(self.ce_loss_indexes) # [N_ce_tokens] + ce_loss_weights = torch.tensor(self.ce_loss_weights) # [N_ce_tokens] + + # The condition_mask and noisy_frame_indexes are kept as lists to support variable shapes. + + # Finalize vision modality + vision: ModalityData | None = None + if self.vision is not None and len(self.vision.sequence_indexes) > 0: + vision = ModalityData( + sequence_indexes=torch.tensor(self.vision.sequence_indexes, dtype=torch.long), # [N_vision_tokens] + timesteps=torch.tensor(self.vision.timesteps), # [N_vision_noisy_tokens] + mse_loss_indexes=torch.tensor( + self.vision.mse_loss_indexes, dtype=torch.long + ), # [N_vision_noisy_tokens] + token_shapes=list(self.vision.token_shapes), + tokens=self.vision.tokens, + condition_mask=list(self.vision.condition_mask), + noisy_frame_indexes=list(self.vision.noisy_frame_indexes), + ) + + # Finalize action modality + action: ModalityData | None = None + if self.action is not None and len(self.action.sequence_indexes) > 0: + action = ModalityData( + sequence_indexes=torch.tensor(self.action.sequence_indexes, dtype=torch.long), # [N_action_tokens] + timesteps=torch.tensor(self.action.timesteps), # [N_action_noisy_tokens] + mse_loss_indexes=torch.tensor( + self.action.mse_loss_indexes, dtype=torch.long + ), # [N_action_noisy_tokens] + token_shapes=list(self.action.token_shapes), + tokens=self.action.tokens, + condition_mask=list(self.action.condition_mask), # Keep as list to support variable shapes + noisy_frame_indexes=list(self.action.noisy_frame_indexes), + domain_id=( + gen_data_clean.action_domain_id + if gen_data_clean.action_domain_id is not None + else [torch.zeros(1, dtype=torch.long)] * len(self.action.token_shapes) + ), + raw_action_dim=gen_data_clean.raw_action_dim, + ) + + # Finalize sound modality (placeholder for future) + sound: ModalityData | None = None + if self.sound is not None and len(self.sound.sequence_indexes) > 0: + sound = ModalityData( + sequence_indexes=torch.tensor(self.sound.sequence_indexes, dtype=torch.long), # [N_sound_tokens] + timesteps=torch.tensor(self.sound.timesteps), # [N_sound_noisy_tokens] + mse_loss_indexes=torch.tensor(self.sound.mse_loss_indexes, dtype=torch.long), # [N_sound_noisy_tokens] + token_shapes=list(self.sound.token_shapes), + tokens=self.sound.tokens, + condition_mask=list(self.sound.condition_mask), + noisy_frame_indexes=list(self.sound.noisy_frame_indexes), + ) + + # Finalize position IDs: 3D mRoPE (3, seq_len) or 1D RoPE (seq_len,) + if self._use_mrope and len(self.position_ids) > 0 and isinstance(self.position_ids[0], torch.Tensor): + mrope_tensors: list[torch.Tensor] = self.position_ids # type: ignore[assignment] + position_ids = torch.cat(mrope_tensors, dim=1) # [3,actual_seq_len] + else: # Original 1D RoPE from Bagel, where all the media tokens share the same 1D position ID + position_ids = torch.tensor(self.position_ids) # [seq_len] + + return PackedSequence( + # Sequence structure + sequence_length=sequence_length, + sample_lens=sample_lens, + split_lens=split_lens, + attn_modes=attn_modes, + is_image_batch=gen_data_clean.is_image_batch, + # Text modality (converted to tensors) + text_ids=torch.tensor(self.text_ids, dtype=torch.long), # [N_text_tokens] + text_indexes=torch.tensor(self.text_indexes, dtype=torch.long), # [N_text_tokens] + position_ids=position_ids, # [seq_len] or [3,seq_len] + # Loss computation - Cross Entropy + label_ids=label_ids, + ce_loss_indexes=ce_loss_indexes, + ce_loss_weights=ce_loss_weights, + # Generation modalities + vision=vision, + action=action, + sound=sound, + # Temporal causal + null_action_supertokens=self.null_action_supertokens, + num_action_tokens_per_supertoken=self.num_action_tokens_per_supertoken, + ) + + def to_cuda(self) -> None: + """Move all tensor fields to CUDA in-place.""" + if isinstance(self.text_ids, torch.Tensor): + self.text_ids = self.text_ids.cuda() + if isinstance(self.text_indexes, torch.Tensor): + self.text_indexes = self.text_indexes.cuda() + if isinstance(self.position_ids, torch.Tensor): + self.position_ids = self.position_ids.cuda() + if isinstance(self.label_ids, torch.Tensor): + self.label_ids = self.label_ids.cuda() + if isinstance(self.ce_loss_indexes, torch.Tensor): + self.ce_loss_indexes = self.ce_loss_indexes.cuda() + if isinstance(self.ce_loss_weights, torch.Tensor): + self.ce_loss_weights = self.ce_loss_weights.cuda() + if self.vision is not None: + self.vision.to_cuda() + if self.action is not None: + self.action.to_cuda() + if self.sound is not None: + self.sound.to_cuda() + + +@dataclass +class SequencePlan: + """Plan describing which modalities are present in a sample. + + This dataclass tracks the presence of different modalities (text, vision, action) + and their conditioning configurations for a dataset sample. Unlike SequencePlan + which holds the actual tensor data, this class provides a lightweight summary + of what modalities exist and how they should be conditioned. + + Attributes: + has_text: Whether text/caption tokens are present for this sample. + Used for text-conditioned generation (e.g., text-to-image/video). + has_vision: Whether vision input (image or video latents) is present. + Defaults to False. + condition_frame_indexes_vision: Indexes of latent vision frames that are clean/conditioning. + [] means all frames are noised/supervised. + All frames specified means all frames are clean (no MSE supervision). + For multi-item samples (e.g. image editing where each sample has multiple + separately-encoded images), this applies to each vision item individually. + The number of items per sample is tracked by + ``GenerationDataClean.num_vision_items_per_sample``. + has_action: Whether action input is present for robotics/embodied AI tasks. + Defaults to False. + condition_frame_indexes_action: Indexes of action steps that are clean/conditioning. + [] means all steps are noised/supervised. + All steps specified means all steps are clean (no MSE supervision). + """ + + # -- understanding (text conditioning) -- + has_text: bool + + # -- vision modality -- + has_vision: bool = False + condition_frame_indexes_vision: list[int] = field(default_factory=list) + # If True, all vision items in this sample share the same temporal mRoPE grid + # (controlnet-style transfer: target frame i is spatio-temporally aligned with + # control frame i). Each item gets the same temporal_offset; spatial reset + # behavior is unchanged. Requires num_vision_items_per_sample > 1, equal latent_t, + # and equal fps across items. Default False preserves single-clip and + # image-editing semantics where items represent distinct time states. + share_vision_temporal_positions: bool = False + + # -- action modality -- + has_action: bool = False + condition_frame_indexes_action: list[int] = field(default_factory=list) + action_start_frame_offset: int = 1 + + # -- sound modality -- + has_sound: bool = False + condition_frame_indexes_sound: list[int] = field(default_factory=list) + + def as_dict(self) -> dict: + return { + "has_text": self.has_text, + "has_vision": self.has_vision, + "has_action": self.has_action, + "has_sound": self.has_sound, + "condition_frame_indexes_vision": self.condition_frame_indexes_vision, + "condition_frame_indexes_action": self.condition_frame_indexes_action, + "condition_frame_indexes_sound": self.condition_frame_indexes_sound, + "share_vision_temporal_positions": self.share_vision_temporal_positions, + } + + +def build_sequence_plans_from_data_batch( + data_batch: dict, + input_video_key, + input_image_key: str, +) -> list[SequencePlan]: + """Build or retrieve sequence plans from a data batch dictionary. + + This function extracts sequence plans from the data batch if they exist, + otherwise creates default SequencePlan objects for each sample + in the batch. + + Args: + data_batch: Dictionary containing the data batch from the dataloader. + Expected keys include 'video' or other tensors to determine batch size. + If 'sequence_plan' key exists, those plans are returned directly. + + Returns: + List of SequencePlan objects, one per sample in the batch. + """ + # For new modalities, please generate the sequence_plan in the dataset class!!!! + + # If sequence_plan already exists in data_batch, return it + if "sequence_plan" in data_batch: + return data_batch["sequence_plan"] + + assert "action" not in data_batch or data_batch["action"] is None, "Action data SHOULD have sequence_plans!" + assert "sound" not in data_batch or data_batch["sound"] is None, "Sound data SHOULD have sequence_plans!" + + # Determine batch size from available tensors + batch_size = 0 + for key in [input_video_key, input_image_key]: + if key in data_batch: + val = data_batch[key] + if isinstance(val, torch.Tensor): + batch_size = val.shape[0] + break + elif isinstance(val, list): + batch_size = len(val) + break + + if batch_size == 0: + raise ValueError( + f"Cannot determine batch size from data_batch. Expected {input_video_key}, {input_image_key}, or similar key." + ) + + # Build default SequencePlan objects + return [ + SequencePlan( + has_text=True, # Has text prompt! + has_vision=True, + condition_frame_indexes_vision=[], # No conditioning frames! + ) + for _ in range(batch_size) + ] diff --git a/cosmos_framework/model/tokenizer/evaluation/reconstruction_metrics.py b/cosmos_framework/model/tokenizer/evaluation/reconstruction_metrics.py index 66db4fb..3537164 100644 --- a/cosmos_framework/model/tokenizer/evaluation/reconstruction_metrics.py +++ b/cosmos_framework/model/tokenizer/evaluation/reconstruction_metrics.py @@ -487,6 +487,8 @@ def reset(self) -> None: self._fid_metric.reset() +# COSMOS-RELEASE-END-IGNORE + __all__ = [ "TokenizerMetric", "PSNRMetric", diff --git a/cosmos_framework/model/tokenizer/models/dense_runtime.py b/cosmos_framework/model/tokenizer/models/dense_runtime.py index 59c6b9d..9382028 100644 --- a/cosmos_framework/model/tokenizer/models/dense_runtime.py +++ b/cosmos_framework/model/tokenizer/models/dense_runtime.py @@ -461,6 +461,15 @@ def decode( When ``pixel_trim`` is enabled and ``pad_frames > 0``, the latent contains boundary tokens from encoding. After decoding, the corresponding boundary pixel frames are trimmed from each chunk. + + **Output shape contract**: + - Video (``temporal_patches > 1``): ``[B, T, H, W, C]`` where T is the + total number of decoded pixel frames across all chunks (after trim). + - Image (``temporal_patches == 1``): ``[B, 1, H, W, C]``. The image + latent is decoded into ``patch_time`` identical frames (it was encoded + from ``patch_time`` copies of the same frame); only the last frame is + kept. This differs from pre-``dense_runtime`` behaviour where the + full ``[B, patch_time, H, W, C]`` was returned. """ if self.decoder_cache_spec.patch_frames != 0: raise NotImplementedError("Dense runtime decoder V1 does not support KV cache.") @@ -481,21 +490,20 @@ def decode( pad_frames = self.pad_frames trim_pixel = self.pixel_trim and pad_frames > 0 - patch_time = self.patch_size[0] # Images were encoded as a single latent (no noncausal first chunk). # Videos have temporal_patches > 1: latent[0] is the noncausal first frame. is_image = temporal_patches == 1 + # Patch 0 is always a single-latent chunk — either the noncausal first + # frame (video) or the sole image latent. Both were encoded from + # [frame × patch_time] copies, so all decoded frames are equivalent; + # keep the last one. For images temporal_patches == 1, so the loop + # below is empty and this is the only chunk. decoded_chunks: list[torch.Tensor] = [] + decoded_first = self._decode_latent_chunk(latent[:, 0:1]) # [B, patch_time, H, W, C] + decoded_chunks.append(decoded_first[:, -1:]) - if not is_image: - # Noncausal first latent: decode → patch_time pixel frames, keep last - # (the reconstructed original first frame). - first_latent = latent[:, 0:1] - decoded_first = self._decode_latent_chunk(first_latent) # [B, patch_time, H, W, C] - decoded_chunks.append(decoded_first[:, -1:]) - - for start_patch in range(0 if is_image else 1, temporal_patches, chunk_patch_frames): + for start_patch in range(1, temporal_patches, chunk_patch_frames): end_patch = min(start_patch + chunk_patch_frames, temporal_patches) latent_chunk = latent[:, start_patch:end_patch] decoded_chunk = self._decode_latent_chunk(latent_chunk) diff --git a/cosmos_framework/model/vfm/algorithm/loss/flow_matching.py b/cosmos_framework/model/vfm/algorithm/loss/flow_matching.py index b10bfd9..ecece25 100644 --- a/cosmos_framework/model/vfm/algorithm/loss/flow_matching.py +++ b/cosmos_framework/model/vfm/algorithm/loss/flow_matching.py @@ -23,7 +23,6 @@ def compute_flow_matching_loss( has_valid_tokens: bool, rectified_flow: RectifiedFlow, tensor_kwargs_fp32: dict, - loss_scale: float | None = None, raw_action_dim: list[torch.Tensor] | None = None, normalize_by_active: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -43,11 +42,6 @@ def compute_flow_matching_loss( rectified_flow: The rectified flow object for time weighting. tensor_kwargs_fp32: Dict of dtype/device kwargs forwarded to ``rectified_flow.train_time_weight``. - loss_scale: Optional per-modality loss scale. Falls back to the global - ``rectified_flow_training_config.loss_scale`` when *None*. - (Currently unused inside the function body — scaling is applied at the - call site in ``OmniMoTModel._compute_losses``. Kept in the signature - to preserve the original API.) normalize_by_active: When True, normalize per-instance loss by the count of active (noisy) elements rather than all elements. Preserves the ``sum / active_count`` semantics needed for distillation critics where diff --git a/cosmos_framework/model/vfm/diffusion/rectified_flow.py b/cosmos_framework/model/vfm/diffusion/rectified_flow.py index 439bff5..bea0e04 100644 --- a/cosmos_framework/model/vfm/diffusion/rectified_flow.py +++ b/cosmos_framework/model/vfm/diffusion/rectified_flow.py @@ -12,6 +12,10 @@ class TrainTimeSampler: _WAVER_MODE_S = 1.29 + # 99.9th and 0.5th percentiles of the standard normal, used for ltx2 stretching. + _LTX2_NORMAL_999_PCTILE = 3.0902 + _LTX2_NORMAL_005_PCTILE = -2.5758 + _LTX2_UNIFORM_PROB = 0.1 def __init__( self, @@ -26,12 +30,22 @@ def __call__( device: torch.device = torch.device("cpu"), dtype: torch.dtype = torch.float32, generator: torch.Generator | None = None, + shifts: torch.Tensor | None = None, ) -> torch.Tensor: """ - Sample time tensor for training + Sample sigma ∈ [0, 1] for training. + + Args: + batch_size: Number of samples. + device: Target device. + dtype: Target dtype. + shifts: Optional 1-D per-sample shift values, shape ``(batch_size,)``. For non-ltx2 + distributions, the raw sample ``t`` is warped through + ``sigma = shift * t / (1 + (shift-1) * t)``. For ``ltx2``, ``shifts`` is + required and used as the per-sample logit-normal mean. Returns: - torch.Tensor: Time tensor, shape (batch_size,) + torch.Tensor: sigma ∈ [0, 1], shape (batch_size,). """ if self.distribution == "uniform": t = torch.rand((batch_size,), generator=generator).to(device=device, dtype=dtype) # [B] @@ -41,8 +55,39 @@ def __call__( u = torch.rand((batch_size,), dtype=torch.float32, generator=generator) # [B] t = 1.0 - u - self._WAVER_MODE_S * (torch.cos(torch.pi / 2.0 * u) ** 2 - 1 + u) # [B] t = t.to(device=device, dtype=dtype) # [B] + elif self.distribution == "ltx2": + # Shifted logit-normal with percentile-based stretching and 10% uniform fallback. + assert shifts is not None, "'ltx2' distribution requires per-sample shifts." + # shift(sigmoid(t), s) = sigmoid(t + ln(s)) + mu = torch.log(shifts.to(device=torch.device("cpu"), dtype=torch.float32)) # [B] + std = 1.0 + eps = 1e-3 + + normal_samples = torch.randn((batch_size,), dtype=torch.float32, generator=generator) * std + mu # [B] + logitnormal_samples = torch.sigmoid(normal_samples) # [B] + + percentile_999 = torch.sigmoid(mu + self._LTX2_NORMAL_999_PCTILE * std) # [B] + percentile_005 = torch.sigmoid(mu + self._LTX2_NORMAL_005_PCTILE * std) # [B] + + zero_terminal_raw = (logitnormal_samples - percentile_005) / (percentile_999 - percentile_005) + stretched = torch.where( + zero_terminal_raw >= eps, + zero_terminal_raw, + 2 * eps - zero_terminal_raw, + ) + stretched = torch.clamp(stretched, 0, 1) + + uniform = (1 - eps) * torch.rand((batch_size,), dtype=torch.float32, generator=generator) + eps + prob = torch.rand((batch_size,), dtype=torch.float32, generator=generator) + t = torch.where(prob > self._LTX2_UNIFORM_PROB, stretched, uniform).to(device=device, dtype=dtype) + + return t # skip post-shift else: - raise NotImplementedError(f"Time distribution '{self.dist}' is not implemented.") + raise NotImplementedError(f"Time distribution '{self.distribution}' is not implemented.") + + if shifts is not None: + shifts = shifts.to(device=device, dtype=dtype) # [B] + t = shifts * t / (1 + (shifts - 1) * t) # [B], sigma ∈ [0,1] return t # [B] @@ -86,18 +131,20 @@ def __init__( self.device = torch.device(device) if isinstance(device, str) else device self.dtype = torch.dtype(dtype) if isinstance(dtype, str) else dtype - def sample_train_time(self, batch_size: int, iteration: int | None = None) -> torch.Tensor: + def sample_train_time(self, batch_size: int, iteration: int | None = None, shifts: torch.Tensor | None = None): r"""This method calls the `TrainTimeSampler` to sample training times. Args: - batch_size: Number of time values to sample. + batch_size: Number of samples. iteration: When provided, sampling uses a local generator seeded from ``(iteration, rank)`` so results are identical across independent runs regardless of prior global RNG state. + shifts: Optional 1-D shift tensor, shape ``(batch_size,)``. Forwarded to + ``TrainTimeSampler.__call__``; see that docstring for details. Returns: t (`torch.Tensor`): - A tensor of sampled training times with shape `(batch_size,)`, + A tensor of sampled sigmas with shape `(batch_size,)`, matching the class specified `device` and `dtype`. """ generator = None @@ -105,7 +152,9 @@ def sample_train_time(self, batch_size: int, iteration: int | None = None) -> to rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 generator = torch.Generator() generator.manual_seed(iteration * 65536 + rank) - time = self.train_time_sampler(batch_size, device=self.device, dtype=self.dtype, generator=generator) + time = self.train_time_sampler( + batch_size, device=self.device, dtype=self.dtype, generator=generator, shifts=shifts + ) return time def get_discrete_timestamp(self, u, tensor_kwargs): diff --git a/cosmos_framework/model/vfm/mot/attention.py b/cosmos_framework/model/vfm/mot/attention.py index 82d1b00..9f950b9 100644 --- a/cosmos_framework/model/vfm/mot/attention.py +++ b/cosmos_framework/model/vfm/mot/attention.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: OpenMDW-1.1 import torch -from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention from cosmos_framework.model.attention import ( attention, @@ -12,8 +11,6 @@ from cosmos_framework.model.attention.masks import CausalType from cosmos_framework.model.vfm.utils.memory import KVToStore, MemoryValue -flex_attention = torch.compile(flex_attention) - class SplitInfo: def __init__( @@ -30,8 +27,8 @@ def __init__( ): """ Actual len is the actual non-padded length of the packed sequence. - It's used to trim split_lens, attn_modes and sample_lens, which were - originally padded to max sequence length (likely for flex attention). + It's used to trim split_lens, attn_modes and sample_lens, which may + be padded to max sequence length by upstream packers. """ assert sum(sample_lens) == sum(split_lens), ( f"Sum of new sample lens {sum(sample_lens)} is not equal to sum of new split lens {sum(split_lens)}" @@ -60,33 +57,31 @@ def __init__( self.null_action_supertokens = null_action_supertokens -AttentionMaskType = BlockMask | SplitInfo +AttentionMaskType = SplitInfo _dotproduct_attention_cache = {} -from cosmos_framework.data.vfm.sequence_packing import ( - FactoredSequencePack, - JointSequencePack, - create_sparse_mask, - factored_from_joint_sequence, - from_joint, - from_mode_splits, +from cosmos_framework.data.vfm.sequence_packing.natten import ( generate_natten_metadata, generate_temporal_causal_natten_metadata, +) +from cosmos_framework.data.vfm.sequence_packing.runtime import ( + SequencePack, + from_mode_splits, get_all_seq, get_causal_seq, get_full_only_seq, - joint_from_joint_sequence, + sequence_pack_from_packed_sequence, ) def two_way_attention( - packed_query_states: FactoredSequencePack | JointSequencePack, - packed_key_states: FactoredSequencePack | JointSequencePack, - packed_value_states: FactoredSequencePack | JointSequencePack, -): + packed_query_states: SequencePack, + packed_key_states: SequencePack, + packed_value_states: SequencePack, +) -> SequencePack: """ Performs two-way attention with causal and full attention. """ @@ -134,12 +129,12 @@ def two_way_attention( def three_way_attention( - packed_query_states: FactoredSequencePack | JointSequencePack, - packed_key_states: FactoredSequencePack | JointSequencePack, - packed_value_states: FactoredSequencePack | JointSequencePack, + packed_query_states: SequencePack, + packed_key_states: SequencePack, + packed_value_states: SequencePack, natten_metadata: dict | None, attention_meta: SplitInfo | None = None, -): +) -> SequencePack: """ Performs three-way attention, with understanding and generations attentions fully decomposed, and allows sparsity / multi-dimensional masking in the generation tower. @@ -238,72 +233,14 @@ def three_way_attention( return out_all -def pad_sequence(tensor, pad_size): - """ - Pad a tensor along the second-to-last dimension. - - Args: - tensor: Input tensor to pad - pad_size: Number of padding elements to add - - Returns: - Padded tensor with zeros added along dim=-2 - """ - if pad_size <= 0: - return tensor - pad_shape = list(tensor.shape) - pad_shape[-2] = pad_size - padding = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device) - return torch.cat([tensor, padding], dim=-2) # [...,S+pad_size,...] - - -def block_flex_attention( - packed_query_states: FactoredSequencePack | JointSequencePack, - packed_key_states: FactoredSequencePack | JointSequencePack, - packed_value_states: FactoredSequencePack | JointSequencePack, - attention_mask: BlockMask, - block_size: int | None = None, -): - packed_queries = get_all_seq(packed_query_states) # [N,heads,head_dim] - packed_keys = get_all_seq(packed_key_states) # [N,heads,head_dim] - packed_values = get_all_seq(packed_value_states) # [N,heads,head_dim] - max_num_tokens = packed_query_states["max_num_tokens"] - - num_attention_heads = packed_queries.shape[1] - head_dim = packed_queries.shape[2] - - # Handle block mask attention with flex_attention - pad_size = max_num_tokens - packed_queries.shape[0] - packed_queries_padded = pad_sequence(packed_queries.permute(1, 0, 2), pad_size) # [heads,max_num_tokens,head_dim] - packed_keys_padded = pad_sequence(packed_keys.permute(1, 0, 2), pad_size) # [heads,max_num_tokens,head_dim] - packed_values_padded = pad_sequence(packed_values.permute(1, 0, 2), pad_size) # [heads,max_num_tokens,head_dim] - - packed_attn_output = flex_attention( - packed_queries_padded.unsqueeze(0), # [1,heads,max_num_tokens,head_dim] - packed_keys_padded.unsqueeze(0), # [1,heads,max_num_tokens,head_dim] - packed_values_padded.unsqueeze(0), # [1,heads,max_num_tokens,head_dim] - enable_gqa=True, - block_mask=attention_mask, - ) # [1,heads,max_num_tokens,head_dim] - assert isinstance(packed_attn_output, torch.Tensor) - - end_index = packed_attn_output.shape[2] - pad_size - packed_attn_output = packed_attn_output[0, :, :end_index, :] # [heads,N,head_dim] - packed_attn_output = packed_attn_output.transpose(0, 1).reshape( - -1, num_attention_heads * head_dim - ) # [N,heads*head_dim] - - return from_joint(packed_attn_output, packed_query_states) - - def dispatch_attention( - packed_query_states: FactoredSequencePack | JointSequencePack, - packed_key_states: FactoredSequencePack | JointSequencePack, - packed_value_states: FactoredSequencePack | JointSequencePack, - attention_mask: BlockMask | SplitInfo, + packed_query_states: SequencePack, + packed_key_states: SequencePack, + packed_value_states: SequencePack, + attention_mask: SplitInfo, natten_metadata: dict | None = None, memory_value: MemoryValue | None = None, -) -> tuple[FactoredSequencePack | JointSequencePack, KVToStore | None]: +) -> tuple[SequencePack, KVToStore | None]: assert memory_value is None, "Base dispatch_attention does not handle MemoryValue" if isinstance(attention_mask, SplitInfo) and attention_mask.is_three_way: output = three_way_attention( @@ -316,7 +253,7 @@ def dispatch_attention( elif isinstance(attention_mask, SplitInfo): output = two_way_attention(packed_query_states, packed_key_states, packed_value_states) else: - output = block_flex_attention(packed_query_states, packed_key_states, packed_value_states, attention_mask) + raise TypeError(f"Unsupported attention metadata: {type(attention_mask)}") return output, None @@ -344,35 +281,21 @@ def build_packed_sequence( num_action_tokens_per_supertoken: int = 0, null_action_supertokens: bool = False, pad_for_cuda_graphs: bool = False, -) -> tuple[FactoredSequencePack | JointSequencePack, AttentionMaskType, list | None]: +) -> tuple[SequencePack, AttentionMaskType, list | None]: """ Build the model input pack and attention meta for joint attention. Returns a tuple: (input_pack, attention_meta). """ device = packed_sequence.device natten_metadata_list = None - if joint_attn_implementation == "flex": - sparse_mask = create_sparse_mask(sample_lens, split_lens, attn_modes, device) - seqlen = sum(sample_lens) - attention_meta = create_block_mask( - sparse_mask, - B=1, - H=num_heads, - Q_LEN=seqlen, - KV_LEN=seqlen, - device=device, - BLOCK_SIZE=block_size, - _compile=True, - ) - make_pack = joint_from_joint_sequence - elif joint_attn_implementation == "two_way": + if joint_attn_implementation == "two_way": attention_meta = SplitInfo( split_lens=split_lens, attn_modes=attn_modes, sample_lens=sample_lens, actual_len=int(packed_sequence.shape[0]), ) - make_pack = factored_from_joint_sequence + make_pack = sequence_pack_from_packed_sequence elif joint_attn_implementation == "three_way": attention_meta = SplitInfo( split_lens=split_lens, @@ -385,7 +308,7 @@ def build_packed_sequence( num_action_tokens_per_supertoken=num_action_tokens_per_supertoken, null_action_supertokens=null_action_supertokens, ) - make_pack = factored_from_joint_sequence + make_pack = sequence_pack_from_packed_sequence # Some memory-driven attention paths implement temporal visibility in # their own attention kernels; skip NATTEN metadata for those paths. if not skip_natten_metadata: @@ -412,8 +335,7 @@ def build_packed_sequence( ) else: raise ValueError( - f"Invalid joint_attn_implementation: {joint_attn_implementation}. " - "Must be 'two_way', 'three_way', or 'flex'." + f"Invalid joint_attn_implementation: {joint_attn_implementation}. Must be 'two_way' or 'three_way'." ) input_pack = make_pack( diff --git a/cosmos_framework/model/vfm/mot/attention_test.py b/cosmos_framework/model/vfm/mot/attention_test.py index 56c2dee..53112ee 100644 --- a/cosmos_framework/model/vfm/mot/attention_test.py +++ b/cosmos_framework/model/vfm/mot/attention_test.py @@ -9,7 +9,10 @@ import cosmos_framework.model.vfm.mot.attention as attention from cosmos_framework.model.attention.natten import NATTEN_SUPPORTED -from cosmos_framework.data.vfm.sequence_packing import ( +from cosmos_framework.model.vfm.mot.attention import ( + build_packed_sequence, +) +from cosmos_framework.data.vfm.sequence_packing.runtime import ( get_all_seq, get_gen_seq, get_und_seq, @@ -17,9 +20,6 @@ set_und_seq, zeros_like, ) -from cosmos_framework.model.vfm.mot.attention import ( - build_packed_sequence, -) MAX_SEQ_LEN = 24 SEQS_PER_BATCH = 4 @@ -64,7 +64,6 @@ def _test_attention_impls( torch.compiler.reset() IMPL_TO_FN = { - "flex": attention.block_flex_attention, "two_way": attention.two_way_attention, "three_way": attention.three_way_attention, } @@ -151,23 +150,7 @@ def make_pack_two_way(x): def make_pack_three_way(x): return _make_pack_decomposed(x, "three_way") - def make_pack_flex(x): - return build_packed_sequence( - "flex", - packed_sequence=x, - attn_modes=attn_modes, - split_lens=split_lens, - sample_lens=sample_lens, - packed_und_token_indexes=packed_und_idx_t, - packed_gen_token_indexes=packed_gen_idx_t, - num_heads=num_q_heads, - head_dim=head_dim, - num_layers=num_layers, - token_shapes=token_shapes, - )[0] - IMPL_TO_MAKE_PACK = { - "flex": make_pack_flex, "two_way": make_pack_two_way, "three_way": make_pack_three_way, } @@ -271,27 +254,8 @@ def forward(self, *args, **kwargs): key_joint_2["_causal_seq_offsets"] = query_joint_2["_causal_seq_offsets"] value_joint_2["_causal_seq_offsets"] = query_joint_2["_causal_seq_offsets"] - # Build a matching flex attention_meta for the given shapes kwargs_1 = {} kwargs_2 = {} - if impl_1 == "flex" or impl_2 == "flex": - packed_sequence = packed_qkv22 if impl_2 == "flex" else packed_qkv21 - _, attention_mask, _ = build_packed_sequence( - "flex", - packed_sequence=packed_sequence[:, :num_q_heads, :], - attn_modes=attn_modes, - split_lens=split_lens, - sample_lens=sample_lens, - packed_und_token_indexes=packed_und_idx_t, - packed_gen_token_indexes=packed_gen_idx_t, - num_heads=num_q_heads, - head_dim=head_dim, - num_layers=num_layers, - ) - if impl_1 == "flex": - kwargs_1["attention_mask"] = attention_mask - elif impl_2 == "flex": - kwargs_2["attention_mask"] = attention_mask # natten_metadata is a required argument, but setting it to None implements standard self attn. if impl_1 == "three_way": @@ -329,7 +293,7 @@ def forward(self, *args, **kwargs): ) torch.cuda.synchronize() - # joint vs factored test. should be same + # Independent packs for the same implementation should be the same. torch.testing.assert_close( get_all_seq(output1_factored)[:real_len], get_all_seq(output1_joint)[:real_len], atol=atol_self, rtol=rtol_self ) @@ -367,14 +331,30 @@ def forward(self, *args, **kwargs): @pytest.mark.L0 @pytest.mark.skipif(not NATTEN_SUPPORTED, reason="NATTEN is not available, or too old.") -def test_two_way_attention_cmp_flex_attn(): - _test_attention_impls("two_way", "flex") +def test_two_way_attention_vs_three_way_attention(): + _test_attention_impls("two_way", "three_way") @pytest.mark.L0 -@pytest.mark.skipif(not NATTEN_SUPPORTED, reason="NATTEN is not available, or too old.") -def test_two_way_attention_vs_three_way_attention(): - _test_attention_impls("two_way", "three_way") +def test_build_packed_sequence_rejects_flex(): + device = torch.device("cpu") + packed_sequence = torch.randn(4, 8, device=device) # [N,D] + packed_und_token_indexes = torch.tensor([0, 1], device=device, dtype=torch.long) # [N_und] + packed_gen_token_indexes = torch.tensor([2, 3], device=device, dtype=torch.long) # [N_gen] + + with pytest.raises(ValueError, match="Must be 'two_way' or 'three_way'"): + build_packed_sequence( + "flex", + packed_sequence=packed_sequence, + attn_modes=["causal", "full"], + split_lens=[2, 2], + sample_lens=[4], + packed_und_token_indexes=packed_und_token_indexes, + packed_gen_token_indexes=packed_gen_token_indexes, + num_heads=1, + head_dim=8, + num_layers=1, + ) @pytest.mark.L0 @@ -383,7 +363,7 @@ def test_decoder_layer_optimized_path_empty_und_tensor_shape(): In the optimized path (frame > 0, KV cache active), the decoder layer creates empty und tensors for all intermediate und variables. These tensors are stored as - ``causal_seq`` in the output FactoredSequencePack, and the *next* decoder layer + ``causal_seq`` in the output SequencePack, and the *next* decoder layer calls ``get_und_seq(input)`` to retrieve them. If they are 1D ``[0]``, a subsequent RMSNorm ``weight [H] * hidden_states [0]`` triggers: RuntimeError: The size of tensor a (H) must match tensor b (0) at non-singleton dim 0 @@ -413,7 +393,7 @@ def test_decoder_layer_optimized_path_empty_und_tensor_shape(): norm_out = weight * new_und # [H] * [0, H] → [0, H] assert norm_out.shape == (0, hidden_dim) - # Verify round-trip through FactoredSequencePack preserves 2D shape. + # Verify round-trip through SequencePack preserves 2D shape. # from_mode_splits(und, gen, meta) stores und as causal_seq; get_und_seq retrieves it. meta = {"causal_seq": new_und, "full_only_seq": ref} retrieved = get_und_seq(meta) # type: ignore[arg-type] @@ -421,5 +401,4 @@ def test_decoder_layer_optimized_path_empty_und_tensor_shape(): if __name__ == "__main__": - test_two_way_attention_cmp_flex_attn() test_two_way_attention_vs_three_way_attention() diff --git a/cosmos_framework/model/vfm/mot/context_parallel_test.py b/cosmos_framework/model/vfm/mot/context_parallel_test.py index e1f1fd7..3ba6433 100644 --- a/cosmos_framework/model/vfm/mot/context_parallel_test.py +++ b/cosmos_framework/model/vfm/mot/context_parallel_test.py @@ -11,29 +11,33 @@ import torch.distributed as dist from cosmos_framework.data.vfm.joint_dataloader import IterativeJointDataLoader +from cosmos_framework.model.vfm.mot.attention import ( + SplitInfo, + dispatch_attention, +) +from cosmos_framework.model.vfm.mot.context_parallel_utils import ( + context_parallel_attention, + get_context_parallel_sharded_sequence, +) +from cosmos_framework.model.vfm.mot.parallelize_unified_mot import ARReplicatedIODispatch +from cosmos_framework.model.vfm.mot.unified_mot import _apply_head_sharded_o_proj +from cosmos_framework.model.vfm.utils.data_and_condition import GenerationDataClean from cosmos_framework.data.vfm.sequence_packing import ( - FactoredSequencePack, PackedSequence, build_sequence_plans_from_data_batch, - factored_from_joint_sequence, - from_joint, + pack_input_sequence, +) +from cosmos_framework.data.vfm.sequence_packing.runtime import ( + SequencePack, + from_all_seq, from_mode_splits, get_all_seq, get_gen_seq, get_und_seq, - pack_input_sequence, + sequence_pack_from_packed_sequence, set_gen_seq, set_und_seq, ) -from cosmos_framework.model.vfm.mot.attention import ( - SplitInfo, - dispatch_attention, -) -from cosmos_framework.model.vfm.mot.context_parallel_utils import ( - context_parallel_attention, - get_context_parallel_sharded_sequence, -) -from cosmos_framework.model.vfm.utils.data_and_condition import GenerationDataClean from cosmos_framework.utils.vfm.parallelism import ParallelDims @@ -154,7 +158,7 @@ def get_factored_qkv_data( print(f"DEBUG: packed_und_token_indexes length: {packed_und_token_indexes.shape[0]}") print(f"DEBUG: split_lens sum causal: {sum(l for l, m in zip(split_lens, attn_modes) if m == 'causal')}") - global_q_pack = factored_from_joint_sequence( + global_q_pack = sequence_pack_from_packed_sequence( packed_sequence=global_packed_sequence_q, attn_modes=attn_modes, split_lens=split_lens, @@ -162,8 +166,8 @@ def get_factored_qkv_data( packed_und_token_indexes=packed_und_token_indexes, packed_gen_token_indexes=packed_gen_token_indexes, ) - global_k_pack = from_joint(global_packed_sequence_k, global_q_pack) - global_v_pack = from_joint(global_packed_sequence_v, global_q_pack) + global_k_pack = from_all_seq(global_packed_sequence_k, global_q_pack) + global_v_pack = from_all_seq(global_packed_sequence_v, global_q_pack) print(f"DEBUG: global_q_pack causal_seq shape: {get_und_seq(global_q_pack).shape}") return global_q_pack, global_k_pack, global_v_pack @@ -193,8 +197,8 @@ def verify_fwd_output( # COMPARE: On rank 0, concatenate the shards and compare with the baseline if rank == 0: - # cast baseline_output_pack to FactoredSequencePack - baseline_output_pack = cast(FactoredSequencePack, baseline_output_pack) + # cast baseline_output_pack to SequencePack + baseline_output_pack = cast(SequencePack, baseline_output_pack) print(f"Comparing results for world_size={world_size}...") baseline_und_seq = get_und_seq(baseline_output_pack) @@ -376,7 +380,12 @@ def test_context_parallel_attention_two_way(): attention_function_to_wrap = partial(dispatch_attention) print(f"DEBUG: world_size: {world_size}") - parallel_dims = ParallelDims(enable_inference_mode=True, world_size=world_size, dp_shard=1, cp=cp_size) + parallel_dims = ParallelDims( + enable_inference_mode=True, + world_size=world_size, + dp_shard=1, + cp=cp_size, + ) parallel_dims.build_meshes("cuda") cp_mesh = parallel_dims.cp_mesh @@ -657,7 +666,7 @@ def simple_packed_test(): head_dim = 128 global_packed_sequence_q, _, _ = create_qkv_sequences(global_packed_data, device, num_heads, num_heads, head_dim) - factored_q_pack = factored_from_joint_sequence( + factored_q_pack = sequence_pack_from_packed_sequence( packed_sequence=global_packed_sequence_q, attn_modes=global_packed_data.attn_modes, split_lens=global_packed_data.split_lens, @@ -709,8 +718,8 @@ def _make_factored_pack( S_gen_global: int, device: torch.device, is_sharded: bool = False, -) -> FactoredSequencePack: - """Minimal single-sample FactoredSequencePack for unit tests. +) -> SequencePack: + """Minimal single-sample SequencePack for unit tests. Metadata always uses GLOBAL (pre-sharding) token counts so the metadata is consistent before and after all-to-all inside context_parallel_attention(). @@ -734,13 +743,15 @@ def _make_factored_pack( } + + @pytest.mark.L0 def test_get_context_parallel_sharded_sequence_three_way(): """get_context_parallel_sharded_sequence() accepts three_way attn_implementation. The causal_8b_480p config uses joint_attn_implementation="three_way" (required by video_temporal_causal=True). The sharding logic is identical to "two_way" — it - operates on the FactoredSequencePack (und/gen split), not on the attention pattern — + operates on the SequencePack (und/gen split), not on the attention pattern — so "three_way" must not be rejected by the assertion. Verifies that both und and gen sequences are sharded to 1/world_size tokens per rank, diff --git a/cosmos_framework/model/vfm/mot/context_parallel_utils.py b/cosmos_framework/model/vfm/mot/context_parallel_utils.py index 96bf607..6a6b3dd 100644 --- a/cosmos_framework/model/vfm/mot/context_parallel_utils.py +++ b/cosmos_framework/model/vfm/mot/context_parallel_utils.py @@ -37,11 +37,11 @@ import torch.distributed as dist from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import DTensor, Replicate, Shard -from torch.nn.attention.flex_attention import BlockMask -from cosmos_framework.data.vfm.sequence_packing import ( - FactoredSequencePack, - JointSequencePack, +from cosmos_framework.model.vfm.mot.attention import SplitInfo +from cosmos_framework.model.vfm.utils.memory import KVToStore, MemoryValue +from cosmos_framework.data.vfm.sequence_packing.runtime import ( + SequencePack, from_mode_splits, get_all_seq, get_causal_seq, @@ -51,8 +51,6 @@ get_und_position_ids, get_und_seq, ) -from cosmos_framework.model.vfm.mot.attention import SplitInfo -from cosmos_framework.model.vfm.utils.memory import KVToStore, MemoryValue from cosmos_framework.utils.vfm.parallelism import ParallelDims @@ -96,10 +94,10 @@ def context_parallel_broadcast_tensor_list( def get_context_parallel_sharded_sequence( attn_implementation: str, - input_pack: FactoredSequencePack, + input_pack: SequencePack, position_ids: torch.Tensor, parallel_dims: ParallelDims | None, -) -> tuple[FactoredSequencePack, torch.Tensor]: +) -> tuple[SequencePack, torch.Tensor]: """ Splits the full input_pack into a local shard for Context Parallelism. """ @@ -157,7 +155,7 @@ def get_context_parallel_sharded_sequence( def get_context_parallel_last_hidden_state( - packed_outputs: FactoredSequencePack, + packed_outputs: SequencePack, parallel_dims: ParallelDims | None, ) -> torch.Tensor: if parallel_dims is None or not parallel_dims.cp_enabled: @@ -237,7 +235,7 @@ def gather_seq_scatter_heads( x: shape of [z, seq, h, ...] seq_dim: the dimension to gather head_dim: the dimension to scatter - cp_mesh: ulysses sequence parallelism size + cp_mesh: sequence-sharded context-parallel mesh Returns: torch.Tensor: shape of gathered and scattered tensor """ @@ -260,7 +258,7 @@ def gather_heads_scatter_seq( x (torch.Tensor): shape of [bsz, seq, h/n, ...] head_dim (int): the dimension to gather seq_dim (int): the dimension to scatter - cp_mesh (DeviceMesh): ulysses sequence parallelism size + cp_mesh (DeviceMesh): sequence-sharded context-parallel mesh splits (List[torch.Tensor], optional): Manual splits for variable length scattering Returns: @@ -271,14 +269,14 @@ def gather_heads_scatter_seq( def context_parallel_attention( cp_mesh: DeviceMesh, - packed_query_states: FactoredSequencePack, - packed_key_states: FactoredSequencePack, - packed_value_states: FactoredSequencePack, - attention_mask: BlockMask | SplitInfo, + packed_query_states: SequencePack, + packed_key_states: SequencePack, + packed_value_states: SequencePack, + attention_mask: SplitInfo, attention_function: Callable, natten_metadata: dict | None = None, memory_value: MemoryValue | None = None, -) -> tuple[FactoredSequencePack | JointSequencePack, KVToStore | None]: +) -> tuple[SequencePack, KVToStore | None]: """Ulysses-style context parallel attention for packed und+gen sequences. Each rank holds a sequence shard [S/cp, H, D] for Q and [S/cp, H_kv, D] diff --git a/cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py b/cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py index 909a1d0..8323a8a 100644 --- a/cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py +++ b/cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py @@ -9,7 +9,6 @@ from transformers.configuration_utils import PretrainedConfig from transformers.modeling_utils import PreTrainedModel -from cosmos_framework.data.vfm.sequence_packing import ModalityData, PackedSequence, verify_natten_parameter_list from cosmos_framework.model.vfm.mot.attention import build_packed_sequence from cosmos_framework.model.vfm.mot.context_parallel_utils import ( get_context_parallel_last_hidden_state, @@ -22,6 +21,8 @@ VideoRopePosition3DEmb, ) from cosmos_framework.model.vfm.utils.memory import MemoryState +from cosmos_framework.data.vfm.sequence_packing import ModalityData, PackedSequence +from cosmos_framework.data.vfm.sequence_packing.natten import verify_natten_parameter_list class Cosmos3VFMNetworkConfig(PretrainedConfig): @@ -124,6 +125,7 @@ def __init__(self, language_model, config: Cosmos3VFMNetworkConfig): self.num_kv_heads = text_config.num_key_value_heads self.head_dim = text_config.head_dim self.num_hidden_layers = text_config.num_hidden_layers + self.attention_io_layout = "sequence_sharded" self.predict_text_tokens = config.predict_text_tokens if config.natten_parameter_list is not None and config.joint_attn_implementation != "three_way": @@ -628,13 +630,17 @@ def _encode_vision( assert vision.token_shapes is not None assert isinstance(vision.sequence_indexes, torch.Tensor) assert isinstance(vision.timesteps, torch.Tensor) + torch._assert( + vision.timesteps.dtype in (torch.long, torch.float32), + f"Timestep must be long/float32, got {vision.timesteps.dtype}", + ) + assert isinstance(vision.mse_loss_indexes, torch.Tensor) packed_tokens_vision, original_latent_shapes = self.patchify_and_pack_latents( vision.tokens, vision.token_shapes ) # packed_tokens_vision: [total_vision_patches,patch_latent_dim] - - packed_tokens_vision = self.vae2llm(packed_tokens_vision) # [total_vision_patches,hidden_size] + packed_tokens_vision = self.vae2llm(packed_tokens_vision.to(target_dtype)) # [total_vision_patches,hidden_size] # Add absolute position embedding only when NOT using unified 3D mRoPE # (3D mRoPE provides positional information via rotary embeddings instead) @@ -647,7 +653,7 @@ def _encode_vision( has_noisy_vision = vision.mse_loss_indexes.numel() > 0 if has_noisy_vision: - timesteps_vision = vision.timesteps * self.timestep_scale # [N_noisy_frames_vision] + timesteps_vision = vision.timesteps.to(dtype=torch.float32) * self.timestep_scale # [N_noisy_frames_vision] # Timesteps are computed in FP32 for numerical stability. with torch.autocast("cuda", enabled=True, dtype=torch.float32): @@ -1033,12 +1039,28 @@ def forward( # The packer is the single source of truth for the supertoken layout. # ``num_action_tokens_per_supertoken`` is stamped onto ``packed_seq`` by - # ``_pack_supertokens_temporal_causal`` (= tcf when actions are packed + # ``pack_supertokens_temporal_causal`` (= tcf when actions are packed # inline, 0 otherwise) and read unchanged by the attention builder, the # NATTEN metadata generator, and the rolling KV-cache state — keeping # all downstream supertoken geometry automatically in sync with the pack. num_action_tokens_per_supertoken = packed_seq.num_action_tokens_per_supertoken + replicated_attention_io_cp = ( + self.attention_io_layout == "replicated" + and self.parallel_dims is not None + and self.parallel_dims.cp_enabled + ) + # ``sequence_sharded`` attention I/O shards the token sequence, so + # packing must pad sequence lengths to the CP size and the input/output + # sequence helpers need the CP mesh. ``replicated`` attention I/O keeps + # current-frame sequences replicated and uses the CP mesh later inside + # attention to slice local heads, so the effective sequence-sharding + # world size is 1 here. + sequence_shard_parallel_dims = None if replicated_attention_io_cp else self.parallel_dims + sequence_shard_world_size = ( + 1 if replicated_attention_io_cp else (self.parallel_dims.cp_size if self.parallel_dims else 1) + ) + input_pack, attention_meta, natten_metadata_list = build_packed_sequence( self.config.joint_attn_implementation, packed_sequence=packed_sequence, @@ -1053,7 +1075,7 @@ def forward( num_layers=self.num_hidden_layers, token_shapes=packed_seq.vision.token_shapes, natten_parameter_list=self.natten_parameter_list, - cp_world_size=self.parallel_dims.cp_size if self.parallel_dims else 1, + cp_world_size=sequence_shard_world_size, video_temporal_causal=self.video_temporal_causal, skip_natten_metadata=memory is not None and not memory.requires_natten_metadata(), vision_token_shapes=vision_token_shapes, @@ -1067,7 +1089,7 @@ def forward( attn_implementation=self.config.joint_attn_implementation, input_pack=input_pack, position_ids=packed_seq.position_ids, - parallel_dims=self.parallel_dims, + parallel_dims=sequence_shard_parallel_dims, ) packed_outputs, lbl_metadata = self.language_model( @@ -1079,7 +1101,7 @@ def forward( ) last_hidden_state = get_context_parallel_last_hidden_state( packed_outputs=packed_outputs, - parallel_dims=self.parallel_dims, + parallel_dims=sequence_shard_parallel_dims, ) # [N_total,hidden_size] output_dict = dict() diff --git a/cosmos_framework/model/vfm/mot/modeling_utils.py b/cosmos_framework/model/vfm/mot/modeling_utils.py index 418daec..bcd42d9 100644 --- a/cosmos_framework/model/vfm/mot/modeling_utils.py +++ b/cosmos_framework/model/vfm/mot/modeling_utils.py @@ -367,8 +367,8 @@ def timestep_embedding(t, dim, max_period=10000): half = dim // 2 freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( device=t.device - ) # [D/2] - args = t[:, None].float() * freqs[None] # [N,D/2] + ) + args = t[:, None] * freqs[None] # [N,D/2] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) # [N,D] if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) # [N,D+1] diff --git a/cosmos_framework/model/vfm/mot/parallelize_unified_mot.py b/cosmos_framework/model/vfm/mot/parallelize_unified_mot.py index e72d127..a4601d1 100644 --- a/cosmos_framework/model/vfm/mot/parallelize_unified_mot.py +++ b/cosmos_framework/model/vfm/mot/parallelize_unified_mot.py @@ -1,8 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 -from typing import Callable - """FSDP / activation-checkpointing / torch.compile pass for the unified MoT. The activation-checkpointing implementation here mirrors the torchtitan SAC @@ -14,6 +12,7 @@ """ import re +from typing import Callable import torch import torch.nn as nn @@ -21,21 +20,23 @@ checkpoint_wrapper as ptd_checkpoint_wrapper, ) from torch.distributed.fsdp import fully_shard, register_fsdp_forward_method -from torch.nn.attention.flex_attention import BlockMask from torch.utils.checkpoint import ( CheckpointPolicy, create_selective_checkpoint_contexts, ) +from cosmos_framework.utils import log from cosmos_framework.configs.base.defaults.activation_checkpointing import ActivationCheckpointingConfig from cosmos_framework.configs.base.defaults.compile import CompileConfig -from cosmos_framework.data.vfm.sequence_packing import ( - FactoredSequencePack, - JointSequencePack, -) from cosmos_framework.model.vfm.mot.attention import SplitInfo, dispatch_attention from cosmos_framework.model.vfm.mot.context_parallel_utils import context_parallel_attention from cosmos_framework.model.vfm.utils.memory import KVToStore, MemoryValue +from cosmos_framework.data.vfm.sequence_packing.runtime import ( + SequencePack, + from_und_gen_splits, + get_gen_seq, + get_und_seq, +) from cosmos_framework.utils.vfm.parallelism import ParallelDims @@ -51,7 +52,7 @@ class ContextParallelDispatch(nn.Module): the inner ``wrapped_dispatch`` with Ulysses-style all-to-all communication. This includes the AR frame 1+ gen-only path — the inner dispatch routes to ``attention_AR_gen_only`` which operates on the - head-sharded tensors produced by the all-to-all. + local-head tensors produced by the all-to-all. All cache writes flow through the ``MemoryState`` interface; neither this class nor the CP attention functions write to the cache directly. @@ -68,13 +69,13 @@ def __init__( def forward( self, - packed_query_states: FactoredSequencePack | JointSequencePack, - packed_key_states: FactoredSequencePack | JointSequencePack, - packed_value_states: FactoredSequencePack | JointSequencePack, - attention_mask: BlockMask | SplitInfo, + packed_query_states: SequencePack, + packed_key_states: SequencePack, + packed_value_states: SequencePack, + attention_mask: SplitInfo, natten_metadata: dict | None = None, memory_value: MemoryValue | None = None, - ) -> tuple[FactoredSequencePack | JointSequencePack, KVToStore | None]: + ) -> tuple[SequencePack, KVToStore | None]: if memory_value is not None and not memory_value.supports_context_parallel_attention: raise ValueError("Context-parallel doesn't work when training with a KV-cache.") @@ -90,6 +91,127 @@ def forward( ) +class ARReplicatedIODispatch(nn.Module): + """AR CP dispatch for replicated attention I/O with local-head attention. + + ``Replicated I/O`` means the caller-side tensors at the attention boundary + are replicated across CP ranks. It does **not** mean attention compute is + replicated. For AR frame 1+, this wrapper slices the replicated current + Q/K/V to this rank's local Q/KV heads and runs attention against the local + KV-head cache. + + Shape flow for AR frame 1+: + before slicing: + q: [S,H,D], k/v: [S,H_kv,D], cached k/v: [B,S_hist,H_kv/CP,D] + after local head slicing: + q: [S,H/CP,D], k/v: [S,H_kv/CP,D], cached k/v: [B,S_hist,H_kv/CP,D] + after local attention: + out_local: [S,H/CP*D] + after sharded o_proj in PackedAttentionMoT: + out: [S,hidden_size] + + Current-frame hidden states stay replicated. For AR frame 1+, this wrapper + delegates to the existing memory-aware AR attention for local heads, then + returns the local current-frame attention output so ``PackedAttentionMoT`` + can apply the corresponding ``o_proj`` column slice. Frame 0 and non-AR + paths delegate unchanged; frame 0 seeds the local KV-head cache through + ``ARMemoryState.write_for_layer``. + """ + + def __init__( + self, + cp_mesh, + wrapped_dispatch: Callable = dispatch_attention, + ) -> None: + super().__init__() + self.cp_mesh = cp_mesh + self.wrapped_dispatch = wrapped_dispatch + + def _head_slices(self, q_heads: int, kv_heads: int) -> tuple[slice, slice]: + cp_group = self.cp_mesh.get_group() + cp_rank = torch.distributed.get_rank(cp_group) + cp_size = torch.distributed.get_world_size(cp_group) + assert kv_heads % cp_size == 0, ( + f"replicated attention_io_layout requires num_kv_heads({kv_heads}) % cp_size({cp_size}) == 0. " + f"num_kv_heads={kv_heads} is the upper bound for useful local-head attention CP." + ) + assert q_heads % kv_heads == 0, f"Query heads ({q_heads}) must be divisible by KV heads ({kv_heads})" + kv_heads_per_rank = kv_heads // cp_size + q_heads_per_kv_head = q_heads // kv_heads + q_heads_per_rank = kv_heads_per_rank * q_heads_per_kv_head + kv_start = cp_rank * kv_heads_per_rank + kv_end = kv_start + kv_heads_per_rank + q_start = cp_rank * q_heads_per_rank + q_end = q_start + q_heads_per_rank + return slice(q_start, q_end), slice(kv_start, kv_end) + + def _slice_local_heads( + self, + packed_query_states: SequencePack, + packed_key_states: SequencePack, + packed_value_states: SequencePack, + ) -> tuple[SequencePack, SequencePack, SequencePack]: + # Input heads are full and sequence-replicated on every CP rank: + # q: [S,H,D], k/v: [S,H_kv,D]. + q_und_seq = get_und_seq(packed_query_states) # [S_und,H,D] + q_gen_seq = get_gen_seq(packed_query_states) # [S_curr,H,D] + k_und_seq = get_und_seq(packed_key_states) # [S_und,H_kv,D] + k_gen_seq = get_gen_seq(packed_key_states) # [S_curr,H_kv,D] + v_und_seq = get_und_seq(packed_value_states) # [S_und,H_kv,D] + v_gen_seq = get_gen_seq(packed_value_states) # [S_curr,H_kv,D] + + # Slice the contiguous Q-head group that corresponds to this rank's + # contiguous KV-head group: q -> [S,H/CP,D], k/v -> [S,H_kv/CP,D]. + q_slice, kv_slice = self._head_slices(q_gen_seq.shape[1], k_gen_seq.shape[1]) + q_und_local = q_und_seq[:, q_slice, :].contiguous() # [S_und,H_local,D] + q_gen_local = q_gen_seq[:, q_slice, :].contiguous() # [S_curr,H_local,D] + k_und_local = k_und_seq[:, kv_slice, :].contiguous() # [S_und,H_kv_local,D] + k_gen_local = k_gen_seq[:, kv_slice, :].contiguous() # [S_curr,H_kv_local,D] + v_und_local = v_und_seq[:, kv_slice, :].contiguous() # [S_und,H_kv_local,D] + v_gen_local = v_gen_seq[:, kv_slice, :].contiguous() # [S_curr,H_kv_local,D] + + local_query_pack = from_und_gen_splits(q_und_local, q_gen_local, packed_query_states) + local_key_pack = from_und_gen_splits(k_und_local, k_gen_local, packed_key_states) + local_value_pack = from_und_gen_splits(v_und_local, v_gen_local, packed_value_states) + return local_query_pack, local_key_pack, local_value_pack + + def forward( + self, + packed_query_states: SequencePack, + packed_key_states: SequencePack, + packed_value_states: SequencePack, + attention_mask: SplitInfo, + natten_metadata: dict | None = None, + memory_value: MemoryValue | None = None, + ) -> tuple[SequencePack, KVToStore | None]: + if memory_value is None or getattr(memory_value, "frame_idx", 0) <= 0: + return self.wrapped_dispatch( + packed_query_states, + packed_key_states, + packed_value_states, + attention_mask, + natten_metadata=natten_metadata, + memory_value=memory_value, + ) + if getattr(memory_value, "for_cuda_graphs", False): + raise ValueError("replicated attention_io_layout does not support ARMemoryState(for_cuda_graphs=True)") + + local_query_pack, local_key_pack, local_value_pack = self._slice_local_heads( + packed_query_states, + packed_key_states, + packed_value_states, + ) + local_output_pack, kv_to_store = self.wrapped_dispatch( + local_query_pack, + local_key_pack, + local_value_pack, + attention_mask, + natten_metadata=natten_metadata, + memory_value=memory_value, + ) + return local_output_pack, kv_to_store + + def _apply_selective_ac( module: nn.Module, ac: ActivationCheckpointingConfig, @@ -222,6 +344,38 @@ def apply_cp( return model +def apply_replicated_attention_io_cp( + model: nn.Module, + parallel_dims: ParallelDims, +) -> nn.Module: + """Install replicated-attention-IO context parallelism on every attention layer.""" + cp_mesh = parallel_dims.cp_mesh + cp_size = parallel_dims.cp_size + first_block = next(iter(model.model.layers.children())) + first_attn = first_block.self_attn + num_kv_heads = int(first_attn.num_key_value_heads) + num_attention_heads = int(first_attn.num_attention_heads) + assert num_kv_heads % cp_size == 0, ( + f"replicated attention_io_layout requires num_kv_heads({num_kv_heads}) % cp_size({cp_size}) == 0. " + f"num_kv_heads={num_kv_heads} is the upper bound for useful local-head attention CP." + ) + log.info( + "[replicated attention I/O CP] enabled " + f"cp_size={cp_size}, num_kv_heads={num_kv_heads}, num_attention_heads={num_attention_heads}, " + f"kv_heads_per_rank={num_kv_heads // cp_size}, max_useful_cp_size={num_kv_heads}", + rank0_only=True, + ) + for _, block in model.model.layers.named_children(): + attn = block.self_attn + attn.replicated_attention_io_local_head_o_proj = True + attn.replicated_attention_io_cp_mesh = cp_mesh + attn.dispatch_attention_fn = ARReplicatedIODispatch( + cp_mesh, + wrapped_dispatch=attn.dispatch_attention_fn, + ) + return model + + def apply_fsdp( model: nn.Module, parallel_dims: ParallelDims, @@ -253,6 +407,7 @@ def parallelize_unified_mot( parallel_dims: ParallelDims | None, compile_config: CompileConfig, ac_config: ActivationCheckpointingConfig, + attention_io_layout: str = "sequence_sharded", ) -> nn.Module: """Optimize the model using CP, FSDP, activation checkpointing, and torch.compile. @@ -271,10 +426,16 @@ def parallelize_unified_mot( back to the dataclass defaults (mode="selective", save the ``save_ops_regex`` ops, mode="full", save only the outputs of each transformer block). + attention_io_layout: Tensor layout at the attention boundary under CP. """ if parallel_dims is not None and parallel_dims.cp_enabled: - apply_cp(model, parallel_dims) + if attention_io_layout == "replicated": + apply_replicated_attention_io_cp(model, parallel_dims) + elif attention_io_layout == "sequence_sharded": + apply_cp(model, parallel_dims) + else: + raise ValueError(f"Unsupported attention_io_layout={attention_io_layout!r}") apply_ac(model, ac_config) if compile_config.enabled: apply_compile(model, compile_config) diff --git a/cosmos_framework/model/vfm/mot/parallelize_vfm_network.py b/cosmos_framework/model/vfm/mot/parallelize_vfm_network.py index 704bbce..746afa0 100644 --- a/cosmos_framework/model/vfm/mot/parallelize_vfm_network.py +++ b/cosmos_framework/model/vfm/mot/parallelize_vfm_network.py @@ -47,6 +47,7 @@ def parallelize_vfm_network( parallel_dims: ParallelDims | None, compile_config: CompileConfig, ac_config: ActivationCheckpointingConfig, + attention_io_layout: str = "sequence_sharded", ) -> torch.nn.Module: """Optimize the model using FSDP, CP, activation checkpointing, and torch.compile. @@ -62,7 +63,9 @@ def parallelize_vfm_network( ``OmniMoTModelConfig.sac``. Forwarded to ``parallelize_unified_mot``; ``None`` falls back to the ``ActivationCheckpointingConfig`` defaults. + attention_io_layout: Tensor layout at the attention boundary under CP. """ + model.attention_io_layout = attention_io_layout if parallel_dims is not None and parallel_dims.cp_enabled: model.parallel_dims = parallel_dims @@ -71,6 +74,7 @@ def parallelize_vfm_network( parallel_dims=parallel_dims, compile_config=compile_config, ac_config=ac_config, + attention_io_layout=attention_io_layout, ) if compile_config.enabled and compile_config.compiled_region == "all": diff --git a/cosmos_framework/model/vfm/mot/unified_mot.py b/cosmos_framework/model/vfm/mot/unified_mot.py index 2444c66..5c76a15 100644 --- a/cosmos_framework/model/vfm/mot/unified_mot.py +++ b/cosmos_framework/model/vfm/mot/unified_mot.py @@ -10,21 +10,11 @@ import torch from torch import nn +from torch.distributed import ProcessGroup from cosmos_framework.model.attention import attention as imaginaire_attention from cosmos_framework.model.attention.masks import CausalType from cosmos_framework.utils import log -from cosmos_framework.data.vfm.sequence_packing import ( - FactoredSequencePack, - from_joint, - from_und_gen_splits, - get_device_and_dtype, - get_gen_seq, - get_und_seq, - set_gen_seq, - set_und_seq, - zeros_like, -) from cosmos_framework.model.vfm.mot.attention import ( AttentionMaskType, dispatch_attention, @@ -78,6 +68,17 @@ Qwen3VLMoeTextSparseMoeBlock, Qwen3VLMoeVisionModel, ) +from cosmos_framework.data.vfm.sequence_packing.runtime import ( + SequencePack, + from_all_seq, + from_und_gen_splits, + get_device_and_dtype, + get_gen_seq, + get_und_seq, + set_gen_seq, + set_und_seq, + zeros_like, +) # Torch optimization settings torch._dynamo.config.cache_size_limit = 512 @@ -140,10 +141,9 @@ def is_moe(self) -> bool: # MoT wrapper configs — one per architecture family # ----------------------------------------------------------------------------- -# Package root = parent of the top-level ``cosmos_framework`` package directory, -# i.e. the framework repo root in an editable install or site-packages for a -# wheel. The shipped model-config JSONs live under ``cosmos_framework/model/...`` -# beneath it. +# Package root = parent of the top-level package directory; the shipped +# model-config JSONs live under it (cosmos_framework/model/... in releases; +# cosmos_framework/model/vfm/... in i4 source). _PACKAGE_ROOT = Path(__file__).resolve().parents[4] @@ -152,8 +152,7 @@ def _resolve_packaged_config_path(json_file: str) -> str: Absolute paths and paths that already exist relative to the CWD are returned unchanged (preserving existing behavior when launched from the repo root). - A relative path that does not exist against the CWD — e.g. the shipped - ``"cosmos_framework/model/.../X.json"`` defaults — is resolved against the + A relative path that does not exist against the CWD is resolved against the installed package root. If that candidate is missing too, the original path is returned so ``open()`` raises the familiar ``FileNotFoundError``. """ @@ -389,14 +388,6 @@ def from_json_file(cls, json_file: str) -> "_MoTConfigBase": fields (when present) are surfaced lazily via :pyattr:`vision_config` and by HF downstream consumers reading the dict directly. - - ``json_file`` may be absolute, or relative. The shipped config - defaults reference these JSONs by a repo-root-relative path (e.g. - ``"cosmos_framework/model/vfm/vlm/qwen3_vl/configs/X.json"``), which - only resolves when the process CWD is the framework repo root. To keep - ``cosmos_framework.scripts.*`` runnable from any working directory, a - relative path that does not exist against the CWD is resolved against - the installed package root. """ with open(_resolve_packaged_config_path(json_file), encoding="utf-8") as reader: config_dict = json.load(reader) @@ -463,6 +454,21 @@ def _transform_text_dict(self, text_dict: Mapping[str, Any]) -> Mapping[str, Any # ----------------------------------------------------------------------------- +def _apply_head_sharded_o_proj( + local_attn_output: torch.Tensor, # [N,H_local*D] + projection: nn.Linear, + feature_slice: slice, + cp_group: ProcessGroup, +) -> torch.Tensor: # [N,hidden_size] + """Apply one local input-column slice of ``projection`` and sum partial outputs.""" + local_weight = projection.weight[:, feature_slice] # [hidden_size,H_local*D] + out = torch.nn.functional.linear(local_attn_output, local_weight, bias=None) # [N,hidden_size] + torch.distributed.all_reduce(out, op=torch.distributed.ReduceOp.SUM, group=cp_group) + if projection.bias is not None: + out = out + projection.bias # [N,hidden_size] + return out + + class PackedAttentionMoT(nn.Module): """ Dual-pathway packed attention for MoT architectures. @@ -534,15 +540,45 @@ def __init__( self._apply_rotary_pos_emb = layer_types.apply_rotary_pos_emb self.dispatch_attention_fn = dispatch_attention + self.replicated_attention_io_local_head_o_proj = False + self.replicated_attention_io_cp_mesh: Any | None = None + + def _replicated_attention_io_q_feature_slice(self) -> slice: + cp_mesh = self.replicated_attention_io_cp_mesh + assert cp_mesh is not None, "replicated attention I/O requires a CP mesh" + cp_group = cp_mesh.get_group() + cp_rank = torch.distributed.get_rank(cp_group) + cp_size = torch.distributed.get_world_size(cp_group) + assert self.num_key_value_heads % cp_size == 0, ( + f"cp_size({cp_size}) must divide num_key_value_heads({self.num_key_value_heads})" + ) + assert self.num_attention_heads % self.num_key_value_heads == 0, ( + f"num_attention_heads({self.num_attention_heads}) must be divisible by " + f"num_key_value_heads({self.num_key_value_heads})" + ) + kv_heads_per_rank = self.num_key_value_heads // cp_size + q_heads_per_kv_head = self.num_attention_heads // self.num_key_value_heads + q_heads_per_rank = kv_heads_per_rank * q_heads_per_kv_head + q_start = cp_rank * q_heads_per_rank + q_end = q_start + q_heads_per_rank + return slice(q_start * self.head_dim, q_end * self.head_dim) + + def _uses_replicated_attention_io_local_head_o_proj(self, memory_value: MemoryValue | None) -> bool: + return ( + self.replicated_attention_io_local_head_o_proj + and memory_value is not None + and getattr(memory_value, "frame_idx", 0) > 0 + and not getattr(memory_value, "for_cuda_graphs", False) + ) def forward( self, - pack: FactoredSequencePack, + pack: SequencePack, attention_mask: AttentionMaskType, - packed_position_embeddings: tuple[FactoredSequencePack, FactoredSequencePack], + packed_position_embeddings: tuple[SequencePack, SequencePack], natten_metadata: dict | None = None, memory_value: MemoryValue | None = None, - ) -> tuple[FactoredSequencePack, KVToStore | None]: + ) -> tuple[SequencePack, KVToStore | None]: """Forward pass with optional memory-augmented attention. When ``memory_value`` is provided, ``dispatch_attention_fn`` routes to @@ -557,7 +593,7 @@ def forward( Args: pack: Packed sequence with und/gen tokens - attention_mask: Attention mask (BlockMask or SplitInfo) + attention_mask: Attention metadata (SplitInfo). packed_position_embeddings: RoPE embeddings (cos, sin) natten_metadata: Optional NATTEN metadata for neighborhood attention. memory_value: Optional read-only tensor container for memory-augmented attention. @@ -619,7 +655,7 @@ def forward( # Produce kv_to_store for MemoryState.write_for_layer() when the # dispatch didn't already provide one (e.g. standard or AR frame-0 - # non-CP paths). CP dispatch returns head-sharded kv_to_store + # non-CP paths). CP dispatch returns local KV-head kv_to_store # directly, so kv_to_store is already non-None in that case. # # Gradient detach is NOT done here; each MemoryState.write_for_layer() @@ -635,9 +671,39 @@ def forward( v_und[:und_len].unsqueeze(0), ) - # Apply projections directly to get final results - und_seq = self.o_proj(get_und_seq(packed_attn_output)) # [N_und,hidden_size] - gen_seq = self.o_proj_moe_gen(get_gen_seq(packed_attn_output)) # [N_gen,hidden_size] + # Attention compute is local-head under both sequence-sharded and + # replicated attention I/O layouts. The difference here is the output + # layout returned to this module. Replicated attention I/O returns only + # this rank's local heads from AR frame 1+ attention: + # gen [N_gen,H_local*D] and und [0,H_local*D]. We therefore apply the + # matching o_proj input-column slice and all-reduce partial outputs back + # to replicated hidden states. The else path receives full attention + # heads at this boundary, so regular o_proj applies: + # und [N_und,H*D] -> [N_und,hidden_size], + # gen [N_gen,H*D] -> [N_gen,hidden_size]. + if self._uses_replicated_attention_io_local_head_o_proj(memory_value): + local_und_attn = get_und_seq(packed_attn_output) # [0,H_local*D] + local_gen_attn = get_gen_seq(packed_attn_output) # [N_gen,H_local*D] + assert local_und_attn.shape[0] == 0, "replicated attention I/O only supports gen-only frame 1+ attention" + feature_slice = self._replicated_attention_io_q_feature_slice() + assert feature_slice.start is not None and feature_slice.stop is not None + expected_local_features = feature_slice.stop - feature_slice.start + assert local_gen_attn.shape[-1] == expected_local_features, ( + f"Expected local attention features {expected_local_features}, got {local_gen_attn.shape[-1]}" + ) + cp_mesh = self.replicated_attention_io_cp_mesh + assert cp_mesh is not None, "replicated attention I/O requires a CP mesh" + cp_group = cp_mesh.get_group() + und_seq = local_gen_attn.new_empty((0, self.hidden_size)) # [0,hidden_size] + gen_seq = _apply_head_sharded_o_proj( + local_gen_attn, + self.o_proj_moe_gen, + feature_slice, + cp_group, + ) # [N_gen,hidden_size] + else: + und_seq = self.o_proj(get_und_seq(packed_attn_output)) # [N_und,hidden_size] + gen_seq = self.o_proj_moe_gen(get_gen_seq(packed_attn_output)) # [N_gen,hidden_size] return from_und_gen_splits(und_seq, gen_seq, pack), kv_to_store # [N_und+N_gen,hidden_size] def reasoner_forward( @@ -781,12 +847,12 @@ def _impl_init_taylorseer(self, cache_dic=None, current=None): def _impl_forward( self, - pack: FactoredSequencePack, + pack: SequencePack, attention_mask, position_ids: torch.Tensor, natten_metadata_list: list | None = None, memory: MemoryState | None = None, -) -> tuple[FactoredSequencePack, dict[str, LBLMetadata]]: +) -> tuple[SequencePack, dict[str, LBLMetadata]]: """Shared training forward pass for the three MoT text models. Used by ``Qwen3VLTextModel``, ``Qwen3VLMoeTextModel``, and @@ -794,7 +860,7 @@ def _impl_forward( Args: pack: Packed sequence with und/gen tokens. - attention_mask: Attention mask (BlockMask or SplitInfo). + attention_mask: Attention metadata (SplitInfo). position_ids: Position IDs (1D ``[N]`` for standard RoPE or 2D ``[3, N]`` for mrope). natten_metadata_list: Optional per-layer NATTEN metadata. @@ -815,8 +881,8 @@ def _impl_forward( cos = cos.squeeze(0) # [N,head_dim] sin = sin.squeeze(0) # [N,head_dim] position_embeddings = ( - from_joint(cos, pack), - from_joint(sin, pack), + from_all_seq(cos, pack), + from_all_seq(sin, pack), ) # Tracking the load balancing loss across all layers. For dense models, lbl_metadata_all @@ -950,13 +1016,13 @@ def __init__( def forward( self, - input: FactoredSequencePack, + input: SequencePack, attention_mask, - packed_position_embeddings: tuple[FactoredSequencePack, FactoredSequencePack], + packed_position_embeddings: tuple[SequencePack, SequencePack], natten_metadata: dict | None = None, memory_value: MemoryValue | None = None, gen_only: bool = False, - ) -> tuple[FactoredSequencePack, dict[str, LBLMetadata], KVToStore | None]: + ) -> tuple[SequencePack, dict[str, LBLMetadata], KVToStore | None]: """Forward pass with MoT routing and optional memory-augmented attention. Returns a 3-tuple: ``(hidden_states, lbl_metadata_dict, kv_to_store)``. @@ -1257,7 +1323,7 @@ def reasoner_forward(self, *args, **kwargs) -> torch.Tensor: # The helpers below run *only* the reasoner tower in standard ``[B, T, H]`` # layout with a per-layer KV cache, enabling an efficient prompt-prefill + # token-by-token decode loop for next-token text generation. Sequence -# packing (``FactoredSequencePack``) is intentionally not used here because +# packing (``SequencePack``) is intentionally not used here because # AR text generation has no full-attention generation tokens to pack with. # ----------------------------------------------------------------------------- @@ -1944,12 +2010,12 @@ def set_input_embeddings(self, value: nn.Embedding) -> None: def forward( self, - pack: FactoredSequencePack, + pack: SequencePack, attention_mask, position_ids: torch.Tensor, natten_metadata_list: list | None = None, memory: MemoryState | None = None, - ) -> tuple[FactoredSequencePack, dict[str, LBLMetadata]]: + ) -> tuple[SequencePack, dict[str, LBLMetadata]]: """Training forward pass — delegates to the dense text model.""" outputs = self.model( pack=pack, @@ -2075,12 +2141,12 @@ def set_input_embeddings(self, value: nn.Embedding) -> None: def forward( self, - pack: FactoredSequencePack, + pack: SequencePack, attention_mask, position_ids: torch.Tensor, natten_metadata_list: list | None = None, memory: MemoryState | None = None, - ) -> tuple[FactoredSequencePack, dict[str, LBLMetadata]]: + ) -> tuple[SequencePack, dict[str, LBLMetadata]]: """Training forward pass — delegates to the MoE text model.""" outputs = self.model( @@ -2203,12 +2269,12 @@ def set_input_embeddings(self, value: nn.Embedding) -> None: def forward( self, - pack: FactoredSequencePack, + pack: SequencePack, attention_mask, position_ids: torch.Tensor, natten_metadata_list: list | None = None, memory: MemoryState | None = None, - ) -> tuple[FactoredSequencePack, dict[str, LBLMetadata]]: + ) -> tuple[SequencePack, dict[str, LBLMetadata]]: return self.model( pack=pack, attention_mask=attention_mask, diff --git a/cosmos_framework/model/vfm/omni_mot_model.py b/cosmos_framework/model/vfm/omni_mot_model.py index 652b930..f725a18 100644 --- a/cosmos_framework/model/vfm/omni_mot_model.py +++ b/cosmos_framework/model/vfm/omni_mot_model.py @@ -27,13 +27,6 @@ from cosmos_framework.model.vfm.algorithm.loss.load_balancing import compute_load_balancing_loss from cosmos_framework.configs.base.defaults.model_config import OmniMoTModelConfig from cosmos_framework.data.vfm.action.action_processing import ActionProcessor, get_action_processing_records -from cosmos_framework.data.vfm.sequence_packing import ( - PackedSequence, - SequencePlan, - add_special_tokens, - build_sequence_plans_from_data_batch, - pack_input_sequence, -) from cosmos_framework.data.vfm.utils import IMAGE_RES_SIZE_INFO, VIDEO_RES_SIZE_INFO from cosmos_framework.model.vfm.diffusion.rectified_flow import RectifiedFlow from cosmos_framework.model.vfm.diffusion.samplers.edm import EDMSampler @@ -53,6 +46,13 @@ from cosmos_framework.model.vfm.utils.memory import MemoryState from cosmos_framework.model.vfm.utils.safetensors_loader import load_language_model as load_language_model_safetensors from cosmos_framework.model.vfm.vlm.qwen3_vl.utils import tokenize_caption +from cosmos_framework.data.vfm.sequence_packing import ( + PackedSequence, + SequencePlan, + build_sequence_plans_from_data_batch, + pack_input_sequence, +) +from cosmos_framework.data.vfm.sequence_packing.modalities import add_special_tokens from cosmos_framework.model.vfm.tokenizers.interface import VideoTokenizerInterface from cosmos_framework.model.vfm.upsampler.prompts import build_messages, clean_response from cosmos_framework.utils.vfm.data_utils import get_vision_data_resolution @@ -237,6 +237,7 @@ def build_net(self, dtype: torch.dtype): parallel_dims=self.parallel_dims, compile_config=self.config.compile, ac_config=self.config.activation_checkpointing, + attention_io_layout=self.config.parallelism.attention_io_layout, ) with misc.timer("meta to cuda and broadcast model states"): @@ -509,10 +510,10 @@ def init_optimizer_scheduler( def _derive_include_end_of_generation_token(self) -> bool: impl = self.config.joint_attn_implementation - assert impl in ("flex", "two_way", "three_way"), ( - f"Invalid joint_attn_implementation: {impl}. Must be 'flex', 'two_way', or 'three_way'." + assert impl in ("two_way", "three_way"), ( + f"Invalid joint_attn_implementation: {impl}. Must be 'two_way' or 'three_way'." ) - return impl == "flex" + return False # ------------------------ training hooks ------------------------ def on_before_zero_grad( @@ -761,8 +762,8 @@ def training_step( ) # [B, T_vis] each # Optional independent action schedule (sampled from rectified_flow_action with - # action-specific shift/high-sigma overrides). Only active when the config opts in and - # the batch contains action data. + # action-specific shift). Only active when the config opts in and the batch contains + # action data. # # Mixed-batch indexing: gen_data_clean.x0_tokens_action (and every packed_sequence.action.* # field) is *dense* — one entry per sample with has_action=True, in the original batch order @@ -975,7 +976,6 @@ def _compute_flow_matching_loss( timesteps: torch.Tensor, has_valid_tokens: bool, rectified_flow: RectifiedFlow, - loss_scale: float | None = None, raw_action_dim: list[torch.Tensor] | None = None, normalize_by_active: bool = False, ) -> torch.Tensor: @@ -993,8 +993,6 @@ def _compute_flow_matching_loss( are handled correctly. has_valid_tokens: Whether this modality has valid noisy tokens. rectified_flow: The rectified flow object for time weighting. - loss_scale: Optional per-modality loss scale. Falls back to the global - ``rectified_flow_training_config.loss_scale`` when *None*. normalize_by_active: When True, normalize per-instance loss by the count of active (noisy) elements rather than all elements. Preserves the ``sum / active_count`` semantics needed for distillation critics where @@ -1014,7 +1012,6 @@ def _compute_flow_matching_loss( has_valid_tokens=has_valid_tokens, rectified_flow=rectified_flow, tensor_kwargs_fp32=self.tensor_kwargs_fp32, - loss_scale=loss_scale, raw_action_dim=raw_action_dim, normalize_by_active=normalize_by_active, ) @@ -1198,8 +1195,8 @@ def _get_train_noise_level_vision( iteration: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ - Sample the rectified flow interpolation coefficient (timesteps), optionally adjust the sampled - timesteps with high sigma strategy, and obtain the corresponding normalized timestep. + Sample the rectified flow interpolation coefficient (timesteps) and obtain the corresponding + normalized timestep. Args: batch_size: Batch size for sampling timesteps. @@ -1271,74 +1268,38 @@ def _get_train_noise_level_vision( # T_max = max(num_vision_latent_frames) across the batch; trailing entries for shorter # sequences are unused (sliced away in _add_noise_to_input). T_max = max(num_vision_latent_frames) - t_raw = ( - rectified_flow.sample_train_time(batch_size * T_max, iteration=iteration) + sigmas = ( + rectified_flow.sample_train_time( + batch_size * T_max, iteration=iteration, shifts=shifts.repeat_interleave(T_max) + ) .to(**self.tensor_kwargs_fp32) .reshape(batch_size, T_max) ) # [B,T_max] else: - t_raw = ( - rectified_flow.sample_train_time(batch_size, iteration=iteration) + sigmas = ( + rectified_flow.sample_train_time(batch_size, iteration=iteration, shifts=shifts) .to(**self.tensor_kwargs_fp32) .unsqueeze(1) ) # [B,1] - # Apply shift and scale: t_raw ∈ [0,1] → timesteps ∈ [0,max_timestep] - # shifts.unsqueeze(1) → [B,1], broadcasts with both [B,1] (base/TF) and [B,T_max] (DF) - t = 1 - t_raw # [B,1] or [B,T_max] - shifts_2d = shifts.unsqueeze(1).to(t_raw.device) # [B,1], broadcasts with [B,1] and [B,T_max] - timesteps = shifts_2d * t / (1 + (shifts_2d - 1) * t) * max_timestep # [B,1] or [B,T_max] - - if self.config.rectified_flow_training_config.use_high_sigma_strategy: - timesteps = self._apply_high_noise_strategy(timesteps, max_timestep) # [B,1] or [B,T_max] - - sigmas = timesteps / max_timestep # [B,1] for base/TF, [B,T_max] for DF + timesteps = sigmas * max_timestep # [B,1] or [B,T_max] return timesteps, sigmas - def _apply_high_noise_strategy(self, timesteps: torch.Tensor, max_timestep: int) -> torch.Tensor: - """ - Update the sampled RF timesteps to shift the distribution towards higher noise levels (high sigmas). - - Args: - timesteps (torch.Tensor): Input timesteps. Shape [B,1] for base/TF or [B,T_max] for DF. - max_timestep (int): The maximum timestep value. - - Returns: - torch.Tensor: Timesteps with the same shape as input — [B,1] or [B,T_max]. - """ - mask = ( - torch.rand(timesteps.shape, device=timesteps.device) - < self.config.rectified_flow_training_config.high_sigma_ratio - ) - new_timesteps = ( - torch.rand(timesteps.shape, device=timesteps.device).type_as(timesteps) - * ( - self.config.rectified_flow_training_config.high_sigma_timesteps_max - - self.config.rectified_flow_training_config.high_sigma_timesteps_min - ) - + self.config.rectified_flow_training_config.high_sigma_timesteps_min - ) - timesteps = torch.where(mask, new_timesteps, timesteps) - - return timesteps - def _get_train_noise_level_action( self, batch_size: int, iteration: int | None = None ) -> tuple[torch.Tensor, torch.Tensor]: """Sample ``(timesteps, sigmas)`` of shape ``[batch_size, 1]`` from ``rectified_flow_action``. This helper is locally-scoped: it just draws ``batch_size`` independent σ values and - applies action-specific shift / high-sigma config. The caller decides what ``batch_size`` - means semantically — ``training_step`` passes the full batch size and then reindexes to + applies action-specific shift config. The caller decides what ``batch_size`` means + semantically — ``training_step`` passes the full batch size and then reindexes to the dense action-bearing subset with ``action_sample_indices``. ``shift_action`` must be an int (or ``None`` to inherit ``shift``). Dict-keyed per-resolution shifts are vision-only — multi-resolution action training would need per-sample lookup, which this helper does not implement; if the global ``shift`` is a dict and ``shift_action`` is None, this raises so the user sets shift_action explicitly. - ``use_high_sigma_strategy_action`` toggles the high-σ strategy for action; when on, the - global ``high_sigma_ratio`` / ``_min`` / ``_max`` apply. σ is a shared scalar per input - slot (no per-frame σ for action). + σ is a shared scalar per input slot (no per-frame σ for action). """ rf_cfg = self.config.rectified_flow_training_config rf = self.rectified_flow_action @@ -1361,17 +1322,13 @@ def _get_train_noise_level_action( f"Got shift={rf_cfg.shift!r}." ) - t_raw = ( - rf.sample_train_time(batch_size, iteration=iteration).to(**self.tensor_kwargs_fp32).unsqueeze(1) + shifts = torch.full((batch_size,), shift_val, dtype=torch.float32) + sigmas = ( + rf.sample_train_time(batch_size, iteration=iteration, shifts=shifts) + .to(**self.tensor_kwargs_fp32) + .unsqueeze(1) ) # [B,1] - t = 1 - t_raw # [B,1] - shifts_2d = torch.full((batch_size, 1), shift_val, dtype=torch.float32, device=t_raw.device) # [B,1] - timesteps = shifts_2d * t / (1 + (shifts_2d - 1) * t) * max_timestep # [B,1] - - if rf_cfg.use_high_sigma_strategy_action: - timesteps = self._apply_high_noise_strategy(timesteps, max_timestep) # [B,1] - - sigmas = timesteps / max_timestep # [B,1] + timesteps = sigmas * max_timestep # [B,1] return timesteps, sigmas def _get_train_noise_level_sound(self, batch_size: int) -> tuple[torch.Tensor, torch.Tensor]: @@ -1401,15 +1358,9 @@ def _get_train_noise_level_sound(self, batch_size: int) -> tuple[torch.Tensor, t f"Got shift={rf_cfg.shift!r}." ) - t_raw = rf.sample_train_time(batch_size).to(**self.tensor_kwargs_fp32).unsqueeze(1) # [B,1] - t = 1 - t_raw # [B,1] - shifts_2d = torch.full((batch_size, 1), shift_val, dtype=torch.float32, device=t_raw.device) # [B,1] - timesteps = shifts_2d * t / (1 + (shifts_2d - 1) * t) * max_timestep # [B,1] - - if rf_cfg.use_high_sigma_strategy_sound: - timesteps = self._apply_high_noise_strategy(timesteps, max_timestep) # [B,1] - - sigmas = timesteps / max_timestep # [B,1] + shifts = torch.full((batch_size,), shift_val, dtype=torch.float32) + sigmas = rf.sample_train_time(batch_size, shifts=shifts).to(**self.tensor_kwargs_fp32).unsqueeze(1) # [B,1] + timesteps = sigmas * max_timestep # [B,1] return timesteps, sigmas def _add_noise_to_input( @@ -1759,18 +1710,21 @@ def _prepare_inference_data( seed_dict["action"].append(seed[sample_idx]) seed_dict["sound"].append(seed[sample_idx]) - # Generate noise and apply conditioning per vision item (supports variable shapes) + # Generate noise and apply conditioning per vision item (supports variable shapes). + # Noise and the conditioning blend are kept in fp32 so the sampler accumulates + # in full precision. x0_token is already fp32 (forced by .float() in + # get_data_and_condition). The cast to model dtype happens inside velocity_fn. noise_vision_list: list[torch.Tensor] = [] for i, (x0_token, cond_mask) in enumerate( zip(gen_data_clean.x0_tokens_vision, packed_sequence.vision.condition_mask, strict=True) ): pure_noise_i = misc.arch_invariant_rand( tuple(x0_token.shape), - self.tensor_kwargs["dtype"], - self.tensor_kwargs["device"], + self.tensor_kwargs_fp32["dtype"], + self.tensor_kwargs_fp32["device"], seed_dict["vision"][i], # Different seed per sample for diversity ) # [C,T,H,W] - noise_i = cond_mask * x0_token.to(**self.tensor_kwargs) + (1.0 - cond_mask) * pure_noise_i # [C,T,H,W] + noise_i = cond_mask * x0_token + (1.0 - cond_mask) * pure_noise_i # [C,T,H,W] noise_vision_list.append(noise_i) # 6. Initialize action noise if action_gen is True @@ -2209,6 +2163,11 @@ def generate_samples_from_batch( sampler: Any | None = None, guidance: float = 1.5, guidance_interval: Optional[list[float]] = None, + velocity_postprocess_builder: Optional[ + Callable[ + ..., Optional[Callable[[list[torch.Tensor], list[torch.Tensor], torch.Tensor], list[torch.Tensor]]] + ] + ] = None, seed: list[int] | int = 1, n_sample: int | None = None, has_negative_prompt: bool = False, @@ -2362,6 +2321,24 @@ def generate_samples_from_batch( assert n_sample == len(seed), f"Number of samples {n_sample} must match number of seeds {len(seed)}" + # Optional per-step velocity postprocess hook. Built once via a builder + # that receives the prepared inference state. The returned callable (if + # any) is invoked after the conditional forward on every step and can + # modify the conditional velocity (e.g. inject control-CFG, attention + # weighting, etc.). The model itself stays agnostic of what the hook + # does — all transfer/edit-specific logic lives in the caller. + velocity_postprocess: Optional[ + Callable[[list[torch.Tensor], list[torch.Tensor], torch.Tensor], list[torch.Tensor]] + ] = None + if velocity_postprocess_builder is not None: + velocity_postprocess = velocity_postprocess_builder( + model=self, + net=net, + cond_tokens=cond_tokens, + sequence_plans=sequence_plans, + gen_data_clean=gen_data_clean, + ) + # Create a velocity function for a single sample (for use with self.sampler). # FSDP collective-sequence alignment (throughput-preset inference). # @@ -2422,34 +2399,64 @@ def _single_velocity_fn(tokens: list[list[int]], skip_text_tokens: bool): skip_text_tokens=skip_text_tokens, ) - # Skip unconditional branch when outside the guidance interval - needs_cfg = guidance != 1.0 - if needs_cfg and guidance_interval is not None: + needs_text_cfg = guidance != 1.0 + if needs_text_cfg and guidance_interval is not None: assert len(guidance_interval) == 2, f"guidance_interval must be [lo, hi], got {guidance_interval}" t_lo, t_hi = guidance_interval - needs_cfg = t_lo < timestep[0].item() < t_hi + needs_text_cfg = t_lo < timestep[0].item() < t_hi - # FSDP alignment: if ANY rank in the shard group needs CFG this - # call, every rank computes both forwards. Cheap 1-element - # all_reduce per velocity_fn call; the alternative (forcing CFG - # always-on globally) would silently ignore the per-timestep - # ``guidance_interval`` gate. + # FSDP alignment: if ANY rank in the shard group needs text-CFG + # this call, every rank must take the CFG path so the allgather + # sequence stays aligned across ranks. if _dp_shard_group is not None: - _cfg_t = torch.tensor([1 if needs_cfg else 0], device=_align_device, dtype=torch.int32) + _cfg_t = torch.tensor([1 if needs_text_cfg else 0], device=_align_device, dtype=torch.int32) torch.distributed.all_reduce(_cfg_t, op=torch.distributed.ReduceOp.MAX, group=_dp_shard_group) - _any_needs_cfg = bool(_cfg_t.item()) + _any_needs_text_cfg = bool(_cfg_t.item()) else: - _any_needs_cfg = needs_cfg + _any_needs_text_cfg = needs_text_cfg - if not _any_needs_cfg: + # Fast path: no text-CFG anywhere and no postprocess hook — single forward. + if not _any_needs_text_cfg and velocity_postprocess is None: return _single_velocity_fn(cond_tokens, skip_text_tokens=False) - cond_v, uncond_v = self._run_classifier_free_guidance( - cond_tokens=cond_tokens, - uncond_tokens=uncond_tokens, - skip_text_tokens_for_cfg=skip_text_tokens_for_cfg, - single_velocity_fn=_single_velocity_fn, - ) + # Fast path: only text-CFG and no postprocess — preserve the + # cfgp-parallel branch so two-rank CFG parallelism stays available. + if velocity_postprocess is None: + cond_v, uncond_v = self._run_classifier_free_guidance( + cond_tokens=cond_tokens, + uncond_tokens=uncond_tokens, + skip_text_tokens_for_cfg=skip_text_tokens_for_cfg, + single_velocity_fn=_single_velocity_fn, + ) + if not needs_text_cfg: + # Peers needed CFG so we ran the uncond forward to keep + # FSDP allgather aligned; locally we still return cond. + return cond_v + v_pred = [u_i + guidance * (c_i - u_i) for c_i, u_i in zip(cond_v, uncond_v)] + if normalize_cfg: + v_pred = [ + v_i * (torch.norm(c_i) / (torch.norm(v_i) + 1e-8)).clamp(min=0.0, max=1.0) + for v_i, c_i in zip(v_pred, cond_v) + ] + return v_pred + + # Conditional forward, then per-step postprocess hook. Hook runs + # sequentially; cfgp parallelism not used on this path. + cond_v_full = _single_velocity_fn(cond_tokens, skip_text_tokens=False) + cond_v = velocity_postprocess(cond_v_full, noise_x, timestep) + + uncond_v = _single_velocity_fn(uncond_tokens, skip_text_tokens=skip_text_tokens_for_cfg) + if not needs_text_cfg: + # Same alignment story as above for the postprocess branch. + return cond_v + + if not needs_cfg: + # This rank doesn't actually need CFG (guidance==1.0 or sigma + # outside guidance_interval). Return cond_v directly so the + # output is bit-identical to the original no-CFG path; the + # uncond_v forward was only run to keep the FSDP allgather + # sequence aligned with peers. + return cond_v if not needs_cfg: # This rank doesn't actually need CFG (guidance==1.0 or sigma @@ -2485,12 +2492,14 @@ def _single_velocity_fn(tokens: list[list[int]], skip_text_tokens: bool): # Run sampler for all samples at once. sampler = sampler or self.sampler scheduler_type = self.config.rectified_flow_inference_config.scheduler_type - if scheduler_type == "unipc": + if isinstance(sampler, FixedStepSampler): + log.info(f"Using sampler: FixedStep (t_list={sampler.t_list}, sample_type={sampler.sample_type})") + elif scheduler_type == "unipc": log.info(f"Using sampler: UniPC (shift={shift}, num_steps={num_steps})") else: log.info(f"Using sampler: EDM (sigma_max={sigma_max}, num_steps={num_steps})") - if scheduler_type == "unipc": + if isinstance(sampler, FixedStepSampler) or scheduler_type == "unipc": latents = sampler( velocity_fn, initial_noise, @@ -2979,7 +2988,7 @@ def _normalize_video_databatch_inplace( if isinstance(item, torch.Tensor): item = [item] assert item[0].dtype == torch.uint8, "Video data is not in uint8 format." - data_batch[input_key][i] = torch.stack(item).to(**self.tensor_kwargs) / 127.5 - 1.0 + data_batch[input_key][i] = torch.stack(item).to(**self.tensor_kwargs_fp32) / 127.5 - 1.0 data_batch[IS_PREPROCESSED_KEY] = True def _normalize_action_databatch( @@ -3119,6 +3128,10 @@ def _augment_image_dim_inplace(self, data_batch: dict[str, torch.Tensor], input_ assert data_batch[input_key][i].shape[2] == 1, ( f"Image data is claimed be augmented while its shape is {data_batch[input_key][i].shape} for sample {i}" ) + assert torch.is_floating_point(data_batch[input_key][i]), "Image data is not in float format." + assert torch.all((data_batch[input_key][i] >= -1.0001) & (data_batch[input_key][i] <= 1.0001)), ( + f"Image data is not in the range [-1, 1]. get data range [{data_batch[input_key][i].min()}, {data_batch[input_key][i].max()}]" + ) return else: new_image_tensor_list = [] @@ -3126,7 +3139,7 @@ def _augment_image_dim_inplace(self, data_batch: dict[str, torch.Tensor], input_ for img_tensor in data_batch[input_key][i]: img_tensor = rearrange(img_tensor, "c h w -> 1 c 1 h w").contiguous() if img_tensor.dtype == torch.uint8: - img_tensor = img_tensor.to(**self.tensor_kwargs) / 127.5 - 1.0 + img_tensor = img_tensor.to(**self.tensor_kwargs_fp32) / 127.5 - 1.0 new_image_tensor_list.append(img_tensor) data_batch[input_key] = new_image_tensor_list data_batch[IS_PREPROCESSED_KEY] = True diff --git a/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_4x32x32.py b/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_4x32x32.py index 18b242a..0a1f76a 100644 --- a/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_4x32x32.py +++ b/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_4x32x32.py @@ -160,7 +160,7 @@ def decode(self, latent: torch.Tensor) -> torch.Tensor: def get_latent_num_frames(self, num_pixel_frames: int) -> int: return (num_pixel_frames + self.model.cfg.num_pad_frames) // self._temporal_compression_factor - def get_pixel_num_frames(self, num_latent_frames: int) -> int: + def get_pixel_num_frames(self, num_latent_frames: int, **kwargs) -> int: return num_latent_frames * self._temporal_compression_factor - self.model.cfg.num_pad_frames @property diff --git a/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_v.py b/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_v.py index 71daf06..976f702 100644 --- a/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_v.py +++ b/cosmos_framework/model/vfm/tokenizers/dc_ae/dc_ae_v.py @@ -795,13 +795,19 @@ def temporal_tiled_encode(self, x: torch.Tensor) -> torch.Tensor: if feat_idx is not None: feat_idx[0] = 0 if feature_cache is not None and self.cfg.compilable: + old_feature_cache = feature_cache feature_cache = [f.clone() if f is not None else None for f in feature_cache] + if old_feature_cache is not None: + old_feature_cache.clear() tile = self.encoder(tile, feature_cache=feature_cache, feat_idx=feat_idx)[0].clone() if remove_padding: valid_latent_t = (actual_t + compression_factor - 1) // compression_factor tile = tile[:, :, :valid_latent_t, :, :] row.append(tile) + if feature_cache is not None: + feature_cache.clear() + result_row = [] for i, tile in enumerate(row): if i > 0: @@ -901,14 +907,14 @@ def dc_ae_v_f32t4_encoder_causal_decoder_chunk_causal_4( latent_channels, num_pad_frames, temporal_remainder, scaling_factor = 64, 7, 1, 0.5704 encoder_width_list = [0, 64, 128, 512, 1024, 1024, 1024] elif name in [ - "dcae4x32x32_c96_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2", + "dcae4x32x32_c96_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_lcr", ]: - latent_channels, num_pad_frames, temporal_remainder, scaling_factor = 96, 7, 1, 0.5185 + latent_channels, num_pad_frames, temporal_remainder, scaling_factor = 96, 7, 1, 0.4766 encoder_width_list = [0, 64, 128, 512, 1024, 1024, 1024] elif name in [ - "dcae4x32x32_c128_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2", + "dcae4x32x32_c128_t120_256p_fps_all_encoder_causal_decoder_chunk_causal_4_nogan_cosmos_pad_7_v0.2_lcr", ]: - latent_channels, num_pad_frames, temporal_remainder, scaling_factor = 128, 7, 1, 0.5209 + latent_channels, num_pad_frames, temporal_remainder, scaling_factor = 128, 7, 1, 0.5637 encoder_width_list = [0, 64, 128, 512, 1024, 1024, 1024] else: raise ValueError(f"model {name} is not supported") diff --git a/cosmos_framework/model/vfm/tokenizers/flux_vae_8x8.py b/cosmos_framework/model/vfm/tokenizers/flux_vae_8x8.py index 3446646..85188b4 100644 --- a/cosmos_framework/model/vfm/tokenizers/flux_vae_8x8.py +++ b/cosmos_framework/model/vfm/tokenizers/flux_vae_8x8.py @@ -450,7 +450,7 @@ def get_latent_num_frames(self, num_pixel_frames: int) -> int: """Get number of latent frames from pixel frames.""" return num_pixel_frames # Flux VAE doesn't compress temporally - def get_pixel_num_frames(self, num_latent_frames: int) -> int: + def get_pixel_num_frames(self, num_latent_frames: int, **kwargs) -> int: """Get number of pixel frames from latent frames.""" return num_latent_frames # Flux VAE doesn't compress temporally diff --git a/cosmos_framework/model/vfm/tokenizers/interface.py b/cosmos_framework/model/vfm/tokenizers/interface.py index 3c023bc..d307626 100644 --- a/cosmos_framework/model/vfm/tokenizers/interface.py +++ b/cosmos_framework/model/vfm/tokenizers/interface.py @@ -49,7 +49,7 @@ def get_latent_num_frames(self, num_pixel_frames: int) -> int: pass @abstractmethod - def get_pixel_num_frames(self, num_latent_frames: int) -> int: + def get_pixel_num_frames(self, num_latent_frames: int, **kwargs) -> int: pass def get_latent_temporal_positions( diff --git a/cosmos_framework/model/vfm/tokenizers/stable_diffusion_vae_8x8.py b/cosmos_framework/model/vfm/tokenizers/stable_diffusion_vae_8x8.py new file mode 100644 index 0000000..2284230 --- /dev/null +++ b/cosmos_framework/model/vfm/tokenizers/stable_diffusion_vae_8x8.py @@ -0,0 +1,215 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Stable Diffusion VAE tokenizer wrapper for DiT image pretraining.""" + +from typing import Any + +import torch + +from cosmos_framework.model.vfm.tokenizers.interface import VideoTokenizerInterface + +_DTYPE_BY_NAME: dict[str, torch.dtype] = { + "float32": torch.float32, + "fp32": torch.float32, + "float16": torch.float16, + "fp16": torch.float16, + "bfloat16": torch.bfloat16, + "bf16": torch.bfloat16, +} + + +def _resolve_dtype(dtype: str | torch.dtype) -> torch.dtype: + """Resolve a string dtype name to a torch dtype.""" + if isinstance(dtype, torch.dtype): + return dtype + if dtype not in _DTYPE_BY_NAME: + raise ValueError(f"Unsupported SD VAE dtype '{dtype}'. Supported values: {sorted(_DTYPE_BY_NAME)}.") + return _DTYPE_BY_NAME[dtype] + + +def _default_device() -> torch.device: + """Return the current CUDA device when available.""" + if torch.cuda.is_available(): + return torch.device("cuda", torch.cuda.current_device()) + return torch.device("cpu") + + +def _config_value(config: Any, key: str, default: Any) -> Any: + """Fetch a diffusers config value from object-like or dict-like configs.""" + if hasattr(config, key): + return getattr(config, key) + if isinstance(config, dict): + return config.get(key, default) + return default + + +class StableDiffusionVAEInterface(VideoTokenizerInterface): + """Stable Diffusion AutoencoderKL adapter using the shared video tokenizer interface.""" + + def __init__( + self, + bucket_name: str = "", + object_store_credential_path_pretrained: str | None = None, + vae_path: str = "stabilityai/sd-vae-ft-ema", + subfolder: str | None = None, + scaling_factor: float | None = 0.18215, + sample_posterior: bool = True, + dtype: str | torch.dtype = "float32", + device: str | None = None, + chunk_duration: int = 1, + spatial_compression_factor: int = 8, + temporal_compression_factor: int = 1, + ) -> None: + super().__init__(object_store_credential_path_pretrained=object_store_credential_path_pretrained) + self.vae_path = vae_path + self.subfolder = subfolder + self.sample_posterior = sample_posterior + self._dtype = _resolve_dtype(dtype) + self.device = torch.device(device) if device is not None else _default_device() + self.chunk_duration = chunk_duration + self.spatial_compression = spatial_compression_factor + self.temporal_compression = temporal_compression_factor + + resolved_vae_path = self._resolve_vae_path(bucket_name=bucket_name, vae_path=vae_path) + self.model = self._load_model(vae_path=resolved_vae_path, subfolder=subfolder) + self.model.eval() + self.model.requires_grad_(False) + self.model.to(device=self.device, dtype=self._dtype) + + model_config = getattr(self.model, "config", None) + self.scaling_factor = float( + scaling_factor if scaling_factor is not None else _config_value(model_config, "scaling_factor", 0.18215) + ) + self._latent_ch = int(_config_value(model_config, "latent_channels", 4)) + self._spatial_resolution = int(_config_value(model_config, "sample_size", 256)) + + def _resolve_vae_path(self, bucket_name: str, vae_path: str) -> str: + """Resolve internal pretrained paths while leaving Hugging Face repo ids unchanged.""" + if vae_path.startswith("pretrained/") and bucket_name: + return f"s3://{bucket_name}/{vae_path}" + return vae_path + + def _load_model(self, vae_path: str, subfolder: str | None) -> torch.nn.Module: + """Load a diffusers AutoencoderKL model.""" + try: + from diffusers import AutoencoderKL + except ImportError as error: + raise ImportError( + "StableDiffusionVAEInterface requires diffusers. Install diffusers or use wan2pt1_tokenizer." + ) from error + + kwargs: dict[str, object] = {"torch_dtype": self._dtype} + if subfolder is not None: + kwargs["subfolder"] = subfolder + if vae_path.startswith("s3://"): + raise ValueError( + "StableDiffusionVAEInterface expects a Hugging Face repo id or local diffusers VAE directory, " + f"not an S3 path: {vae_path}" + ) + return AutoencoderKL.from_pretrained(vae_path, **kwargs) + + @property + def dtype(self) -> torch.dtype: + """Model compute dtype.""" + return self._dtype + + def reset_dtype(self) -> None: + """Reset the dtype of the model.""" + self.model.to(device=self.device, dtype=self._dtype) + + @torch.no_grad() + def encode(self, state: torch.Tensor) -> torch.Tensor: + """Encode normalized pixels in [-1, 1] to scaled SD VAE latents.""" + if state.dim() != 5: + raise ValueError(f"Expected state tensor [B,3,T,H,W], got shape {tuple(state.shape)}.") + batch_size, channels, num_frames, height, width = state.shape + if channels != 3: + raise ValueError(f"Expected 3 input channels, got {channels}.") + + frames = state.permute(0, 2, 1, 3, 4).contiguous() # [B,T,3,H,W] + frames = frames.view(batch_size * num_frames, channels, height, width) # [B*T,3,H,W] + frames = frames.to(device=self.device, dtype=self._dtype) # [B*T,3,H,W] + + posterior = self.model.encode(frames).latent_dist + if self.sample_posterior: + latents_2d = posterior.sample() # [B*T,4,H//8,W//8] + else: + latents_2d = posterior.mode() # [B*T,4,H//8,W//8] + latents_2d = latents_2d * self.scaling_factor # [B*T,4,H//8,W//8] + + latent_channels = latents_2d.shape[1] + latent_height = latents_2d.shape[2] + latent_width = latents_2d.shape[3] + latents = latents_2d.view( + batch_size, num_frames, latent_channels, latent_height, latent_width + ) # [B,T,4,H//8,W//8] + latents = latents.permute(0, 2, 1, 3, 4).contiguous() # [B,4,T,H//8,W//8] + return latents + + @torch.no_grad() + def decode(self, latent: torch.Tensor) -> torch.Tensor: + """Decode scaled SD VAE latents to normalized pixels in [-1, 1].""" + if latent.dim() != 5: + raise ValueError(f"Expected latent tensor [B,4,T,H,W], got shape {tuple(latent.shape)}.") + batch_size, channels, num_frames, height, width = latent.shape + if channels != self._latent_ch: + raise ValueError(f"Expected {self._latent_ch} latent channels, got {channels}.") + + latents_2d = latent.permute(0, 2, 1, 3, 4).contiguous() # [B,T,4,H,W] + latents_2d = latents_2d.view(batch_size * num_frames, channels, height, width) # [B*T,4,H,W] + latents_2d = latents_2d.to(device=self.device, dtype=self._dtype) / self.scaling_factor # [B*T,4,H,W] + + decoded_2d = self.model.decode(latents_2d).sample # [B*T,3,H*8,W*8] + decoded_height = decoded_2d.shape[2] + decoded_width = decoded_2d.shape[3] + decoded = decoded_2d.view(batch_size, num_frames, 3, decoded_height, decoded_width) # [B,T,3,H*8,W*8] + decoded = decoded.permute(0, 2, 1, 3, 4).contiguous() # [B,3,T,H*8,W*8] + return decoded + + def get_latent_num_frames(self, num_pixel_frames: int) -> int: + """Get number of latent frames from pixel frames.""" + return num_pixel_frames + + def get_pixel_num_frames(self, num_latent_frames: int) -> int: + """Get number of pixel frames from latent frames.""" + return num_latent_frames + + @property + def spatial_compression_factor(self) -> int: + """Spatial compression factor.""" + return self.spatial_compression + + @property + def temporal_compression_factor(self) -> int: + """Temporal compression factor.""" + return self.temporal_compression + + @property + def spatial_resolution(self) -> int: + """Spatial resolution.""" + return self._spatial_resolution + + @property + def pixel_chunk_duration(self) -> int: + """Pixel chunk duration.""" + return self.chunk_duration + + @property + def latent_chunk_duration(self) -> int: + """Latent chunk duration.""" + return self.get_latent_num_frames(self.chunk_duration) + + @property + def latent_ch(self) -> int: + """Number of latent channels.""" + return self._latent_ch + + @property + def name(self) -> str: + """Name of the tokenizer.""" + return "sd_vae_tokenizer" + + def count_param(self) -> int: + """Count the number of parameters in the model.""" + return sum(parameter.numel() for parameter in self.model.parameters()) diff --git a/cosmos_framework/model/vfm/tokenizers/uniae/frame_math.py b/cosmos_framework/model/vfm/tokenizers/uniae/frame_math.py index 5e9be83..9ab1eb9 100644 --- a/cosmos_framework/model/vfm/tokenizers/uniae/frame_math.py +++ b/cosmos_framework/model/vfm/tokenizers/uniae/frame_math.py @@ -311,6 +311,44 @@ def align_uniae_num_video_frames( return num_video_frames +def ceil_uniae_num_video_frames( + num_video_frames: int, + uniae_chunk_frames: int | Mapping[str, int], + *, + pad_frames: int, + temporal_compression_factor: int, + resolution: str | None = None, + spatial_shape: tuple[int, int] | None = None, + target_resolution_key: str | None = None, + missing_resolution_message: str = ( + "spatial_shape or target resolution must be provided for resolution-keyed UniAE chunks" + ), +) -> int: + """Round up to the nearest valid UniAE noncausal count, preserving valid partial tails.""" + if num_video_frames < 1: + return 0 + + for candidate in range(num_video_frames, num_video_frames + temporal_compression_factor + 1): + aligned_candidate = align_uniae_num_video_frames( + candidate, + uniae_chunk_frames, + pad_frames=pad_frames, + temporal_compression_factor=temporal_compression_factor, + resolution=resolution, + spatial_shape=spatial_shape, + target_resolution_key=target_resolution_key, + missing_resolution_message=missing_resolution_message, + ) + if aligned_candidate == candidate: + return candidate + + raise RuntimeError( + "Failed to find a valid UniAE frame count within one temporal-compression window: " + f"{num_video_frames=}, {uniae_chunk_frames=}, {pad_frames=}, {temporal_compression_factor=}, " + f"{resolution=}, {spatial_shape=}, {target_resolution_key=}." + ) + + def _validate_full_chunk( full_chunk: int, *, diff --git a/cosmos_framework/model/vfm/tokenizers/uniae/frame_math_test.py b/cosmos_framework/model/vfm/tokenizers/uniae/frame_math_test.py new file mode 100644 index 0000000..b233ac1 --- /dev/null +++ b/cosmos_framework/model/vfm/tokenizers/uniae/frame_math_test.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from cosmos_framework.model.vfm.tokenizers.uniae.frame_math import ( + align_uniae_num_video_frames, + ceil_uniae_num_video_frames, +) + + +def test_ceil_uniae_num_video_frames_preserves_valid_partial_tail() -> None: + assert ( + ceil_uniae_num_video_frames( + 17, + {"480": 16}, + pad_frames=1, + temporal_compression_factor=4, + resolution="480", + ) + == 17 + ) + + +def test_ceil_uniae_num_video_frames_uses_next_valid_partial_tail() -> None: + assert ( + align_uniae_num_video_frames( + 24, + {"480": 16}, + pad_frames=1, + temporal_compression_factor=4, + resolution="480", + ) + == 21 + ) + assert ( + ceil_uniae_num_video_frames( + 24, + {"480": 16}, + pad_frames=1, + temporal_compression_factor=4, + resolution="480", + ) + == 25 + ) diff --git a/cosmos_framework/model/vfm/tokenizers/uniae/noncausal_4x16x16.py b/cosmos_framework/model/vfm/tokenizers/uniae/noncausal_4x16x16.py index 57b168e..445f88f 100644 --- a/cosmos_framework/model/vfm/tokenizers/uniae/noncausal_4x16x16.py +++ b/cosmos_framework/model/vfm/tokenizers/uniae/noncausal_4x16x16.py @@ -511,6 +511,10 @@ def latent_chunk_duration(self): "Use encode_chunk_frames[res_key] // temporal_compression_factor. Will be removed in a future MR." ) + @property + def pad_frames(self) -> int: + return self.vae._pad_frames + @property def latent_ch(self) -> int: return self.vae.z_dim diff --git a/cosmos_framework/model/vfm/tokenizers/wan2pt1_vae_4x8x8.py b/cosmos_framework/model/vfm/tokenizers/wan2pt1_vae_4x8x8.py index 542c90f..f1368a0 100644 --- a/cosmos_framework/model/vfm/tokenizers/wan2pt1_vae_4x8x8.py +++ b/cosmos_framework/model/vfm/tokenizers/wan2pt1_vae_4x8x8.py @@ -814,7 +814,7 @@ def decode(self, latent: torch.Tensor) -> torch.Tensor: # latent: [B,C,T_latent def get_latent_num_frames(self, num_pixel_frames: int) -> int: return 1 + (num_pixel_frames - 1) // 4 - def get_pixel_num_frames(self, num_latent_frames: int) -> int: + def get_pixel_num_frames(self, num_latent_frames: int, **kwargs) -> int: return (num_latent_frames - 1) * 4 + 1 @property diff --git a/cosmos_framework/model/vfm/tokenizers/wan2pt2_vae_4x16x16.py b/cosmos_framework/model/vfm/tokenizers/wan2pt2_vae_4x16x16.py index 12dc6d5..a796830 100644 --- a/cosmos_framework/model/vfm/tokenizers/wan2pt2_vae_4x16x16.py +++ b/cosmos_framework/model/vfm/tokenizers/wan2pt2_vae_4x16x16.py @@ -1653,7 +1653,7 @@ def _get_ref_caches( def get_latent_num_frames(self, num_pixel_frames: int) -> int: return 1 + (num_pixel_frames - 1) // 4 - def get_pixel_num_frames(self, num_latent_frames: int) -> int: + def get_pixel_num_frames(self, num_latent_frames: int, **kwargs) -> int: return (num_latent_frames - 1) * 4 + 1 @property diff --git a/cosmos_framework/model/vfm/utils/memory.py b/cosmos_framework/model/vfm/utils/memory.py index b2671fa..115a8d3 100644 --- a/cosmos_framework/model/vfm/utils/memory.py +++ b/cosmos_framework/model/vfm/utils/memory.py @@ -64,7 +64,7 @@ def init(self, hidden_states: dict, device: torch.device) -> None: Called once before any transformer layers are processed. Args: - hidden_states: The packed sequence (``FactoredSequencePack``). + hidden_states: The packed sequence (``SequencePack``). device: Target device for any new tensors. """ diff --git a/cosmos_framework/model/vfm/vlm/qwen3_vl/qwen3_vl.py b/cosmos_framework/model/vfm/vlm/qwen3_vl/qwen3_vl.py index 6a4d772..aa92a54 100644 --- a/cosmos_framework/model/vfm/vlm/qwen3_vl/qwen3_vl.py +++ b/cosmos_framework/model/vfm/vlm/qwen3_vl/qwen3_vl.py @@ -354,15 +354,18 @@ def apply_interleaved_mrope(self, freqs, mrope_section): @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): assert self.inv_freq.dtype == torch.float32, f"inv_freq must be float32, but got {self.inv_freq.dtype}" + assert position_ids.dtype in [torch.long, torch.float32], ( + f"position_ids must be long or float32, but got {position_ids.dtype}" + ) # In contrast to other models, Qwen3VL has different position ids for the grids # So we expand the inv_freq to shape (3, ...) if position_ids.ndim == 2: position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) # [3,B,N] inv_freq_expanded = ( - self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1).to(x.device) + self.inv_freq[None, None, :, None].expand(3, position_ids.shape[1], -1, 1).to(x.device) ) # [3,B,head_dim//2,1] - position_ids_expanded = position_ids[:, :, None, :].float() # [3,B,1,N] + position_ids_expanded = position_ids[:, :, None, :] # [3,B,1,N] freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) # [3,B,N,head_dim//2] freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) # [B,N,head_dim//2] diff --git a/cosmos_framework/model/vfm/vlm/qwen3_vl/utils.py b/cosmos_framework/model/vfm/vlm/qwen3_vl/utils.py index e39ad3f..0b012fb 100644 --- a/cosmos_framework/model/vfm/vlm/qwen3_vl/utils.py +++ b/cosmos_framework/model/vfm/vlm/qwen3_vl/utils.py @@ -148,7 +148,7 @@ def prepare_padding_mask( # Pad it if necessary if (padding_length := kv_length + kv_offset - attention_mask.shape[-1]) > 0: local_padding_mask = torch.nn.functional.pad(attention_mask, (0, padding_length)) - # For flex, we should not slice them, only use an offset + # Some callers require an unsliced mask and apply the offset separately. if _slice: # Equivalent to: `local_padding_mask = attention_mask[:, kv_offset : kv_offset + kv_length]`, # but without data-dependent slicing (i.e. torch.compile friendly) diff --git a/cosmos_framework/model/vfm/vlm/qwen3_vl_moe/qwen3_vl_moe.py b/cosmos_framework/model/vfm/vlm/qwen3_vl_moe/qwen3_vl_moe.py index bd02d70..16144cc 100644 --- a/cosmos_framework/model/vfm/vlm/qwen3_vl_moe/qwen3_vl_moe.py +++ b/cosmos_framework/model/vfm/vlm/qwen3_vl_moe/qwen3_vl_moe.py @@ -1143,15 +1143,18 @@ def apply_interleaved_mrope(self, freqs, mrope_section): @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): assert self.inv_freq.dtype == torch.float32, f"inv_freq must be float32, but got {self.inv_freq.dtype}" + assert position_ids.dtype in [torch.long, torch.float32], ( + f"position_ids must be long or float32, but got {position_ids.dtype}" + ) # In contrast to other models, Qwen3VLMoe has different position ids for the grids # So we expand the inv_freq to shape (3, ...) if position_ids.ndim == 2: position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) # [3,B,N] - inv_freq_expanded = ( - self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + inv_freq_expanded = self.inv_freq[None, None, :, None].expand( + 3, position_ids.shape[1], -1, 1 ) # [3,B,head_dim//2,1] - position_ids_expanded = position_ids[:, :, None, :].float() # [3,B,1,N] + position_ids_expanded = position_ids[:, :, None, :] # [3,B,1,N] freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) # [3,B,N,head_dim//2] freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) # [B,N,head_dim//2] diff --git a/cosmos_framework/utils/config_helper.py b/cosmos_framework/utils/config_helper.py index 88c8f07..e81dc81 100644 --- a/cosmos_framework/utils/config_helper.py +++ b/cosmos_framework/utils/config_helper.py @@ -6,6 +6,7 @@ import os import pkgutil import sys +import threading from dataclasses import fields as dataclass_fields from dataclasses import is_dataclass from typing import Any, Dict, Optional @@ -20,6 +21,8 @@ from cosmos_framework.utils.config import Config from cosmos_framework.utils import log +_HYDRA_LOCK = threading.RLock() + def is_attrs_or_dataclass(obj) -> bool: """ @@ -77,16 +80,20 @@ def remove_defaults_filter(f, _): f'Hydra config overrides must be separated with a "--" token. but got overrides={overrides}, and overrides[0]={overrides[0]}' ) overrides = overrides[1:] - # Use Hydra to handle overrides - cs = ConfigStore.instance() - cs.store(name="config", node=config_omegaconf) - if not GlobalHydra().is_initialized(): - with initialize(version_base=None): + # Use Hydra to handle overrides. Hydra's GlobalHydra and ConfigStore are + # process-global; this function stores the caller's root config under the + # fixed name "config" and then composes that name, so concurrent calls can + # observe each other's root config or race GlobalHydra initialization. + with _HYDRA_LOCK: + cs = ConfigStore.instance() + cs.store(name="config", node=config_omegaconf) + if not GlobalHydra().is_initialized(): + with initialize(version_base=None): + config_omegaconf = compose(config_name="config", overrides=overrides) + OmegaConf.resolve(config_omegaconf) + else: config_omegaconf = compose(config_name="config", overrides=overrides) OmegaConf.resolve(config_omegaconf) - else: - config_omegaconf = compose(config_name="config", overrides=overrides) - OmegaConf.resolve(config_omegaconf) def config_from_dict(ref_instance: Any, kwargs: Any) -> Any: """ diff --git a/cosmos_framework/utils/easy_io/transient_retry.py b/cosmos_framework/utils/easy_io/transient_retry.py new file mode 100644 index 0000000..9eca405 --- /dev/null +++ b/cosmos_framework/utils/easy_io/transient_retry.py @@ -0,0 +1,135 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Shared retry helpers for transient object-store / transport errors. + +This module is the single source of truth for what counts as a *transient* +(retryable) error when talking to an object store, and a small synchronous +``retry_on_transient_error`` wrapper around individual operations (e.g. an +``easy_io`` ``get``/``exists`` call). + +It lives under ``easy_io`` -- the lowest-level storage layer -- so any caller +(datasets, ``ObjectStore``, etc.) can depend on it without creating a reverse +dependency on the higher-level dataset code. +""" + +from __future__ import annotations + +import time +from http.client import IncompleteRead +from typing import Callable, TypeVar + +from botocore.exceptions import ( + ConnectionClosedError, + EndpointConnectionError, + ResponseStreamingError, +) +from botocore.exceptions import ( + ReadTimeoutError as BotocoreReadTimeoutError, +) +from multistorageclient.types import RetryableError +from urllib3.exceptions import ProtocolError as URLLib3ProtocolError +from urllib3.exceptions import ReadTimeoutError as URLLib3ReadTimeoutError +from urllib3.exceptions import SSLError as URLLib3SSLError + +from cosmos_framework.utils import log + +__all__ = [ + "RETRYABLE_EXCEPTIONS", + "is_retryable_exception", + "retry_on_transient_error", +] + +T = TypeVar("T") + +# Exceptions that indicate a *transient* transport failure and are worth +# retrying. These are connection/timeout/SSL/protocol errors -- never logical +# errors (e.g. a missing key, bad credentials, or a programming bug). +RETRYABLE_EXCEPTIONS = ( + # built-in + IOError, + # http + IncompleteRead, + # urllib3 + URLLib3ReadTimeoutError, + URLLib3ProtocolError, + URLLib3SSLError, + # AWS SDK for Python (boto) + BotocoreReadTimeoutError, + ConnectionClosedError, + EndpointConnectionError, + ResponseStreamingError, + # NVIDIA Multi-Storage Client (MSC) + RetryableError, +) + +# Default retry policy. Tuned for transient SSL / connection resets seen on +# long-running training jobs: a handful of attempts with exponential backoff +# is enough to ride out a brief blip without masking a real outage for long. +DEFAULT_MAX_RETRIES = 5 +DEFAULT_BASE_DELAY_S = 0.5 + + +def is_retryable_exception(exc: BaseException) -> bool: + """Return True if ``exc`` -- or any exception in its cause/context chain -- is retryable. + + Object-store clients frequently re-wrap a transient transport error (e.g. an + ``SSLError``) inside an opaque, non-retryable exception. Walking the + ``__cause__``/``__context__`` chain lets us still recognise it. + """ + seen: set[int] = set() + while exc is not None and id(exc) not in seen: + if isinstance(exc, RETRYABLE_EXCEPTIONS): + return True + seen.add(id(exc)) + exc = exc.__cause__ or exc.__context__ + return False + + +def retry_on_transient_error( + func: Callable[[], T], + operation: str, + max_retries: int = DEFAULT_MAX_RETRIES, + base_delay: float = DEFAULT_BASE_DELAY_S, +) -> T: + """Call ``func`` and retry it on transient transport errors with exponential backoff. + + Args: + func: Zero-argument callable performing a single object-store operation. + operation: Human-readable label for logging (e.g. ``"load_object(foo.pt)"``). + max_retries: Total number of attempts (must be >= 1). + base_delay: Base delay in seconds; attempt ``i`` (0-indexed) sleeps + ``base_delay * 2**i`` before the next try. + + Returns: + Whatever ``func`` returns on the first successful attempt. + + Raises: + The last exception if all attempts fail, or immediately if a raised + exception is not transient (see :func:`is_retryable_exception`). + """ + assert max_retries >= 1, f"max_retries must be >= 1, got {max_retries}" + + for attempt in range(max_retries): + try: + return func() + except Exception as e: + # Re-raise anything that is not a transient transport error, unchanged. + if not is_retryable_exception(e): + raise + # Out of retries: re-raise the last (transient) exception. + if attempt == max_retries - 1: + log.warning( + f"[{operation}] {type(e).__name__}: {e} -- giving up after {max_retries} attempts", + rank0_only=False, + ) + raise + delay = base_delay * 2**attempt + log.warning( + f"[{operation}] {type(e).__name__}: {e} -- retry {attempt + 1}/{max_retries} in {delay:.1f}s", + rank0_only=False, + ) + time.sleep(delay) + + # Unreachable: the loop either returns or raises on the final attempt. + raise AssertionError("retry_on_transient_error exited its loop without returning or raising") diff --git a/cosmos_framework/utils/easy_io/transient_retry_test.py b/cosmos_framework/utils/easy_io/transient_retry_test.py new file mode 100644 index 0000000..7fa1364 --- /dev/null +++ b/cosmos_framework/utils/easy_io/transient_retry_test.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +from unittest import mock + +import pytest +from urllib3.exceptions import SSLError as URLLib3SSLError + +from cosmos_framework.utils.easy_io.transient_retry import ( + is_retryable_exception, + retry_on_transient_error, +) + + +@pytest.mark.L0 +@pytest.mark.CPU +def test_is_retryable_exception_direct(): + # A transient transport error is retryable. + assert is_retryable_exception(URLLib3SSLError("ssl handshake failed")) + assert is_retryable_exception(IOError("connection reset")) + # A logical/programming error is not. + assert not is_retryable_exception(ValueError("bad value")) + assert not is_retryable_exception(KeyError("missing")) + + +@pytest.mark.L0 +@pytest.mark.CPU +def test_is_retryable_exception_walks_cause_chain(): + # An opaque wrapper that re-raises `from` a transient error is still retryable. + try: + try: + raise URLLib3SSLError("ssl handshake failed") + except URLLib3SSLError as inner: + raise RuntimeError("opaque client error") from inner + except RuntimeError as wrapped: + assert is_retryable_exception(wrapped) + + # ...and via the implicit __context__ chain (raise during handling). + try: + try: + raise URLLib3SSLError("ssl handshake failed") + except URLLib3SSLError: + raise RuntimeError("opaque client error") + except RuntimeError as wrapped: + assert is_retryable_exception(wrapped) + + +@pytest.mark.L0 +@pytest.mark.CPU +def test_retry_returns_on_success_without_retrying(): + func = mock.Mock(return_value="ok") + with mock.patch("cosmos_framework.utils.easy_io.transient_retry.time.sleep") as sleep: + result = retry_on_transient_error(func, operation="op") + assert result == "ok" + assert func.call_count == 1 + sleep.assert_not_called() + + +@pytest.mark.L0 +@pytest.mark.CPU +def test_retry_then_success(): + # Fail twice with a transient error, then succeed. + func = mock.Mock( + side_effect=[ + URLLib3SSLError("ssl"), + URLLib3SSLError("ssl"), + "ok", + ] + ) + with mock.patch("cosmos_framework.utils.easy_io.transient_retry.time.sleep") as sleep: + result = retry_on_transient_error(func, operation="op", max_retries=5, base_delay=0.5) + assert result == "ok" + assert func.call_count == 3 + # Exponential backoff between the two failures: 0.5 * 2**0, then 0.5 * 2**1. + assert [c.args[0] for c in sleep.call_args_list] == [0.5, 1.0] + + +@pytest.mark.L0 +@pytest.mark.CPU +def test_non_retryable_reraised_immediately(): + func = mock.Mock(side_effect=ValueError("logical error")) + with mock.patch("cosmos_framework.utils.easy_io.transient_retry.time.sleep") as sleep: + with pytest.raises(ValueError, match="logical error"): + retry_on_transient_error(func, operation="op", max_retries=5) + # No retry for a non-transient error. + assert func.call_count == 1 + sleep.assert_not_called() + + +@pytest.mark.L0 +@pytest.mark.CPU +def test_retry_exhausted_reraises_last_exception(): + func = mock.Mock(side_effect=URLLib3SSLError("persistent ssl failure")) + with mock.patch("cosmos_framework.utils.easy_io.transient_retry.time.sleep") as sleep: + with pytest.raises(URLLib3SSLError, match="persistent ssl failure"): + retry_on_transient_error(func, operation="op", max_retries=3, base_delay=0.5) + assert func.call_count == 3 + # Slept before attempts 2 and 3, but not after the final failure. + assert [c.args[0] for c in sleep.call_args_list] == [0.5, 1.0] diff --git a/cosmos_framework/utils/functional/lr_scheduler.py b/cosmos_framework/utils/functional/lr_scheduler.py index ceaa93d..f6470f5 100644 --- a/cosmos_framework/utils/functional/lr_scheduler.py +++ b/cosmos_framework/utils/functional/lr_scheduler.py @@ -164,3 +164,98 @@ def schedule(self, n, **kwargs): ) self.last_f = f return f + + +class WSDScheduler: + """Warmup-Stable-Decay (WSD) learning rate scheduler for LLM pretraining. + + Three phases: + 1. **Warmup** (steps 0 .. warm_up_steps-1): + Linear ramp from ``f_start`` to ``f_max``. + 2. **Stable** (steps warm_up_steps .. total_steps - decay_steps - 1): + Constant at ``f_max``. + 3. **Decay** (last ``decay_steps`` steps): + Anneal from ``f_max`` to ``f_min`` using ``decay_type``. + + After ``total_steps`` the multiplier holds at ``f_min`` indefinitely, + so training can safely overshoot ``max_iter`` without crashing. + + Reference: MiniCPM / Warmup-Stable-Decay schedule + (https://arxiv.org/abs/2404.06395) + + Parameters: + warm_up_steps: Number of linear warmup steps. + total_steps: Total training steps (warmup + stable + decay). + decay_steps: Number of decay steps at the end. + decay_type: Decay curve shape — ``"cosine"`` or ``"linear"``. + f_start: LR multiplier at step 0. + f_max: LR multiplier during the stable phase. + f_min: LR multiplier after decay completes. + verbosity_interval: Log every N steps (0 = silent). + + Examples: + >>> scheduler = WSDScheduler( + warm_up_steps=2000, total_steps=50000, decay_steps=5000, + decay_type="cosine", f_start=0.01, f_max=1.0, f_min=0.1) + >>> for step in range(55000): + >>> lr_multiplier = scheduler(step) + """ + + def __init__( + self, + warm_up_steps: int, + total_steps: int, + decay_steps: int, + f_start: float, + f_max: float, + f_min: float, + decay_type: str = "cosine", + verbosity_interval: int = 0, + ): + if decay_type not in ("cosine", "linear"): + raise ValueError(f"decay_type must be 'cosine' or 'linear' now, got '{decay_type}'") + self.warm_up_steps = warm_up_steps + self.total_steps = total_steps + self.decay_steps = decay_steps + self.decay_type = decay_type + self.stable_end = total_steps - decay_steps + self.f_start = f_start + self.f_max = f_max + self.f_min = f_min + self.verbosity_interval = verbosity_interval + self.last_f = 0.0 + self._model = None + + @property + def model(self): + return self._model + + @model.setter + def model(self, model): + self._model = model + + def schedule(self, n, **kwargs): + if n < self.warm_up_steps: + # Warmup: linear ramp + f = self.f_start + (self.f_max - self.f_start) * n / self.warm_up_steps + elif n < self.stable_end: + # Stable: constant + f = self.f_max + elif n < self.total_steps: + # Decay + t = (n - self.stable_end) / self.decay_steps + if self.decay_type == "cosine": + f = self.f_min + 0.5 * (self.f_max - self.f_min) * (1 + np.cos(t * np.pi)) + else: # linear + f = self.f_max + (self.f_min - self.f_max) * t + else: + # Past total_steps: hold at f_min + f = self.f_min + + self.last_f = f + if self.verbosity_interval > 0 and n % self.verbosity_interval == 0: + log.info(f"current step: {n}, lr-multiplier: {f:.6f}") + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) diff --git a/cosmos_framework/utils/object_store.py b/cosmos_framework/utils/object_store.py index 58cd2b1..5054571 100644 --- a/cosmos_framework/utils/object_store.py +++ b/cosmos_framework/utils/object_store.py @@ -18,6 +18,7 @@ from cosmos_framework.utils import distributed, log from cosmos_framework.utils.easy_io import easy_io +from cosmos_framework.utils.easy_io.transient_retry import retry_on_transient_error Image.MAX_IMAGE_PIXELS = None @@ -87,7 +88,13 @@ def load_object( """ assert type is not None or load_func is not None, "Either type or load_func should be specified." - buffer = io.BytesIO(self.easy_io_backend.get(filepath=self._translate_key(key=key))) + path = self._translate_key(key=key) + buffer = io.BytesIO( + retry_on_transient_error( + lambda: self.easy_io_backend.get(filepath=path), + operation=f"load_object({key})", + ) + ) buffer.seek(0) # Read from buffer for common data types. @@ -177,7 +184,11 @@ def object_exists(self, key: str) -> bool: Returns: bool: True if the object exists, False if not. """ - return self.easy_io_backend.exists(filepath=self._translate_key(key=key)) + path = self._translate_key(key=key) + return retry_on_transient_error( + lambda: self.easy_io_backend.exists(filepath=path), + operation=f"object_exists({key})", + ) def sync_s3_dir_to_local( diff --git a/cosmos_framework/utils/vfm/model_loader.py b/cosmos_framework/utils/vfm/model_loader.py index c180ca3..6e6a0dd 100644 --- a/cosmos_framework/utils/vfm/model_loader.py +++ b/cosmos_framework/utils/vfm/model_loader.py @@ -18,7 +18,21 @@ try: from filelock import SoftReadWriteLock except ImportError: # Older filelock versions in some inference containers. - from filelock import ReadWriteLock as SoftReadWriteLock + try: + from filelock import ReadWriteLock as SoftReadWriteLock + except ImportError: + from filelock import FileLock + + class SoftReadWriteLock: + """Compatibility adapter for filelock versions without read/write locks.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + self._lock = FileLock(*args, **kwargs) + + def write_lock(self) -> FileLock: + return self._lock + + from torch.distributed.checkpoint.filesystem import FileSystemReader, FileSystemWriter from cosmos_framework.checkpoint.s3_filesystem import S3StorageReader @@ -171,6 +185,32 @@ def _checkpoint_cache_group_lock( yield action +def _reload_pretrained_reasoner_after_checkpoint_load(model: torch.nn.Module) -> None: + """Re-seed the reasoner pathway after a DCP load, mirroring the LoadPretrained + callback that runs during training (inference does not run training callbacks). + + The decision is delegated entirely to the model's own gate in + ``load_pretrained_model_if_needed``: this is a no-op unless the model was built + with ``exclude_reasoner_weights_from_checkpoint=True`` (and pretrained weights + enabled), i.e. the case where the DCP checkpoint deliberately omits the reasoner + tower so it must be re-seeded from the pretrained source. For a normal checkpoint + that already contains the reasoner, the model's gate evaluates to False and + nothing is reloaded. + + ``has_resumable_checkpoint=True`` / ``has_load_path=False`` is load-bearing: it + re-seeds the reasoner from the pretrained source while skipping the + understanding->generation copy (the generation pathway was already populated by + the DCP load). Passing ``has_load_path=True`` would instead force a reasoner + reload even for non-excluded checkpoints, clobbering any fine-tuned reasoner + weights restored from the DCP. + """ + load_pretrained_model_if_needed = getattr(model, "load_pretrained_model_if_needed") + load_pretrained_model_if_needed( + has_resumable_checkpoint=True, + has_load_path=False, + ) + + def _load_model( model: torch.nn.Module, checkpoint_path: str, @@ -194,6 +234,9 @@ def _load_model( start_time = time.time() state_dict = ModelWrapper(model).state_dict() + if any(key.startswith("net_teacher.") for key in state_dict): + log.info("Dropping net_teacher.* keys from inference load target; distillation checkpoints do not save them.") + state_dict = {key: value for key, value in state_dict.items() if not key.startswith("net_teacher.")} if checkpoint_path.startswith("s3://"): storage_reader = S3StorageReader( @@ -351,6 +394,16 @@ def load_model_from_checkpoint( # Disable EMA for inference. config.model.config.ema.enabled = False + if hasattr(config.model.config, "load_teacher_weights"): + log.info("Setting load_teacher_weights=False for inference to skip teacher checkpoint download.") + config.model.config.load_teacher_weights = False + + if ( + config.model.config.exclude_reasoner_weights_from_checkpoint + and not config.model.config.vlm_config.pretrained_weights.enabled + ): + log.info("Enabling pretrained reasoner weights because this checkpoint excludes the reasoner tower from DCP.") + config.model.config.vlm_config.pretrained_weights.enabled = True config.validate() config.freeze() # type: ignore @@ -426,6 +479,7 @@ def load_model(checkpoint_load_path: str) -> None: if checkpoint_cache_path is None: load_model(checkpoint_path) + _reload_pretrained_reasoner_after_checkpoint_load(model) return model, config cache_lock_path = f"{checkpoint_cache_path}.lock" @@ -443,4 +497,6 @@ def load_model(checkpoint_load_path: str) -> None: if cache_action == _CheckpointCacheAction.LOAD_CACHE: load_model(checkpoint_cache_path) + _reload_pretrained_reasoner_after_checkpoint_load(model) + return model, config diff --git a/cosmos_framework/utils/vfm/optimizer.py b/cosmos_framework/utils/vfm/optimizer.py index b9947a6..1930cbc 100644 --- a/cosmos_framework/utils/vfm/optimizer.py +++ b/cosmos_framework/utils/vfm/optimizer.py @@ -12,7 +12,7 @@ from torch.distributed.checkpoint.stateful import Stateful from torch.optim.lr_scheduler import LambdaLR, LRScheduler -from cosmos_framework.utils.functional.lr_scheduler import LambdaLinearScheduler, LambdaWarmUpCosineScheduler +from cosmos_framework.utils.functional.lr_scheduler import LambdaLinearScheduler, LambdaWarmUpCosineScheduler, WSDScheduler from cosmos_framework.utils import log @@ -404,23 +404,26 @@ def build_optimizer( def _lr_scheduler_cls( lr_scheduler_type: str, **lr_scheduler_kwargs: Any, -) -> LambdaLinearScheduler | LambdaWarmUpCosineScheduler: +) -> LambdaLinearScheduler | LambdaWarmUpCosineScheduler | WSDScheduler: """Instantiate a lambda-style scheduler whose ``.schedule(step)`` returns an LR multiplier. Both returned classes expose a ``schedule(step) -> float`` callable that :class:`LRSchedulersContainer` wraps with ``torch.optim.lr_scheduler.LambdaLR`` to drive each optimizer's param-group LRs. ``lr_scheduler_type`` matching is - case-insensitive; valid values are ``"lambdalinear"`` (linear decay) and - ``"lambdacosine"`` (warmup + cosine decay). Any other value raises - ``NotImplementedError``. All remaining ``**lr_scheduler_kwargs`` are - forwarded verbatim to the underlying scheduler constructor (e.g. - ``warm_up_steps``, ``cycle_lengths``, ``f_start``, ``f_max``, ``f_min``, + case-insensitive; valid values are ``"lambdalinear"`` (linear decay), + ``"lambdacosine"`` (warmup + cosine decay), and ``"wsd"`` + (warmup-stable-decay). Any other value raises ``NotImplementedError``. + All remaining ``**lr_scheduler_kwargs`` are forwarded verbatim to the + underlying scheduler constructor (e.g. ``warm_up_steps``, ``cycle_lengths``, + ``total_steps``, ``decay_steps``, ``f_start``, ``f_max``, ``f_min``, ``verbosity_interval``). """ if lr_scheduler_type.lower() == "lambdalinear": lr_scheduler = LambdaLinearScheduler(**lr_scheduler_kwargs) elif lr_scheduler_type.lower() == "lambdacosine": lr_scheduler = LambdaWarmUpCosineScheduler(**lr_scheduler_kwargs) + elif lr_scheduler_type.lower() == "wsd": + lr_scheduler = WSDScheduler(**lr_scheduler_kwargs) else: raise NotImplementedError(f"LR Scheduler {lr_scheduler_type} not found.") return lr_scheduler @@ -574,7 +577,7 @@ def build_lr_scheduler( :class:`torch.optim.lr_scheduler.LambdaLR` will be built per element of this container. lr_scheduler_type: Scheduler kind accepted by :func:`_lr_scheduler_cls` - — ``"lambdalinear"`` or ``"lambdacosine"`` (case-insensitive). + — ``"lambdalinear"``, ``"lambdacosine"``, or ``"wsd"`` (case-insensitive). **lr_scheduler_kwargs: Forwarded verbatim to the underlying lambda scheduler constructor (e.g. ``warm_up_steps``, ``cycle_lengths``, ``f_start``, ``f_max``, ``f_min``, ``verbosity_interval``). diff --git a/tests/launch_regression_test.py b/tests/launch_regression_test.py index 0106a2c..04322e5 100644 --- a/tests/launch_regression_test.py +++ b/tests/launch_regression_test.py @@ -586,15 +586,19 @@ def test_launch_regression_8gpu(spec_key: str, tmp_path: Path, h100_inputs: dict 39.70305, 48.52226, 52.18334, 22.77521, 25.06970, ], }, - # Captured 2026-06-09 on a 4 × NVIDIA GB200 node with seed 42 against the + # Recaptured 2026-06-25 on a 4 × NVIDIA GB200 node with seed 42 against the # current TOML-config pipeline (inputs prepared in-test by ``h100_inputs``, - # which now also serves gb200). Runs under ``--deterministic`` so loss - # reproduces bit-exact across all 10 iters; loss matches the h100 nano - # series within ~1e-3. grad_norm is non-det because ``compile.enabled=true`` - # makes the all-rank reduction not bit-exact, so None (same as h100). + # which now also serves gb200). The numerical shift from the 2026-06-09 + # capture reflects the rectified-flow sigma-sampling refactor + # (``t = 1 - t_raw`` flip moved into the sampler via per-sample ``shifts``) + # and is expected. Runs under ``--deterministic`` so loss reproduces bit-exact + # across all 10 iters. grad_norm is deterministic here (compile.enabled=false + # in nano_model_config under the new release branch), so values are pinned; + # flip to None if a future change re-enables compile and reintroduces + # non-determinism in the all-rank reduction. "vision_sft_nano": { - "loss": [0.2269, 0.2181, 0.2026, 0.2309, 0.2178, 0.273, 0.2871, 0.2164, 0.2059, 0.264], - "grad_norm": None, + "loss": [0.2243, 0.2133, 0.2437, 0.2255, 0.2616, 0.2552, 0.3313, 0.2247, 0.2036, 0.2621], + "grad_norm": [0.42188, 0.30469, 0.30078, 0.26953, 0.30273, 0.41406, 0.42773, 0.38477, 0.27344, 0.27344], }, }, # Recaptured 2026-06-03 on a 4 × NVIDIA H100 80GB HBM3 node with seed 42 and @@ -616,7 +620,7 @@ def test_launch_regression_8gpu(spec_key: str, tmp_path: Path, h100_inputs: dict # ``compile.enabled=true`` makes the all-rank reduction not bit-exact # on H100. "vision_sft_nano": { - "loss": [0.2272, 0.2181, 0.2028, 0.2306, 0.218, 0.2734, 0.2865, 0.2162, 0.2055, 0.2643], + "loss": [0.2242, 0.2141, 0.2429, 0.2259, 0.2608, 0.2555, 0.332, 0.2256, 0.2041, 0.2621], "grad_norm": None, }, "vision_sft_super": {