Skip to content

Logos-Lab/factorized-diffusion-policies

 
 

Repository files navigation

Factorized Diffusion Policies (FDP)

Factorizing Diffusion Policies for Observation Modality Prioritization

🎉 Accepted to ICRA 2026 (IEEE International Conference on Robotics and Automation) 🎉

Omkar Patil1, Prabin Kumar Rath1, Kartikay Pangaonkar1, Eric Rosen, Nakul Gopalan1

1Arizona State University

[Paper] [Project Website]


Abstract

Diffusion models have been extensively leveraged for learning robot skills from demonstrations. These policies are conditioned on several observational modalities such as proprioception, vision and tactile. However, observational modalities have varying levels of influence for different tasks that diffusion policies fail to capture. We propose Factorized Diffusion Policies (FDP), a novel policy formulation that enables observational modalities to have differing influence on the action diffusion process by design. This results in learning policies where certain observation modalities can be prioritized over others such as vision > tactile or proprioception > vision.

FDP achieves modality prioritization by factorizing the observational conditioning for the diffusion process, resulting in more performant and robust policies:

  • 15% absolute improvement in success rate on several simulated benchmarks in low-data regimes
  • 40% higher absolute success rate across visuomotor tasks under distribution shifts (visual distractors, camera occlusions)
  • 5× success rate over standard diffusion policies under camera occlusion settings

Method Overview

FDP learns a base policybase) using prioritized input modalities and a residual policyres) that captures the effect of remaining modalities. The base and residual model outputs are composed to obtain samples from the full conditional action distribution. Key design choices:

  1. Two-phase training: First train πbase with prioritized modalities (e.g., proprioception), then train πres with all modalities while keeping πbase frozen
  2. Block-wise composition: Instead of composing final score outputs, πres learns block-wise residuals over intermediate outputs of the frozen πbase
  3. Zero-initialized layers: Applied on residual block outputs to avoid harmful updates at the start of training (inspired by ControlNet)

Installation

Prerequisites

  • Python 3.10
  • CUDA-capable GPU (tested on NVIDIA A5000 / A40)

Environment Setup

# Create conda environment
conda env create -f environment.yml
conda activate comp_robotics

# Set environment variables for data and checkpoint directories
export DATA_DIR="/path/to/your/datasets"
export CKPT_DIR="/path/to/your/checkpoints"

Note: DATA_DIR should point to where your dataset files are stored, and CKPT_DIR is where training checkpoints will be saved. If not set, they default to ~/datasets and ~/checkpoints respectively.

Benchmark-specific Dependencies

Depending on which benchmarks you want to use, additional setup may be required:


Repository Structure

