Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
ab4778b
Add ARO (Adaptively Rotated Optimization) optimizer
JohnLangford Mar 29, 2026
8cd539a
Refactor ARO to reuse megabatch_orthogonalize_async for QR distribution
JohnLangford Mar 29, 2026
09126a2
Add sinkhorn as base_opt for ARO
JohnLangford Mar 29, 2026
7758f51
Default base_opt to sinkhorn (paper's recommendation)
JohnLangford Mar 29, 2026
82e9ce0
Add FSDP support to ARO
JohnLangford Mar 29, 2026
bd02502
Fix cusolver OOM in ARO FSDP path by freeing intermediates before QR
JohnLangford Mar 29, 2026
75b73dc
Fix cusolver OOM: free M_f32 before QR and empty CUDA cache
JohnLangford Mar 29, 2026
ee36c41
Use magma backend for QR to avoid cusolver handle creation failure
JohnLangford Mar 29, 2026
d0dc3db
Move QR to CPU to avoid GPU linalg workspace exhaustion
JohnLangford Mar 29, 2026
9d9c752
Fix NameError for N=1 shape groups under FSDP
JohnLangford Mar 29, 2026
7644740
Move QR back to GPU now that the root cause (N=1 NameError) is fixed
JohnLangford Mar 29, 2026
6d3a92d
Replace Householder QR with Shifted Cholesky QR
JohnLangford Mar 29, 2026
fdacce6
Release cached memory before cusolver calls
Mar 29, 2026
1803abf
Move empty_cache inside _shifted_cholesky_qr, right before cholesky_ex
Mar 30, 2026
d5a31dd
Add synchronize before empty_cache in cusolver path
Mar 30, 2026
5d04dd3
Move Cholesky QR to CPU to avoid cusolver corruption
Mar 30, 2026
8242b92
Revert to CPU QR — solve_triangular also uses cusolver
Mar 30, 2026
0a2b8b9
Disable autocast in aro_ortho_fn and ensure R_my is float32
Mar 31, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .aro import ARO
from .dion import Dion
from .dion import DionMixedPrecisionConfig
from .dion_simple import Dion as DionSimple
Expand Down
374 changes: 374 additions & 0 deletions dion/aro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,374 @@
import math
import torch
from collections import defaultdict
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.distributed.tensor import DTensor
from torch.optim.optimizer import ParamsT
from typing import Callable, Generator, List, Optional, Tuple, Union

from .megabatch_base import DistributedOrthoBase, megabatch_orthogonalize_async
from .opt_utils import AsyncTask, to_local
from .muon import adjust_lr_spectral_norm, adjust_lr_rms_norm


class ARO(DistributedOrthoBase):
"""
Adaptively Rotated Optimization (ARO) optimizer.

ARO performs normed steepest descent in a rotated coordinate system,
where the rotation is determined by a norm-informed policy that couples
the rotation to the base optimizer's transformation.

Each parameter of shape [m, n] maintains an m×m rotation matrix R in
float32, adding O(m²) memory per parameter (e.g., 64 MB for m=4096).

Reference: https://arxiv.org/abs/2602.09006

Args:
params: Parameters for the optimizer.
distributed_mesh: DeviceMesh or ProcessGroup for distributed training.
lr: Base learning rate.
mu: Momentum factor for EMA gradient accumulation.
betas: Tuple of (beta1, beta2) for AdamW and Lion scalar parameter groups.
weight_decay: Weight decay factor.
epsilon: Small value for numerical stability.
base_opt: Base optimizer function applied in the rotated frame.
"sinkhorn": Alternating row/column L2 normalization (recommended).
"row_norm": f(X) = sqrt(n) * X / ||x_i|| (row normalization)
"sign": f(X) = sign(X)
adjust_lr: How to adjust the learning rate for ARO updates.
"spectral_norm", "rms_norm", or None.
flatten: Whether to flatten 3D+ tensors to 2D.
"""

def __init__(
self,
params: ParamsT,
distributed_mesh: Optional[Union["DeviceMesh", ProcessGroup]] = None,
lr: float = 0.01,
mu: float = 0.95,
betas: Tuple[float, float] = (0.9, 0.95),
weight_decay: float = 0.01,
epsilon: float = 1e-8,
base_opt: str = "sinkhorn",
adjust_lr: Optional[str] = "rms_norm",
flatten: bool = False,
):
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if mu < 0.0:
raise ValueError(f"Invalid momentum factor: {mu}")
if len(betas) != 2 or betas[0] < 0.0 or betas[1] < 0.0:
raise ValueError(f"Invalid betas: {betas}")
if base_opt not in ("row_norm", "sign", "sinkhorn"):
raise ValueError(
f"Invalid base_opt: {base_opt}. Must be 'row_norm', 'sign', or 'sinkhorn'."
)
if adjust_lr not in ("spectral_norm", "rms_norm", None):
raise ValueError(
f"Invalid adjust_lr: {adjust_lr}. Must be 'spectral_norm', 'rms_norm', or None."
)

defaults = dict(
lr=lr,
mu=mu,
beta1=betas[0],
beta2=betas[1],
weight_decay=weight_decay,
epsilon=epsilon,
base_opt=base_opt,
flatten=flatten,
adjust_lr=adjust_lr,
algorithm="aro",
step=0,
)
super().__init__(params, distributed_mesh, "aro", defaults)

def _get_or_initialize_state(self, param: Tensor, algo: str) -> dict:
state = super()._get_or_initialize_state(param, algo)
if algo == "aro" and "rotation" not in state:
m = param.shape[-2]
state["rotation"] = torch.eye(m, device=param.device, dtype=torch.float32)
return state

def _create_ortho_tasks(
self, param_groups: List[dict]
) -> Generator["AsyncTask", None, None]:
for group in param_groups:
assert group["algorithm"] == "aro"
assert all(
p.ndim >= 2 for p in group["params"]
), "ARO only supports matrix parameters."

group_params = [p for p in group["params"] if p.grad is not None]
if not group_params:
continue

update_args = dict(
lr=torch.tensor(group["lr"]),
momentum=torch.tensor(group["mu"]),
weight_decay=torch.tensor(group["weight_decay"]),
epsilon=torch.tensor(group["epsilon"]),
base_opt=group["base_opt"],
flatten=group["flatten"],
adjust_lr=group["adjust_lr"],
device_rank=self._device_rank,
world_size=self._world_size,
process_group=self._process_group,
)

shape_groups: dict[tuple, list] = defaultdict(list)
for p in group_params:
sharding = p.placements if isinstance(p, DTensor) else None
shape_groups[(p.shape, sharding, p.dtype)].append(p)

for (_shape, _sharding, _dtype), params in shape_groups.items():
gradients = [p.grad for p in params]
states = [self._get_or_initialize_state(p, "aro") for p in params]
momentums = [s["momentum"] for s in states]
rotations = [s["rotation"] for s in states]

is_batch_sharded, is_matrix_sharded, sharded_tensor_dim = (
self._get_shard_info(params[0], group)
)

megabatch_args = update_args
if is_batch_sharded and not is_matrix_sharded:
megabatch_args = {**update_args, "process_group": None}

yield AsyncTask(
aro_update_megabatch_async(
X=params,
G=gradients,
M=momentums,
R=rotations,
shard_dim=sharded_tensor_dim,
**megabatch_args,
)
)


def aro_update_megabatch_async(
X: List[Tensor],
G: List[Tensor],
M: List[Tensor],
R: List[Tensor], # float32 rotation matrices
lr: Tensor,
momentum: Tensor,
weight_decay: Tensor,
epsilon: Tensor,
base_opt: str,
flatten: bool,
adjust_lr: Optional[str],
device_rank: int,
world_size: int,
shard_dim: Optional[int] = None,
process_group: Optional[ProcessGroup] = None,
) -> Generator[None, None, None]:
"""
Megabatched ARO update. Distributes the full ARO computation (rotation,
cross-alignment, QR, update direction) across ranks via the shared
megabatch_orthogonalize_async infrastructure.

For FSDP: the all-to-all reassembles full matrices before the ARO step;
for DDP: each rank processes its assigned params then all-gathers.
In both cases, the ARO computation runs inside a closure that captures
the rotation matrices R for the assigned params.
"""
N = len(X)
assert N == len(G) == len(M) == len(R)

M_local = to_local(M)
G_local = to_local(G)

# Update momentum: M = mu * M + (1-mu) * G
G_cast = [g.to(dtype=m.dtype) for g, m in zip(G_local, M_local)]
torch._foreach_lerp_(M_local, G_cast, 1 - momentum)

base_opt_fn = _get_base_opt_fn(base_opt)
comm_dim = (shard_dim - X[0].ndim) if shard_dim is not None else None

# Determine which params are assigned to this rank so the closure
# can access the right rotation matrices.
pad_n = (world_size - N % world_size) % world_size if process_group is not None and N > 1 else 0
N_total = N + pad_n
per_rank = N_total // world_size if process_group is not None and N > 1 else N
start = device_rank * per_rank if process_group is not None and N > 1 else 0

# Pad R to match padded M list, stack this rank's assigned rotations
R_padded = R + [torch.eye(
R[0].shape[-1], device=R[0].device, dtype=R[0].dtype
)] * pad_n if pad_n > 0 else R
R_my = torch.stack(R_padded[start : start + per_rank])
R_new_holder = [None]

def aro_ortho_fn(M_batch, epsilon=None):
"""Full ARO step: rotation → base_opt → cross-alignment → QR → update.

M_batch is [per_rank, m, n] — full (unsharded) matrices for the
params assigned to this rank, after all-to-all reassembly (FSDP)
or direct stacking (DDP).

"""
# Disable autocast — FSDP MixedPrecisionPolicy may wrap the
# optimizer step in bf16 autocast, but ARO needs float32 for
# the rotation and cross-alignment computation.
with torch.autocast("cuda", enabled=False):
return _aro_ortho_fn_impl(M_batch, epsilon)

def _aro_ortho_fn_impl(M_batch, epsilon=None):
M_f32 = M_batch.float()

# Phase 1: compute cross-alignment matrix
rotated = R_my.float().mT @ M_f32
f_rotated = base_opt_fn(rotated)
del rotated
cross = M_f32 @ f_rotated.mT
del f_rotated, M_f32

# Phase 2: QR on CPU — cusolver is corrupted by torch.compile on
# B200/CUDA 13.0 after the first compiled forward/backward.
# CPU QR is cheap since cross is square [per_rank, m, m].
Q, _ = torch.linalg.qr(cross.cpu())
Q = Q.to(device=cross.device, dtype=cross.dtype)
del cross
R_new_holder[0] = Q

# Phase 3: compute update direction with new rotation
M_f32 = M_batch.float()
rotated_new = Q.mT @ M_f32
f_new = base_opt_fn(rotated_new)
del rotated_new, M_f32
return (Q @ f_new).to(M_batch.dtype)

# Distribute ARO computation via shared megabatch infrastructure.
# For FSDP: all-to-all reassembles full M, aro_ortho_fn runs on full
# matrices, result is scattered back.
# For DDP: each rank runs aro_ortho_fn on its assigned chunk, all-gather.
U_list = yield from megabatch_orthogonalize_async(
M_local,
comm_dim=comm_dim,
device_rank=device_rank,
world_size=world_size,
process_group=process_group,
newton_schulz_func=aro_ortho_fn,
flatten=False,
epsilon=epsilon,
)

# Update rotation state for this rank's assigned params
if R_new_holder[0] is not None:
Q_new = R_new_holder[0]
for i in range(per_rank):
idx = start + i
if idx < N:
R[idx].copy_(Q_new[i])

# Compute adjusted learning rate
if adjust_lr is None:
adjusted_lr = lr
elif adjust_lr == "spectral_norm":
adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape, flatten=flatten)
elif adjust_lr == "rms_norm":
adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape, flatten=flatten)
else:
raise ValueError(f"Unknown adjust_lr: {adjust_lr}")