Click to expand
factorized-diffusion-policies/
├── train_diffusion.py            # Train baseline diffusion policies (DP-DiT, DP-UNet) and single-modality policies
├── train_guided_diffusion.py     # Train FDP residual adapters on additional modalities
├── train_smol_vla.py             # Train SmolVLA (Vision-Language-Action) baseline
├── config/                       # Hydra configuration files
│   ├── train_diffusion.yaml      # Main training configuration
│   ├── rlbench_rollout.yaml      # RLBench rollout config
│   ├── robomimic_rollout.yaml    # Robomimic rollout config
│   ├── d4rl_rollout.yaml         # D4RL/Adroit rollout config
│   ├── m3l_rollout.yaml          # M3L (tactile) rollout config
│   ├── pcd_rlbench_rollout.yaml  # Point-cloud RLBench rollout config
│   ├── rbmc_img_rollout.yaml     # Robomimic image rollout config
│   ├── smolvla_rlbench_rollout.yaml
│   ├── datasets/                 # Dataset-specific configurations
│   │   ├── rl_bench.yaml         # RLBench dataset config
│   │   ├── robo_mimic.yaml       # Robomimic dataset config
│   │   ├── d4rl.yaml             # D4RL/Adroit dataset config
│   │   ├── m3l.yaml              # M3L tactile dataset config
│   │   ├── pcd_rl_bench.yaml     # Point-cloud RLBench config
│   │   └── rbmc_img.yaml         # Robomimic image config
│   └── models/                   # Model architecture configurations
│       ├── dit.yaml              # DiT baseline (DP-DiT)
│       ├── unet.yaml             # UNet baseline (DP-UNet)
│       ├── gdn_dit.yaml          # FDP: prop > vision (conv zero-layer)
│       ├── gdn_dit_linear.yaml   # FDP: prop > vision (linear zero-layer)
│       ├── vis_gdn_dit.yaml      # FDP: vision > prop
│       ├── vis_gdn_dit_linear.yaml
│       ├── pcd_dit.yaml          # Point-cloud DiT
│       ├── pcd_gdn_dit.yaml      # FDP: pcd > vision
│       ├── pcd_unet.yaml         # Point-cloud UNet
│       ├── cross_attn.yaml       # Cross-attention variant
│       ├── cfg_dit.yaml          # Classifier-free guidance DiT
│       └── lowdim_gdn_dit.yaml   # Low-dimensional FDP
├── models/                       # Model implementations
│   ├── diffusion/                # Baseline diffusion policies
│   │   ├── diffusion_policy.py   # Standard Diffusion Policy (DP)
│   │   ├── dit_model.py          # Diffusion Transformer (DiT) architecture
│   │   ├── cfg_diffusion_policy.py  # Classifier-Free Guidance DP
│   │   ├── cross_attn_transformer.py
│   │   └── pcd_diffusion_policy.py  # Point-cloud DP
│   ├── gdn_diffusion/            # FDP (Factorized Diffusion Policy) models
│   │   ├── gdn_diffusion_policy.py    # FDP: prop > vision policy
│   │   ├── gdn_dit.py                # FDP residual DiT (conv zero-layer)
│   │   ├── gdn_dit_linear.py         # FDP residual DiT (linear zero-layer)
│   │   ├── vision_gdn_diffusion_policy.py  # FDP: vision > prop policy
│   │   ├── vision_gdn_dit.py         # FDP vision residual DiT
│   │   ├── vision_gdn_dit_linear.py   # FDP vision residual DiT (linear)
│   │   ├── pcd_gdn_diffusion_policy.py  # FDP: pcd > vision policy
│   │   └── pcd_gdn_dit.py            # FDP point-cloud residual DiT
│   ├── img_encoder/              # Image encoders (ResNet-18 backbones)
│   ├── pcd_encoder/              # Point-cloud encoders (PointNet)
│   ├── unet/                     # UNet architecture
│   ├── vla/                      # Vision-Language-Action models
│   └── tools/                    # Utilities (EMA, normalizer, LR scheduler)
├── datasets/                     # Dataset loading and preprocessing
│   ├── rl_bench.py               # RLBench dataset loader
│   ├── robo_mimic.py             # Robomimic dataset loader
│   ├── d4rl.py                   # D4RL/Adroit dataset loader
│   ├── m3l.py                    # M3L tactile dataset loader
│   ├── pcd_rl_bench.py           # Point-cloud RLBench loader
│   └── lerobot_rl_bench.py       # LeRobot RLBench loader (for SmolVLA)
├── policy_rollout/               # Policy evaluation / rollout scripts
│   ├── rollout.py                # Base rollout class
│   ├── comp_inference.py         # Compositional inference (score composition)
│   ├── rlbench_rollout.py        # RLBench environment rollout
│   ├── robomimic_rollout.py      # Robomimic environment rollout
│   ├── d4rl_rollout.py           # D4RL/Adroit environment rollout
│   ├── m3l_rollout.py            # M3L tactile environment rollout
│   ├── pcd_rlbench_rollout.py    # Point-cloud RLBench rollout
│   ├── rbmc_img_rollout.py       # Robomimic image rollout
│   ├── smolvla_rlbench_rollout.py  # SmolVLA RLBench rollout
│   └── benchmark_compute.py      # Compute benchmarking utilities
├── envs/                         # Environment wrappers & visualization
│   ├── d4rl/                     # D4RL environments (Kitchen, Adroit, AntMaze)
│   ├── m3l/                      # M3L visualization
│   ├── rl_bench/                 # RLBench dataset generation & visualization
│   └── robomimic/                # Robomimic visualization
├── utils/                        # General utilities
│   ├── misc.py                   # Constants, seed, checkpoint search, helpers
│   ├── train_utils.py            # BaseTrainer with checkpointing
│   └── checkpoint_util.py        # Top-K checkpoint management
├── scripts/                      # Analysis and data processing scripts
├── environment.yml               # Conda environment specification
└── LICENSE                       # MIT License

Training

This codebase uses Hydra for configuration management and Weights & Biases for experiment logging. All training runs are configured via YAML files in config/.

Step 1: Train a Baseline Diffusion Policy (πbase)

Train a standard jointly-conditioned diffusion policy or a single-modality base model:

# Train a DiT-based diffusion policy on RLBench (all modalities)
CUDA_VISIBLE_DEVICES=0 python train_diffusion.py \
    datasets=rl_bench \
    ++datasets.filepath=open_box \
    ++max_demos=50

# Train a proprioception-only motion model (π_base for FDP prop>vision)
CUDA_VISIBLE_DEVICES=0 python train_diffusion.py \
    datasets=rl_bench \
    ++datasets.filepath=open_box \
    ++max_demos=50 \
    ++mm=true

# Train on Robomimic (low-dimensional)
CUDA_VISIBLE_DEVICES=0 python train_diffusion.py \
    datasets=robo_mimic \
    ++datasets.filepath=can_low_dim \
    ++max_demos=50

# Train on D4RL / Adroit
CUDA_VISIBLE_DEVICES=0 python train_diffusion.py \
    datasets=d4rl \
    ++datasets.filepath=door_human \
    ++max_demos=5

# Train with UNet architecture
CUDA_VISIBLE_DEVICES=0 python train_diffusion.py \
    datasets=rl_bench \
    models@_global_=unet \
    ++datasets.filepath=open_box \
    ++max_demos=50

# Train with point-cloud inputs
CUDA_VISIBLE_DEVICES=0 python train_diffusion.py \
    datasets=pcd_rl_bench \
    ++datasets.filepath=pc_open_fridge \
    ++max_demos=5 \
    ++pcd_only=true \
    models@_global_=pcd_dit

Step 2: Train the FDP Residual Adapter (πres)

Once a base model (πbase) is trained, train the residual model conditioned on all modalities:

# Train FDP residual: prop > vision (specify base model checkpoint via mm_bb)
CUDA_VISIBLE_DEVICES=0 python train_guided_diffusion.py \
    datasets=rl_bench \
    ++max_demos=50 \
    models@_global_=gdn_dit \
    ++datasets.filepath=open_box \
    ++mm_bb=<base_model_checkpoint_name>

# Train FDP residual: vision > prop
CUDA_VISIBLE_DEVICES=0 python train_guided_diffusion.py \
    datasets=rl_bench \
    ++max_demos=50 \
    models@_global_=vis_gdn_dit \
    ++datasets.filepath=open_box \
    ++mm_bb=<base_model_checkpoint_name>

# Train FDP residual: pcd > vision
CUDA_VISIBLE_DEVICES=0 python train_guided_diffusion.py \
    datasets=pcd_rl_bench \
    ++max_demos=50 \
    models@_global_=pcd_gdn_dit \
    ++datasets.filepath=pc_open_fridge \
    ++mm_bb=<base_model_checkpoint_name>

Step 3: Train SmolVLA Baseline (Optional)

CUDA_VISIBLE_DEVICES=0 DATASET=sweep_to_dustpan MAX_DEMOS=50 python train_smol_vla.py

Key Training Arguments

Argument Description
datasets=<name> Dataset config: rl_bench, robo_mimic, d4rl, m3l, pcd_rl_bench, rbmc_img
models@_global_=<name> Model config: dit, unet, gdn_dit, gdn_dit_linear, vis_gdn_dit, pcd_gdn_dit, cfg_dit
++datasets.filepath=<task> Task name (e.g., open_box, can_low_dim, kitchen_complete_v0, door_human)
++max_demos=<N> Number of training demonstrations (null for all)
++mm=true Train a motion-only (proprioception) base model
++vm=true Train a vision-only base model
++mm_bb=<name> Load a base model checkpoint for FDP or guided training
++logging.mode=disabled Disable W&B logging (useful for debugging)
++training.num_epochs=<N> Override number of training epochs
++device=cuda:0 Set GPU device

Policy Rollout / Evaluation

Evaluate trained policies in simulation environments. Rollout scripts are located in policy_rollout/. Each rollout script uses its own Hydra config from config/ (e.g., rlbench_rollout.yaml).

RLBench Rollout

CUDA_VISIBLE_DEVICES=0 python policy_rollout/rlbench_rollout.py \
    ++datasets.filepath=open_box \
    ++ckpt_tag=<checkpoint_tag>

Robomimic Rollout

CUDA_VISIBLE_DEVICES=0 python policy_rollout/robomimic_rollout.py \
    ++datasets.filepath=can_low_dim \
    ++ckpt_tag=<checkpoint_tag>

D4RL / Adroit Rollout

# Unset LD_PRELOAD for video saving (MuJoCo rendering)
unset LD_PRELOAD

CUDA_VISIBLE_DEVICES=0 python policy_rollout/d4rl_rollout.py \
    ++datasets.filepath=kitchen_complete_v0 \
    ++ckpt_tag=<checkpoint_tag>

M3L Tactile Rollout

CUDA_VISIBLE_DEVICES=0 python policy_rollout/m3l_rollout.py \
    ++datasets.filepath=insertion \
    ++ckpt_tag=<checkpoint_tag>

Point Cloud RLBench Rollout

CUDA_VISIBLE_DEVICES=0 python policy_rollout/pcd_rlbench_rollout.py \
    ++datasets.filepath=pc_open_fridge \
    ++ckpt_tag=<checkpoint_tag>

Compositional Inference

For evaluating POCO with score composition at inference, both the FDP checkpoint (ckpt_tag) and the base motion model checkpoint (mm_bb) are required:

CUDA_VISIBLE_DEVICES=0 python policy_rollout/rlbench_rollout.py \
    ++datasets.filepath=open_box \
    ++ckpt_tag=<fdp_checkpoint_tag> \
    ++mm_bb=<base_model_checkpoint_name> \
    ++compose=true \
    ++comp_weights=<weights>

Key Rollout Arguments

Argument Description
++datasets.filepath=<task> Task name (e.g., open_box, can_low_dim, door_human)
++ckpt_tag=<tag> Checkpoint identifier to load
++num_rollouts=<N> Number of evaluation episodes (default varies per benchmark)
++rollout_steps=<N> Action chunking horizon — number of actions executed per prediction (default varies per benchmark)
++num_inference_steps=<N> Number of DDIM denoising steps at inference (default: 8)
++compose=true Enable compositional inference (FDP base + residual)
++mm_bb=<name> Base model checkpoint name (required when compose=true)
++comp_weights=<w> Composition weights for score composition
++mm_cond_mod=<mod> Base model conditioning modality: proprio, vision, or pcd
++mm=true Rollout with motion-only (proprioception) model
++pcd_only=true Use only point-cloud inputs (PCD rollout)
++pcd_noise=<std> Add Gaussian noise to point clouds for robustness testing (PCD rollout)
++ula=true Enable Unadjusted Langevin Annealing sampling during compositional inference
++gui=true Enable GUI visualization (RLBench)
++save_video=true Save rollout videos
++seed=<N> Random seed for reproducibility
++device=cuda:0 Set GPU device

Supported Benchmarks

Benchmark Tasks Observation Modalities Action Space Config
RLBench 14 visuomotor tasks (open_box, close_box, open_door, etc.) RGB (multi-camera), proprioception Joint positions + gripper datasets=rl_bench
RLBench (PCD) Point-cloud variants Point clouds, RGB, proprioception Joint positions + gripper datasets=pcd_rl_bench
Robomimic Lift, Can, Square, Toolhang Low-dim state (EE pos/quat, object) ΔEE pose (axis-angle) datasets=robo_mimic
Adroit Door, Hammer, Pen, Relocate Low-dim state (proprioception, environment) 24-DoF joint actions datasets=d4rl
D4RL Kitchen Kitchen Complete/Partial/Mixed Low-dim state Joint actions datasets=d4rl
M3L Visuo-tactile peg insertion RGB camera, tactile sensors ΔXYZ datasets=m3l

Model Configurations

Config Name Policy Type Description
dit DiffusionPolicy DP-DiT baseline (DiT-S, ~33M params)
unet DiffusionPolicy DP-UNet baseline (1D-UNet)
gdn_dit GdnDiffusionPolicy FDP: prop > vision (conv zero-layer, ~85M total)
gdn_dit_linear GdnDiffusionPolicy FDP: prop > vision (linear zero-layer)
vis_gdn_dit VisionGdnDiffusionPolicy FDP: vision > prop
vis_gdn_dit_linear VisionGdnDiffusionPolicy FDP: vision > prop (linear)
pcd_gdn_dit PcdGdnDiffusionPolicy FDP: pcd > vision
cfg_dit CfgDiffusionPolicy Classifier-Free Guidance baseline
cross_attn Cross-Attention Cross-attention conditioning variant

Training Details

  • Architectures: DiT-Small (12 layers, 6 heads, 384 hidden dim), UNet (1D)
  • Image Encoder: ResNet-18 (~12M params per camera), BatchNorm replaced with GroupNorm for EMA compatibility
  • Noise Schedule: 100-step DDPM (cosine β schedule), 8-step DDIM sampling
  • Optimizer: AdamW (lr=1e-4, weight_decay=1e-3 for DiT / 1e-6 for UNet)
  • Training Epochs: 2000 (visual) / 3000 (low-dim) for DiT; 3000 / 5000 for UNet
  • Batch Size: 64 (visual) / 256 (low-dim)
  • EMA: Enabled for all models
  • Hardware: NVIDIA A5000 / A40 GPUs, 6–12 hours training time
  • Inference Latency: ~50ms (DP-DiT), ~100ms (DP-UNet / output composition), ~150ms (FDP block-wise)

Citation

If you find this work useful, please cite:

@article{patil2025factorizing,
  title={Factorizing Diffusion Policies for Observation Modality Prioritization},
  author={Patil, Omkar and Rath, Prabin Kumar and Pangaonkar, Kartikay and Rosen, Eric and Gopalan, Nakul},
  journal={arXiv preprint arXiv:2509.16830},
  year={2025}
}

Acknowledgments

This codebase builds upon several excellent open-source projects:


License

This project is licensed under the MIT License — see the LICENSE file for details.

About

[ICRA 2026] Factorizing Diffusion Policies for Observation Modality Prioritization

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

  • Python 100.0%