# Apply weight decay and parameter update
aro_post_update(
X=to_local(X),
U=U_list,
base_lr=lr,
adjusted_lr=adjusted_lr,
weight_decay=weight_decay,
)


def _shifted_cholesky_qr(A: Tensor) -> Tensor:
"""Orthogonalize A via Shifted Cholesky QR on CPU.

torch.compile on B200/CUDA 13.0 corrupts cusolver state, making all
GPU linalg operations fail after the first compiled forward/backward.
We move the decomposition to CPU (Gram matrix is small: [batch, m, m])
and keep the matmuls on GPU.

Uses matmul + Cholesky + triangular solve. Adds a small shift to the
Gram matrix diagonal for numerical stability. Falls back to
Householder QR if Cholesky fails.
"""
device = A.device
G = A.mT @ A # Gram matrix [*, m, m], via cuBLAS (on GPU)
G_cpu = G.cpu()
del G
# Shift proportional to the Frobenius norm of A
shift = A.norm().item() ** 2 * 1e-7
G_cpu.diagonal(dim1=-2, dim2=-1).add_(shift)
# Cholesky on CPU (avoids cusolver entirely)
R, info = torch.linalg.cholesky_ex(G_cpu, upper=True)
if (info != 0).any():
# Fallback: QR on CPU
A_cpu = A.cpu()
Q, _ = torch.linalg.qr(A_cpu)
return Q.to(device)
R = R.to(device)
return torch.linalg.solve_triangular(R, A, upper=True, left=False)


def _get_base_opt_fn(base_opt: str):
if base_opt == "row_norm":
return _base_opt_row_norm
elif base_opt == "sign":
return _base_opt_sign
elif base_opt == "sinkhorn":
return _base_opt_sinkhorn
raise ValueError(f"Unknown base_opt: {base_opt}")


def _base_opt_row_norm(X: Tensor) -> Tensor:
"""f(X) = sqrt(n) * X / ||x_i||_row"""
n = X.shape[-1]
row_norms = X.norm(dim=-1, keepdim=True).clamp(min=1e-8)
return math.sqrt(n) * X / row_norms


def _base_opt_sign(X: Tensor) -> Tensor:
"""f(X) = sign(X)"""
return torch.sign(X)


def _base_opt_sinkhorn(X: Tensor, num_iters: int = 5, eps: float = 1e-8) -> Tensor:
"""SR-Sinkhorn normalization: alternating L2 row/column normalization.

Each iteration normalizes rows to have L2 norm sqrt(cols), then
columns to have L2 norm sqrt(rows). This corresponds to the
square-root iterates of the classical Sinkhorn algorithm applied
to the matrix of squared entries.

Reference: https://arxiv.org/abs/2502.06742
"""
m, n = X.shape[-2], X.shape[-1]
for _ in range(num_iters):
# Row normalization: each row gets L2 norm sqrt(n)
row_norms = X.norm(dim=-1, keepdim=True).clamp(min=eps)
X = X * (math.sqrt(n) / row_norms)
# Column normalization: each column gets L2 norm sqrt(m)
col_norms = X.norm(dim=-2, keepdim=True).clamp(min=eps)
X = X * (math.sqrt(m) / col_norms)
return X


@torch.compile(fullgraph=True)
def aro_post_update(
X: List[Tensor],
U: List[Tensor],
base_lr: Tensor,
adjusted_lr: Tensor,
weight_decay: Tensor,
):
"""Apply weight decay and parameter update."""
torch._foreach_mul_(X, 1 - base_lr * weight_decay)
dtype = X[0].dtype
U = [u.to(dtype=dtype) for u in U]
torch._foreach_mul_(U, -adjusted_lr)
torch._foreach_add_(X, U)
2 changes: 1 addition & 1 deletion dion/megabatch_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def megabatch_orthogonalize_async(
N = len(U)

# Pad to divisible by world_size (needed by both distributed paths)
if process_group is not None and N > 1:
if process_group is not None and (N > 1 or comm_dim is not None):
pad_n = (world_size - N % world_size) % world_size
U_work = U + [torch.zeros_like(U[0])] * pad_n if pad_n > 0 else U
N_total = len(U_work)
Expand Down