From ab4778b41afd5cf0f2382a3e06a6db516e7090ee Mon Sep 17 00:00:00 2001 From: John Langford Date: Sat, 28 Mar 2026 17:23:26 -0700 Subject: [PATCH 01/18] Add ARO (Adaptively Rotated Optimization) optimizer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements the ARO algorithm from https://arxiv.org/abs/2602.09006. ARO maintains a per-parameter rotation matrix R ∈ SO(m) that is updated each step via QR decomposition of a cross-alignment matrix coupling the gradient to the base optimizer's transformation. - Subclasses DistributedOrthoBase for shared infrastructure - R stored in float32 for QR stability, adding O(m²) memory per param - DDP megabatch distributes QR+matmul across ranks via all-gather - Base optimizer functions: row_norm (default) and sign - FSDP not supported (R requires full row dimension) --- dion/__init__.py | 1 + dion/aro.py | 332 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 333 insertions(+) create mode 100644 dion/aro.py diff --git a/dion/__init__.py b/dion/__init__.py index 34894e6..edc79e3 100644 --- a/dion/__init__.py +++ b/dion/__init__.py @@ -1,3 +1,4 @@ +from .aro import ARO from .dion import Dion from .dion import DionMixedPrecisionConfig from .dion_simple import Dion as DionSimple diff --git a/dion/aro.py b/dion/aro.py new file mode 100644 index 0000000..30f24ee --- /dev/null +++ b/dion/aro.py @@ -0,0 +1,332 @@ +import math +import torch +import torch.distributed as dist +from collections import defaultdict +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.distributed.tensor import DTensor +from torch.optim.optimizer import ParamsT +from typing import Callable, Generator, List, Optional, Tuple, Union + +from .megabatch_base import DistributedOrthoBase +from .opt_utils import AsyncTask, to_local +from .muon import adjust_lr_spectral_norm, adjust_lr_rms_norm + + +class ARO(DistributedOrthoBase): + """ + Adaptively Rotated Optimization (ARO) optimizer. + + ARO performs normed steepest descent in a rotated coordinate system, + where the rotation is determined by a norm-informed policy that couples + the rotation to the base optimizer's transformation. + + Each parameter of shape [m, n] maintains an m×m rotation matrix R in + float32, adding O(m²) memory per parameter (e.g., 64 MB for m=4096). + + FSDP is not supported — use DDP or single-GPU. + + Reference: https://arxiv.org/abs/2602.09006 + + Args: + params: Parameters for the optimizer. + distributed_mesh: DeviceMesh or ProcessGroup for distributed training. + lr: Base learning rate. + mu: Momentum factor for EMA gradient accumulation. + betas: Tuple of (beta1, beta2) for AdamW and Lion scalar parameter groups. + weight_decay: Weight decay factor. + epsilon: Small value for numerical stability. + base_opt: Base optimizer function applied in the rotated frame. + "row_norm": f(X) = sqrt(n) * X / ||x_i|| (row normalization) + "sign": f(X) = sign(X) + adjust_lr: How to adjust the learning rate for ARO updates. + "spectral_norm", "rms_norm", or None. + flatten: Whether to flatten 3D+ tensors to 2D. + """ + + def __init__( + self, + params: ParamsT, + distributed_mesh: Optional[Union["DeviceMesh", ProcessGroup]] = None, + lr: float = 0.01, + mu: float = 0.95, + betas: Tuple[float, float] = (0.9, 0.95), + weight_decay: float = 0.01, + epsilon: float = 1e-8, + base_opt: str = "row_norm", + adjust_lr: Optional[str] = "rms_norm", + flatten: bool = False, + ): + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if mu < 0.0: + raise ValueError(f"Invalid momentum factor: {mu}") + if len(betas) != 2 or betas[0] < 0.0 or betas[1] < 0.0: + raise ValueError(f"Invalid betas: {betas}") + if base_opt not in ("row_norm", "sign"): + raise ValueError( + f"Invalid base_opt: {base_opt}. Must be 'row_norm' or 'sign'." + ) + if adjust_lr not in ("spectral_norm", "rms_norm", None): + raise ValueError( + f"Invalid adjust_lr: {adjust_lr}. Must be 'spectral_norm', 'rms_norm', or None." + ) + + defaults = dict( + lr=lr, + mu=mu, + beta1=betas[0], + beta2=betas[1], + weight_decay=weight_decay, + epsilon=epsilon, + base_opt=base_opt, + flatten=flatten, + adjust_lr=adjust_lr, + algorithm="aro", + step=0, + ) + super().__init__(params, distributed_mesh, "aro", defaults) + + def _get_or_initialize_state(self, param: Tensor, algo: str) -> dict: + state = super()._get_or_initialize_state(param, algo) + if algo == "aro" and "rotation" not in state: + m = param.shape[-2] + state["rotation"] = torch.eye(m, device=param.device, dtype=torch.float32) + return state + + def _create_ortho_tasks( + self, param_groups: List[dict] + ) -> Generator["AsyncTask", None, None]: + for group in param_groups: + assert group["algorithm"] == "aro" + assert all( + p.ndim >= 2 for p in group["params"] + ), "ARO only supports matrix parameters." + + group_params = [p for p in group["params"] if p.grad is not None] + if not group_params: + continue + + update_args = dict( + lr=torch.tensor(group["lr"]), + momentum=torch.tensor(group["mu"]), + weight_decay=torch.tensor(group["weight_decay"]), + epsilon=torch.tensor(group["epsilon"]), + base_opt=group["base_opt"], + flatten=group["flatten"], + adjust_lr=group["adjust_lr"], + device_rank=self._device_rank, + world_size=self._world_size, + process_group=self._process_group, + ) + + shape_groups: dict[tuple, list] = defaultdict(list) + for p in group_params: + sharding = p.placements if isinstance(p, DTensor) else None + shape_groups[(p.shape, sharding, p.dtype)].append(p) + + for (_shape, _sharding, _dtype), params in shape_groups.items(): + gradients = [p.grad for p in params] + states = [self._get_or_initialize_state(p, "aro") for p in params] + momentums = [s["momentum"] for s in states] + rotations = [s["rotation"] for s in states] + + is_batch_sharded, is_matrix_sharded, sharded_tensor_dim = ( + self._get_shard_info(params[0], group) + ) + + if is_matrix_sharded: + raise NotImplementedError( + "ARO does not support FSDP-sharded parameters. " + "Use DDP or single-GPU instead." + ) + + megabatch_args = update_args + if is_batch_sharded: + megabatch_args = {**update_args, "process_group": None} + + yield AsyncTask( + aro_update_megabatch_async( + X=params, + G=gradients, + M=momentums, + R=rotations, + **megabatch_args, + ) + ) + + +def aro_update_megabatch_async( + X: List[Tensor], + G: List[Tensor], + M: List[Tensor], + R: List[Tensor], # float32 rotation matrices + lr: Tensor, + momentum: Tensor, + weight_decay: Tensor, + epsilon: Tensor, + base_opt: str, + flatten: bool, + adjust_lr: Optional[str], + device_rank: int, + world_size: int, + process_group: Optional[ProcessGroup] = None, +) -> Generator[None, None, None]: + """ + Megabatched ARO update. Distributes the per-parameter ARO computation + (QR + matmuls) across ranks via all-gather, matching the DDP pattern + used by Muon/NorMuon for Newton-Schulz. + """ + N = len(X) + assert N == len(G) == len(M) == len(R) + + M_local = to_local(M) + G_local = to_local(G) + + # Update momentum: M = mu * M + (1-mu) * G + G_cast = [g.to(dtype=m.dtype) for g, m in zip(G_local, M_local)] + torch._foreach_lerp_(M_local, G_cast, 1 - momentum) + + base_opt_fn = _get_base_opt_fn(base_opt) + + if N > 1 and process_group is not None: + # --- Distributed DDP megabatch --- + pad_n = (world_size - N % world_size) % world_size + if pad_n > 0: + M_work = M_local + [torch.zeros_like(M_local[0])] * pad_n + R_work = R + [torch.eye( + R[0].shape[-1], device=R[0].device, dtype=R[0].dtype + ).expand_as(R[0]).clone() for _ in range(pad_n)] + else: + M_work = M_local + R_work = R + + N_total = len(M_work) + per_rank = N_total // world_size + + start = device_rank * per_rank + my_M = torch.stack(M_work[start : start + per_rank]).float() + my_R = torch.stack(R_work[start : start + per_rank]) + + my_U, my_R_new = _aro_step_batched(my_M, my_R, base_opt_fn) + + # All-gather update directions and new rotations concurrently + all_U = [torch.empty_like(my_U) for _ in range(world_size)] + all_R = [torch.empty_like(my_R_new) for _ in range(world_size)] + work_u = dist.all_gather( + all_U, my_U.contiguous(), group=process_group, async_op=True + ) + work_r = dist.all_gather( + all_R, my_R_new.contiguous(), group=process_group, async_op=True + ) + yield + work_u.wait() + work_r.wait() + + U_list = [all_U[r][i] for r in range(world_size) for i in range(per_rank)][:N] + R_new_list = [all_R[r][i] for r in range(world_size) for i in range(per_rank)][:N] + + elif N == 1: + U, R_new = _aro_step_single(M_local[0].float(), R[0], base_opt_fn) + U_list = [U] + R_new_list = [R_new] + + else: + # N > 1, no process_group (single GPU or batch-sharded) + M_stack = torch.stack(M_local).float() + R_stack = torch.stack(R) + U_stack, R_new_stack = _aro_step_batched(M_stack, R_stack, base_opt_fn) + U_list = [U_stack[i] for i in range(N)] + R_new_list = [R_new_stack[i] for i in range(N)] + + # Update rotation state in-place + for r, r_new in zip(R, R_new_list): + r.copy_(r_new) + + # Compute adjusted learning rate + if adjust_lr is None: + adjusted_lr = lr + elif adjust_lr == "spectral_norm": + adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape, flatten=flatten) + elif adjust_lr == "rms_norm": + adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape, flatten=flatten) + else: + raise ValueError(f"Unknown adjust_lr: {adjust_lr}") + + # Apply weight decay and parameter update + aro_post_update( + X=to_local(X), + U=U_list, + base_lr=lr, + adjusted_lr=adjusted_lr, + weight_decay=weight_decay, + ) + + +def _aro_step_single(M: Tensor, R: Tensor, base_opt_fn) -> Tuple[Tensor, Tensor]: + """ARO step for a single parameter. M: [m, n] float32, R: [m, m] float32.""" + # Rotate gradient into R's frame + rotated = R.mT @ M + f_rotated = base_opt_fn(rotated) + + # Cross-alignment → QR for new rotation + cross = M @ f_rotated.mT + Q, _ = torch.linalg.qr(cross) + + # Re-rotate with new R, apply base opt, rotate back + rotated_new = Q.mT @ M + f_new = base_opt_fn(rotated_new) + U = Q @ f_new + + return U, Q + + +def _aro_step_batched(M: Tensor, R: Tensor, base_opt_fn) -> Tuple[Tensor, Tensor]: + """Batched ARO step. M: [N, m, n] float32, R: [N, m, m] float32.""" + rotated = R.mT @ M + f_rotated = base_opt_fn(rotated) + + cross = M @ f_rotated.mT + Q, _ = torch.linalg.qr(cross) + + rotated_new = Q.mT @ M + f_new = base_opt_fn(rotated_new) + U = Q @ f_new + + return U, Q + + +def _get_base_opt_fn(base_opt: str): + if base_opt == "row_norm": + return _base_opt_row_norm + elif base_opt == "sign": + return _base_opt_sign + raise ValueError(f"Unknown base_opt: {base_opt}") + + +def _base_opt_row_norm(X: Tensor) -> Tensor: + """f(X) = sqrt(n) * X / ||x_i||_row""" + n = X.shape[-1] + row_norms = X.norm(dim=-1, keepdim=True).clamp(min=1e-8) + return math.sqrt(n) * X / row_norms + + +def _base_opt_sign(X: Tensor) -> Tensor: + """f(X) = sign(X)""" + return torch.sign(X) + + +@torch.compile(fullgraph=True) +def aro_post_update( + X: List[Tensor], + U: List[Tensor], + base_lr: Tensor, + adjusted_lr: Tensor, + weight_decay: Tensor, +): + """Apply weight decay and parameter update.""" + torch._foreach_mul_(X, 1 - base_lr * weight_decay) + dtype = X[0].dtype + U = [u.to(dtype=dtype) for u in U] + torch._foreach_mul_(U, -adjusted_lr) + torch._foreach_add_(X, U) From 8cd539a12da299f1b3bb994b8d793e28a8a33e52 Mon Sep 17 00:00:00 2001 From: John Langford Date: Sat, 28 Mar 2026 18:11:42 -0700 Subject: [PATCH 02/18] Refactor ARO to reuse megabatch_orthogonalize_async for QR distribution MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Instead of reimplementing the pad/stack/assign/all-gather pattern, ARO now plugs QR decomposition into the shared megabatch infrastructure via yield-from, matching how NorMuon and Dion2 plug in Newton-Schulz. Pre-compute (cross-alignment) and post-compute (rotation → update) are done locally; only the QR orthogonalization is distributed. --- dion/aro.py | 126 +++++++++++++++------------------------------------- 1 file changed, 36 insertions(+), 90 deletions(-) diff --git a/dion/aro.py b/dion/aro.py index 30f24ee..a4024b0 100644 --- a/dion/aro.py +++ b/dion/aro.py @@ -1,6 +1,5 @@ import math import torch -import torch.distributed as dist from collections import defaultdict from torch import Tensor from torch.distributed import ProcessGroup @@ -8,7 +7,7 @@ from torch.optim.optimizer import ParamsT from typing import Callable, Generator, List, Optional, Tuple, Union -from .megabatch_base import DistributedOrthoBase +from .megabatch_base import DistributedOrthoBase, megabatch_orthogonalize_async from .opt_utils import AsyncTask, to_local from .muon import adjust_lr_spectral_norm, adjust_lr_rms_norm @@ -173,9 +172,9 @@ def aro_update_megabatch_async( process_group: Optional[ProcessGroup] = None, ) -> Generator[None, None, None]: """ - Megabatched ARO update. Distributes the per-parameter ARO computation - (QR + matmuls) across ranks via all-gather, matching the DDP pattern - used by Muon/NorMuon for Newton-Schulz. + Megabatched ARO update. Pre-computes cross-alignment matrices locally, + then distributes QR orthogonalization across ranks via the shared + megabatch_orthogonalize_async infrastructure. """ N = len(X) assert N == len(G) == len(M) == len(R) @@ -189,59 +188,39 @@ def aro_update_megabatch_async( base_opt_fn = _get_base_opt_fn(base_opt) - if N > 1 and process_group is not None: - # --- Distributed DDP megabatch --- - pad_n = (world_size - N % world_size) % world_size - if pad_n > 0: - M_work = M_local + [torch.zeros_like(M_local[0])] * pad_n - R_work = R + [torch.eye( - R[0].shape[-1], device=R[0].device, dtype=R[0].dtype - ).expand_as(R[0]).clone() for _ in range(pad_n)] - else: - M_work = M_local - R_work = R - - N_total = len(M_work) - per_rank = N_total // world_size - - start = device_rank * per_rank - my_M = torch.stack(M_work[start : start + per_rank]).float() - my_R = torch.stack(R_work[start : start + per_rank]) - - my_U, my_R_new = _aro_step_batched(my_M, my_R, base_opt_fn) - - # All-gather update directions and new rotations concurrently - all_U = [torch.empty_like(my_U) for _ in range(world_size)] - all_R = [torch.empty_like(my_R_new) for _ in range(world_size)] - work_u = dist.all_gather( - all_U, my_U.contiguous(), group=process_group, async_op=True - ) - work_r = dist.all_gather( - all_R, my_R_new.contiguous(), group=process_group, async_op=True - ) - yield - work_u.wait() - work_r.wait() - - U_list = [all_U[r][i] for r in range(world_size) for i in range(per_rank)][:N] - R_new_list = [all_R[r][i] for r in range(world_size) for i in range(per_rank)][:N] - - elif N == 1: - U, R_new = _aro_step_single(M_local[0].float(), R[0], base_opt_fn) - U_list = [U] - R_new_list = [R_new] - - else: - # N > 1, no process_group (single GPU or batch-sharded) - M_stack = torch.stack(M_local).float() - R_stack = torch.stack(R) - U_stack, R_new_stack = _aro_step_batched(M_stack, R_stack, base_opt_fn) - U_list = [U_stack[i] for i in range(N)] - R_new_list = [R_new_stack[i] for i in range(N)] + # Pre-compute cross-alignment matrices (local, all params) + cross_list = [] + for i in range(N): + M_f32 = M_local[i].float() + rotated = R[i].mT @ M_f32 + f_rotated = base_opt_fn(rotated) + cross_list.append(M_f32 @ f_rotated.mT) + + # Distribute QR across ranks via shared megabatch infrastructure + def qr_orthogonalize(X_in, epsilon=None): + Q, _ = torch.linalg.qr(X_in) + return Q + + Q_list = yield from megabatch_orthogonalize_async( + cross_list, + comm_dim=None, # non-sharded + device_rank=device_rank, + world_size=world_size, + process_group=process_group, + newton_schulz_func=qr_orthogonalize, + flatten=False, + epsilon=epsilon, + ) - # Update rotation state in-place - for r, r_new in zip(R, R_new_list): - r.copy_(r_new) + # Post-compute: use new rotations to produce update directions (local, all params) + U_list = [] + for i in range(N): + Q = Q_list[i] + R[i].copy_(Q) + M_f32 = M_local[i].float() + rotated_new = Q.mT @ M_f32 + f_new = base_opt_fn(rotated_new) + U_list.append(Q @ f_new) # Compute adjusted learning rate if adjust_lr is None: @@ -263,39 +242,6 @@ def aro_update_megabatch_async( ) -def _aro_step_single(M: Tensor, R: Tensor, base_opt_fn) -> Tuple[Tensor, Tensor]: - """ARO step for a single parameter. M: [m, n] float32, R: [m, m] float32.""" - # Rotate gradient into R's frame - rotated = R.mT @ M - f_rotated = base_opt_fn(rotated) - - # Cross-alignment → QR for new rotation - cross = M @ f_rotated.mT - Q, _ = torch.linalg.qr(cross) - - # Re-rotate with new R, apply base opt, rotate back - rotated_new = Q.mT @ M - f_new = base_opt_fn(rotated_new) - U = Q @ f_new - - return U, Q - - -def _aro_step_batched(M: Tensor, R: Tensor, base_opt_fn) -> Tuple[Tensor, Tensor]: - """Batched ARO step. M: [N, m, n] float32, R: [N, m, m] float32.""" - rotated = R.mT @ M - f_rotated = base_opt_fn(rotated) - - cross = M @ f_rotated.mT - Q, _ = torch.linalg.qr(cross) - - rotated_new = Q.mT @ M - f_new = base_opt_fn(rotated_new) - U = Q @ f_new - - return U, Q - - def _get_base_opt_fn(base_opt: str): if base_opt == "row_norm": return _base_opt_row_norm From 09126a2bf5cd95bebba39dcd337ed5b2067ba6b1 Mon Sep 17 00:00:00 2001 From: John Langford Date: Sat, 28 Mar 2026 18:36:06 -0700 Subject: [PATCH 03/18] Add sinkhorn as base_opt for ARO Implements SR-Sinkhorn normalization (alternating L2 row/column normalization), the recommended base optimizer from the ARO paper. Stateless, 5 iterations by default. --- dion/aro.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/dion/aro.py b/dion/aro.py index a4024b0..261cde9 100644 --- a/dion/aro.py +++ b/dion/aro.py @@ -36,6 +36,7 @@ class ARO(DistributedOrthoBase): weight_decay: Weight decay factor. epsilon: Small value for numerical stability. base_opt: Base optimizer function applied in the rotated frame. + "sinkhorn": Alternating row/column L2 normalization (recommended). "row_norm": f(X) = sqrt(n) * X / ||x_i|| (row normalization) "sign": f(X) = sign(X) adjust_lr: How to adjust the learning rate for ARO updates. @@ -62,9 +63,9 @@ def __init__( raise ValueError(f"Invalid momentum factor: {mu}") if len(betas) != 2 or betas[0] < 0.0 or betas[1] < 0.0: raise ValueError(f"Invalid betas: {betas}") - if base_opt not in ("row_norm", "sign"): + if base_opt not in ("row_norm", "sign", "sinkhorn"): raise ValueError( - f"Invalid base_opt: {base_opt}. Must be 'row_norm' or 'sign'." + f"Invalid base_opt: {base_opt}. Must be 'row_norm', 'sign', or 'sinkhorn'." ) if adjust_lr not in ("spectral_norm", "rms_norm", None): raise ValueError( @@ -247,6 +248,8 @@ def _get_base_opt_fn(base_opt: str): return _base_opt_row_norm elif base_opt == "sign": return _base_opt_sign + elif base_opt == "sinkhorn": + return _base_opt_sinkhorn raise ValueError(f"Unknown base_opt: {base_opt}") @@ -262,6 +265,27 @@ def _base_opt_sign(X: Tensor) -> Tensor: return torch.sign(X) +def _base_opt_sinkhorn(X: Tensor, num_iters: int = 5, eps: float = 1e-8) -> Tensor: + """SR-Sinkhorn normalization: alternating L2 row/column normalization. + + Each iteration normalizes rows to have L2 norm sqrt(cols), then + columns to have L2 norm sqrt(rows). This corresponds to the + square-root iterates of the classical Sinkhorn algorithm applied + to the matrix of squared entries. + + Reference: https://arxiv.org/abs/2502.06742 + """ + m, n = X.shape[-2], X.shape[-1] + for _ in range(num_iters): + # Row normalization: each row gets L2 norm sqrt(n) + row_norms = X.norm(dim=-1, keepdim=True).clamp(min=eps) + X = X * (math.sqrt(n) / row_norms) + # Column normalization: each column gets L2 norm sqrt(m) + col_norms = X.norm(dim=-2, keepdim=True).clamp(min=eps) + X = X * (math.sqrt(m) / col_norms) + return X + + @torch.compile(fullgraph=True) def aro_post_update( X: List[Tensor], From 7758f51400ebaf8f835bb0c1425ecb9a24865cd2 Mon Sep 17 00:00:00 2001 From: John Langford Date: Sat, 28 Mar 2026 19:36:23 -0700 Subject: [PATCH 04/18] Default base_opt to sinkhorn (paper's recommendation) --- dion/aro.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dion/aro.py b/dion/aro.py index 261cde9..788e4f7 100644 --- a/dion/aro.py +++ b/dion/aro.py @@ -53,7 +53,7 @@ def __init__( betas: Tuple[float, float] = (0.9, 0.95), weight_decay: float = 0.01, epsilon: float = 1e-8, - base_opt: str = "row_norm", + base_opt: str = "sinkhorn", adjust_lr: Optional[str] = "rms_norm", flatten: bool = False, ): From 82e9ce024ca4c4b8b522e0d84e531d0bce4bf2d9 Mon Sep 17 00:00:00 2001 From: John Langford Date: Sat, 28 Mar 2026 20:19:59 -0700 Subject: [PATCH 05/18] Add FSDP support to ARO Move the full ARO computation (rotation, cross-alignment, QR, update direction) into a closure passed to megabatch_orthogonalize_async. For FSDP, the all-to-all reassembles full matrices before the closure runs; for DDP, each rank runs the closure on its assigned chunk. The closure captures R for the assigned params, so rotation state stays consistent without a separate all-gather. --- dion/aro.py | 94 ++++++++++++++++++++++++++++++++--------------------- 1 file changed, 57 insertions(+), 37 deletions(-) diff --git a/dion/aro.py b/dion/aro.py index 788e4f7..776050b 100644 --- a/dion/aro.py +++ b/dion/aro.py @@ -23,8 +23,6 @@ class ARO(DistributedOrthoBase): Each parameter of shape [m, n] maintains an m×m rotation matrix R in float32, adding O(m²) memory per parameter (e.g., 64 MB for m=4096). - FSDP is not supported — use DDP or single-GPU. - Reference: https://arxiv.org/abs/2602.09006 Args: @@ -135,14 +133,8 @@ def _create_ortho_tasks( self._get_shard_info(params[0], group) ) - if is_matrix_sharded: - raise NotImplementedError( - "ARO does not support FSDP-sharded parameters. " - "Use DDP or single-GPU instead." - ) - megabatch_args = update_args - if is_batch_sharded: + if is_batch_sharded and not is_matrix_sharded: megabatch_args = {**update_args, "process_group": None} yield AsyncTask( @@ -151,6 +143,7 @@ def _create_ortho_tasks( G=gradients, M=momentums, R=rotations, + shard_dim=sharded_tensor_dim, **megabatch_args, ) ) @@ -170,12 +163,18 @@ def aro_update_megabatch_async( adjust_lr: Optional[str], device_rank: int, world_size: int, + shard_dim: Optional[int] = None, process_group: Optional[ProcessGroup] = None, ) -> Generator[None, None, None]: """ - Megabatched ARO update. Pre-computes cross-alignment matrices locally, - then distributes QR orthogonalization across ranks via the shared + Megabatched ARO update. Distributes the full ARO computation (rotation, + cross-alignment, QR, update direction) across ranks via the shared megabatch_orthogonalize_async infrastructure. + + For FSDP: the all-to-all reassembles full matrices before the ARO step; + for DDP: each rank processes its assigned params then all-gathers. + In both cases, the ARO computation runs inside a closure that captures + the rotation matrices R for the assigned params. """ N = len(X) assert N == len(G) == len(M) == len(R) @@ -188,40 +187,61 @@ def aro_update_megabatch_async( torch._foreach_lerp_(M_local, G_cast, 1 - momentum) base_opt_fn = _get_base_opt_fn(base_opt) - - # Pre-compute cross-alignment matrices (local, all params) - cross_list = [] - for i in range(N): - M_f32 = M_local[i].float() - rotated = R[i].mT @ M_f32 + comm_dim = (shard_dim - X[0].ndim) if shard_dim is not None else None + + # Determine which params are assigned to this rank so the closure + # can access the right rotation matrices. + pad_n = (world_size - N % world_size) % world_size if process_group is not None and N > 1 else 0 + N_total = N + pad_n + per_rank = N_total // world_size if process_group is not None and N > 1 else N + start = device_rank * per_rank if process_group is not None and N > 1 else 0 + + # Pad R to match padded M list, stack this rank's assigned rotations + R_padded = R + [torch.eye( + R[0].shape[-1], device=R[0].device, dtype=R[0].dtype + )] * pad_n if pad_n > 0 else R + R_my = torch.stack(R_padded[start : start + per_rank]) + R_new_holder = [None] + + def aro_ortho_fn(M_batch, epsilon=None): + """Full ARO step: rotation → base_opt → cross-alignment → QR → update. + + M_batch is [per_rank, m, n] — full (unsharded) matrices for the + params assigned to this rank, after all-to-all reassembly (FSDP) + or direct stacking (DDP). + """ + M_f32 = M_batch.float() + rotated = R_my.mT @ M_f32 f_rotated = base_opt_fn(rotated) - cross_list.append(M_f32 @ f_rotated.mT) - - # Distribute QR across ranks via shared megabatch infrastructure - def qr_orthogonalize(X_in, epsilon=None): - Q, _ = torch.linalg.qr(X_in) - return Q - - Q_list = yield from megabatch_orthogonalize_async( - cross_list, - comm_dim=None, # non-sharded + cross = M_f32 @ f_rotated.mT + Q, _ = torch.linalg.qr(cross) + R_new_holder[0] = Q + rotated_new = Q.mT @ M_f32 + f_new = base_opt_fn(rotated_new) + return (Q @ f_new).to(M_batch.dtype) + + # Distribute ARO computation via shared megabatch infrastructure. + # For FSDP: all-to-all reassembles full M, aro_ortho_fn runs on full + # matrices, result is scattered back. + # For DDP: each rank runs aro_ortho_fn on its assigned chunk, all-gather. + U_list = yield from megabatch_orthogonalize_async( + M_local, + comm_dim=comm_dim, device_rank=device_rank, world_size=world_size, process_group=process_group, - newton_schulz_func=qr_orthogonalize, + newton_schulz_func=aro_ortho_fn, flatten=False, epsilon=epsilon, ) - # Post-compute: use new rotations to produce update directions (local, all params) - U_list = [] - for i in range(N): - Q = Q_list[i] - R[i].copy_(Q) - M_f32 = M_local[i].float() - rotated_new = Q.mT @ M_f32 - f_new = base_opt_fn(rotated_new) - U_list.append(Q @ f_new) + # Update rotation state for this rank's assigned params + if R_new_holder[0] is not None: + Q_new = R_new_holder[0] + for i in range(per_rank): + idx = start + i + if idx < N: + R[idx].copy_(Q_new[i]) # Compute adjusted learning rate if adjust_lr is None: From bd02502ce0d13fbe9baccc63480b84806ccb84a1 Mon Sep 17 00:00:00 2001 From: John Langford Date: Sat, 28 Mar 2026 21:00:40 -0700 Subject: [PATCH 06/18] Fix cusolver OOM in ARO FSDP path by freeing intermediates before QR Delete rotated, f_rotated, and cross tensors as soon as they're no longer needed, reducing peak memory during the QR decomposition. cusolver's cusolverDnCreate fails with CUSOLVER_STATUS_INTERNAL_ERROR when GPU memory is exhausted by the preceding all-to-all + float32 intermediates. --- dion/aro.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/dion/aro.py b/dion/aro.py index 776050b..3c0756a 100644 --- a/dion/aro.py +++ b/dion/aro.py @@ -209,15 +209,29 @@ def aro_ortho_fn(M_batch, epsilon=None): M_batch is [per_rank, m, n] — full (unsharded) matrices for the params assigned to this rank, after all-to-all reassembly (FSDP) or direct stacking (DDP). + + Intermediates are explicitly deleted before QR to reduce peak + memory — cusolver needs workspace and will fail with + CUSOLVER_STATUS_INTERNAL_ERROR if GPU memory is exhausted. """ M_f32 = M_batch.float() + + # Phase 1: compute cross-alignment matrix, then free intermediates rotated = R_my.mT @ M_f32 f_rotated = base_opt_fn(rotated) + del rotated cross = M_f32 @ f_rotated.mT + del f_rotated + + # Phase 2: QR (needs cusolver workspace) Q, _ = torch.linalg.qr(cross) + del cross R_new_holder[0] = Q + + # Phase 3: compute update direction with new rotation rotated_new = Q.mT @ M_f32 f_new = base_opt_fn(rotated_new) + del rotated_new return (Q @ f_new).to(M_batch.dtype) # Distribute ARO computation via shared megabatch infrastructure. From 75b73dca4dec80d2de97a288da7af8db6ae2dad1 Mon Sep 17 00:00:00 2001 From: John Langford Date: Sun, 29 Mar 2026 04:21:59 -0700 Subject: [PATCH 07/18] Fix cusolver OOM: free M_f32 before QR and empty CUDA cache cusolver allocates workspace outside PyTorch's caching allocator, so freed-but-cached tensor blocks aren't visible to it. Two changes: - del M_f32 before QR (recompute it after from M_batch for phase 3) - torch.cuda.empty_cache() before QR to release cached blocks --- dion/aro.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/dion/aro.py b/dion/aro.py index 3c0756a..bdc99c6 100644 --- a/dion/aro.py +++ b/dion/aro.py @@ -221,17 +221,20 @@ def aro_ortho_fn(M_batch, epsilon=None): f_rotated = base_opt_fn(rotated) del rotated cross = M_f32 @ f_rotated.mT - del f_rotated + del f_rotated, M_f32 - # Phase 2: QR (needs cusolver workspace) + # Phase 2: QR — cusolver allocates workspace outside PyTorch's + # caching allocator, so we must release cached blocks first. + torch.cuda.empty_cache() Q, _ = torch.linalg.qr(cross) del cross R_new_holder[0] = Q - # Phase 3: compute update direction with new rotation + # Phase 3: recompute M_f32 (freed above to make room for QR) + M_f32 = M_batch.float() rotated_new = Q.mT @ M_f32 f_new = base_opt_fn(rotated_new) - del rotated_new + del rotated_new, M_f32 return (Q @ f_new).to(M_batch.dtype) # Distribute ARO computation via shared megabatch infrastructure. From ee36c410164a7ed25053be2348d0292a0d4054b9 Mon Sep 17 00:00:00 2001 From: John Langford Date: Sun, 29 Mar 2026 06:35:59 -0700 Subject: [PATCH 08/18] Use magma backend for QR to avoid cusolver handle creation failure cusolverDnCreate fails with INTERNAL_ERROR under FSDP memory pressure regardless of available memory. Switch to magma backend for the QR call, restoring the previous backend after. --- dion/aro.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/dion/aro.py b/dion/aro.py index bdc99c6..888318c 100644 --- a/dion/aro.py +++ b/dion/aro.py @@ -223,10 +223,16 @@ def aro_ortho_fn(M_batch, epsilon=None): cross = M_f32 @ f_rotated.mT del f_rotated, M_f32 - # Phase 2: QR — cusolver allocates workspace outside PyTorch's - # caching allocator, so we must release cached blocks first. + # Phase 2: QR — use magma backend to avoid cusolver handle + # creation failures (cusolverDnCreate INTERNAL_ERROR) under + # memory pressure in FSDP configurations. torch.cuda.empty_cache() - Q, _ = torch.linalg.qr(cross) + prev_lib = torch.backends.cuda.preferred_linalg_library() + try: + torch.backends.cuda.preferred_linalg_library("magma") + Q, _ = torch.linalg.qr(cross) + finally: + torch.backends.cuda.preferred_linalg_library(prev_lib) del cross R_new_holder[0] = Q From d0dc3dba3c9eec2f63250579e1ca82585e3ea63d Mon Sep 17 00:00:00 2001 From: John Langford Date: Sun, 29 Mar 2026 06:56:59 -0700 Subject: [PATCH 09/18] Move QR to CPU to avoid GPU linalg workspace exhaustion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Both cusolver and magma fail under FSDP memory pressure because the all-to-all reassembly of full matrices leaves no room for linalg workspace. CPU QR sidesteps this — the cross matrix is square [m, m] so the transfer and computation are cheap relative to the GPU matmuls. The paper's approach (Shifted Cholesky QR / fully distributed rotation) avoids full reassembly entirely but requires a larger architectural change. --- dion/aro.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/dion/aro.py b/dion/aro.py index 888318c..6cbaba0 100644 --- a/dion/aro.py +++ b/dion/aro.py @@ -223,16 +223,12 @@ def aro_ortho_fn(M_batch, epsilon=None): cross = M_f32 @ f_rotated.mT del f_rotated, M_f32 - # Phase 2: QR — use magma backend to avoid cusolver handle - # creation failures (cusolverDnCreate INTERNAL_ERROR) under - # memory pressure in FSDP configurations. - torch.cuda.empty_cache() - prev_lib = torch.backends.cuda.preferred_linalg_library() - try: - torch.backends.cuda.preferred_linalg_library("magma") - Q, _ = torch.linalg.qr(cross) - finally: - torch.backends.cuda.preferred_linalg_library(prev_lib) + # Phase 2: QR on CPU — both cusolver and magma fail under FSDP + # memory pressure because the all-to-all reassembly of full + # matrices leaves no room for linalg workspace on GPU. + # CPU QR is cheap since cross is square [per_rank, m, m]. + Q, _ = torch.linalg.qr(cross.cpu()) + Q = Q.to(device=cross.device, dtype=cross.dtype) del cross R_new_holder[0] = Q From 9d9c752d5aa403658eb6785ff37cbf201be9df03 Mon Sep 17 00:00:00 2001 From: John Langford Date: Sun, 29 Mar 2026 12:19:57 -0700 Subject: [PATCH 10/18] Fix NameError for N=1 shape groups under FSDP The padding block guarded by N > 1 skipped setting per_rank, but the FSDP all-to-all path still referenced it when N=1, causing a NameError. The corrupted async task then left CUDA state dirty, making subsequent QR calls fail with cusolver INTERNAL_ERROR. Fix: also enter the padding block when comm_dim is not None (FSDP), regardless of N. --- dion/megabatch_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dion/megabatch_base.py b/dion/megabatch_base.py index 44f621b..ef0358e 100644 --- a/dion/megabatch_base.py +++ b/dion/megabatch_base.py @@ -260,7 +260,7 @@ def megabatch_orthogonalize_async( N = len(U) # Pad to divisible by world_size (needed by both distributed paths) - if process_group is not None and N > 1: + if process_group is not None and (N > 1 or comm_dim is not None): pad_n = (world_size - N % world_size) % world_size U_work = U + [torch.zeros_like(U[0])] * pad_n if pad_n > 0 else U N_total = len(U_work) From 7644740bab2b53b8245b7e3407f2a89e986f279b Mon Sep 17 00:00:00 2001 From: John Langford Date: Sun, 29 Mar 2026 13:40:43 -0700 Subject: [PATCH 11/18] Move QR back to GPU now that the root cause (N=1 NameError) is fixed --- dion/aro.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/dion/aro.py b/dion/aro.py index 6cbaba0..2a00031 100644 --- a/dion/aro.py +++ b/dion/aro.py @@ -223,12 +223,8 @@ def aro_ortho_fn(M_batch, epsilon=None): cross = M_f32 @ f_rotated.mT del f_rotated, M_f32 - # Phase 2: QR on CPU — both cusolver and magma fail under FSDP - # memory pressure because the all-to-all reassembly of full - # matrices leaves no room for linalg workspace on GPU. - # CPU QR is cheap since cross is square [per_rank, m, m]. - Q, _ = torch.linalg.qr(cross.cpu()) - Q = Q.to(device=cross.device, dtype=cross.dtype) + # Phase 2: QR + Q, _ = torch.linalg.qr(cross) del cross R_new_holder[0] = Q From 6d3a92d1d604e5caf84987bda23dde499bf92c6c Mon Sep 17 00:00:00 2001 From: John Langford Date: Sun, 29 Mar 2026 14:10:31 -0700 Subject: [PATCH 12/18] Replace Householder QR with Shifted Cholesky QR Use matmul + Cholesky + triangular solve instead of torch.linalg.qr. This matches the ARO paper's recommended implementation and dion.py's existing orthogonalize() pattern: G = A^T A + shift*I R = cholesky(G, upper=True) Q = solve_triangular(R, A, upper=True, left=False) Cholesky QR uses far less GPU workspace than Householder QR, avoiding the cusolver/magma OOM issues under FSDP memory pressure. Falls back to Householder QR if the Cholesky factorization fails. --- dion/aro.py | 42 +++++++++++++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/dion/aro.py b/dion/aro.py index 2a00031..c9aa744 100644 --- a/dion/aro.py +++ b/dion/aro.py @@ -210,26 +210,23 @@ def aro_ortho_fn(M_batch, epsilon=None): params assigned to this rank, after all-to-all reassembly (FSDP) or direct stacking (DDP). - Intermediates are explicitly deleted before QR to reduce peak - memory — cusolver needs workspace and will fail with - CUSOLVER_STATUS_INTERNAL_ERROR if GPU memory is exhausted. """ M_f32 = M_batch.float() - # Phase 1: compute cross-alignment matrix, then free intermediates + # Phase 1: compute cross-alignment matrix rotated = R_my.mT @ M_f32 f_rotated = base_opt_fn(rotated) del rotated cross = M_f32 @ f_rotated.mT - del f_rotated, M_f32 + del f_rotated - # Phase 2: QR - Q, _ = torch.linalg.qr(cross) + # Phase 2: Shifted Cholesky QR — uses only matmul + Cholesky + + # triangular solve, avoiding Householder QR's large workspace. + Q = _shifted_cholesky_qr(cross) del cross R_new_holder[0] = Q - # Phase 3: recompute M_f32 (freed above to make room for QR) - M_f32 = M_batch.float() + # Phase 3: compute update direction with new rotation rotated_new = Q.mT @ M_f32 f_new = base_opt_fn(rotated_new) del rotated_new, M_f32 @@ -278,6 +275,33 @@ def aro_ortho_fn(M_batch, epsilon=None): ) +def _shifted_cholesky_qr(A: Tensor) -> Tensor: + """Orthogonalize A via Shifted Cholesky QR. + + Uses matmul + Cholesky + triangular solve, which need far less + GPU workspace than Householder QR (torch.linalg.qr). Adds a + small shift to the Gram matrix diagonal for numerical stability. + + If Cholesky fails (input too ill-conditioned), falls back to + Householder QR. + + Same approach as dion.py's orthogonalize() and the ARO paper's + recommended implementation. + """ + G = A.mT @ A # Gram matrix [*, m, m], via cuBLAS + # Shift proportional to the Frobenius norm of A + shift = A.norm() ** 2 * 1e-7 + G.diagonal(dim1=-2, dim2=-1).add_(shift) + # Upper Cholesky: G = R^T R, then Q = A @ R^{-1} + # Same pattern as dion.py's orthogonalize() + R, info = torch.linalg.cholesky_ex(G, upper=True) + if (info != 0).any(): + # Fallback: Householder QR for ill-conditioned inputs + Q, _ = torch.linalg.qr(A) + return Q + return torch.linalg.solve_triangular(R, A, upper=True, left=False) + + def _get_base_opt_fn(base_opt: str): if base_opt == "row_norm": return _base_opt_row_norm From fdacce65c6df94a2e3559a3f44c146056d308929 Mon Sep 17 00:00:00 2001 From: JohnLangford Date: Sun, 29 Mar 2026 16:55:11 -0700 Subject: [PATCH 13/18] Release cached memory before cusolver calls After forward+backward, PyTorch's caching allocator holds ~164GB reserved but only ~2GB allocated. cusolver allocates outside the caching allocator and only sees ~28GB free, causing cusolverDnCreate to fail. torch.cuda.empty_cache() releases cached blocks back to CUDA before the Cholesky/QR calls. Also free M_f32 before the QR (recompute it after for phase 3) to reduce peak memory during decomposition. --- dion/aro.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dion/aro.py b/dion/aro.py index c9aa744..da54e35 100644 --- a/dion/aro.py +++ b/dion/aro.py @@ -218,15 +218,19 @@ def aro_ortho_fn(M_batch, epsilon=None): f_rotated = base_opt_fn(rotated) del rotated cross = M_f32 @ f_rotated.mT - del f_rotated + del f_rotated, M_f32 # Phase 2: Shifted Cholesky QR — uses only matmul + Cholesky + # triangular solve, avoiding Householder QR's large workspace. + # Release cached-but-free memory so cusolver can allocate handles + # and workspace (it allocates outside PyTorch's caching allocator). + torch.cuda.empty_cache() Q = _shifted_cholesky_qr(cross) del cross R_new_holder[0] = Q # Phase 3: compute update direction with new rotation + M_f32 = M_batch.float() rotated_new = Q.mT @ M_f32 f_new = base_opt_fn(rotated_new) del rotated_new, M_f32 @@ -297,6 +301,7 @@ def _shifted_cholesky_qr(A: Tensor) -> Tensor: R, info = torch.linalg.cholesky_ex(G, upper=True) if (info != 0).any(): # Fallback: Householder QR for ill-conditioned inputs + torch.cuda.empty_cache() Q, _ = torch.linalg.qr(A) return Q return torch.linalg.solve_triangular(R, A, upper=True, left=False) From 1803abf6df1e5c599e4d3fb32a6cf7873ec5b63c Mon Sep 17 00:00:00 2001 From: JohnLangford Date: Sun, 29 Mar 2026 17:04:43 -0700 Subject: [PATCH 14/18] Move empty_cache inside _shifted_cholesky_qr, right before cholesky_ex The earlier empty_cache before _shifted_cholesky_qr is ineffective because G = A.mT @ A re-fills the caching allocator. The cache must be released after G is computed but before cholesky_ex calls cusolver. --- dion/aro.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dion/aro.py b/dion/aro.py index da54e35..54f784a 100644 --- a/dion/aro.py +++ b/dion/aro.py @@ -296,8 +296,10 @@ def _shifted_cholesky_qr(A: Tensor) -> Tensor: # Shift proportional to the Frobenius norm of A shift = A.norm() ** 2 * 1e-7 G.diagonal(dim1=-2, dim2=-1).add_(shift) + # Release cached memory right before cusolver call — the matmul above + # re-fills the cache, so the earlier empty_cache() is not sufficient. + torch.cuda.empty_cache() # Upper Cholesky: G = R^T R, then Q = A @ R^{-1} - # Same pattern as dion.py's orthogonalize() R, info = torch.linalg.cholesky_ex(G, upper=True) if (info != 0).any(): # Fallback: Householder QR for ill-conditioned inputs From d5a31dd3c9c3383ff6fe5d95f5cd4a63b4d5907f Mon Sep 17 00:00:00 2001 From: JohnLangford Date: Sun, 29 Mar 2026 17:17:40 -0700 Subject: [PATCH 15/18] Add synchronize before empty_cache in cusolver path empty_cache cannot release blocks with pending operations on other CUDA streams. The megabatch all-to-all and torch.compile ops may run on separate streams, preventing cache release. synchronize() ensures all pending ops complete before releasing cached blocks. --- dion/aro.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/dion/aro.py b/dion/aro.py index 54f784a..51d9f74 100644 --- a/dion/aro.py +++ b/dion/aro.py @@ -296,13 +296,16 @@ def _shifted_cholesky_qr(A: Tensor) -> Tensor: # Shift proportional to the Frobenius norm of A shift = A.norm() ** 2 * 1e-7 G.diagonal(dim1=-2, dim2=-1).add_(shift) - # Release cached memory right before cusolver call — the matmul above - # re-fills the cache, so the earlier empty_cache() is not sufficient. + # Synchronize + release cached memory before cusolver call. + # Without synchronize, empty_cache cannot release blocks with pending + # ops on other CUDA streams (e.g. NCCL all-to-all, torch.compile). + torch.cuda.synchronize() torch.cuda.empty_cache() # Upper Cholesky: G = R^T R, then Q = A @ R^{-1} R, info = torch.linalg.cholesky_ex(G, upper=True) if (info != 0).any(): # Fallback: Householder QR for ill-conditioned inputs + torch.cuda.synchronize() torch.cuda.empty_cache() Q, _ = torch.linalg.qr(A) return Q From 5d04dd34023b7968f85b24f15f2155bf51fba4b4 Mon Sep 17 00:00:00 2001 From: JohnLangford Date: Mon, 30 Mar 2026 07:09:22 -0700 Subject: [PATCH 16/18] Move Cholesky QR to CPU to avoid cusolver corruption MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit torch.compile on B200/CUDA 13.0 corrupts cusolver state after the first compiled forward/backward pass, making ALL GPU linalg calls fail (even 64x64 identity matrices). This is not memory-related — 190GB free and cusolver still broken. Move the Cholesky decomposition to CPU. The Gram matrix is small ([batch, m, m]) so the CPU overhead is minimal. The matmuls (A^T A and the triangular solve) stay on GPU via cuBLAS which is unaffected. --- dion/aro.py | 45 +++++++++++++++++++++------------------------ 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/dion/aro.py b/dion/aro.py index 51d9f74..b5ad18e 100644 --- a/dion/aro.py +++ b/dion/aro.py @@ -280,35 +280,32 @@ def aro_ortho_fn(M_batch, epsilon=None): def _shifted_cholesky_qr(A: Tensor) -> Tensor: - """Orthogonalize A via Shifted Cholesky QR. + """Orthogonalize A via Shifted Cholesky QR on CPU. - Uses matmul + Cholesky + triangular solve, which need far less - GPU workspace than Householder QR (torch.linalg.qr). Adds a - small shift to the Gram matrix diagonal for numerical stability. + torch.compile on B200/CUDA 13.0 corrupts cusolver state, making all + GPU linalg operations fail after the first compiled forward/backward. + We move the decomposition to CPU (Gram matrix is small: [batch, m, m]) + and keep the matmuls on GPU. - If Cholesky fails (input too ill-conditioned), falls back to - Householder QR. - - Same approach as dion.py's orthogonalize() and the ARO paper's - recommended implementation. + Uses matmul + Cholesky + triangular solve. Adds a small shift to the + Gram matrix diagonal for numerical stability. Falls back to + Householder QR if Cholesky fails. """ - G = A.mT @ A # Gram matrix [*, m, m], via cuBLAS + device = A.device + G = A.mT @ A # Gram matrix [*, m, m], via cuBLAS (on GPU) + G_cpu = G.cpu() + del G # Shift proportional to the Frobenius norm of A - shift = A.norm() ** 2 * 1e-7 - G.diagonal(dim1=-2, dim2=-1).add_(shift) - # Synchronize + release cached memory before cusolver call. - # Without synchronize, empty_cache cannot release blocks with pending - # ops on other CUDA streams (e.g. NCCL all-to-all, torch.compile). - torch.cuda.synchronize() - torch.cuda.empty_cache() - # Upper Cholesky: G = R^T R, then Q = A @ R^{-1} - R, info = torch.linalg.cholesky_ex(G, upper=True) + shift = A.norm().item() ** 2 * 1e-7 + G_cpu.diagonal(dim1=-2, dim2=-1).add_(shift) + # Cholesky on CPU (avoids cusolver entirely) + R, info = torch.linalg.cholesky_ex(G_cpu, upper=True) if (info != 0).any(): - # Fallback: Householder QR for ill-conditioned inputs - torch.cuda.synchronize() - torch.cuda.empty_cache() - Q, _ = torch.linalg.qr(A) - return Q + # Fallback: QR on CPU + A_cpu = A.cpu() + Q, _ = torch.linalg.qr(A_cpu) + return Q.to(device) + R = R.to(device) return torch.linalg.solve_triangular(R, A, upper=True, left=False) From 8242b92e543a5a7ccb67cd3ac468bd43e5cd0c1f Mon Sep 17 00:00:00 2001 From: JohnLangford Date: Mon, 30 Mar 2026 13:27:22 -0700 Subject: [PATCH 17/18] =?UTF-8?q?Revert=20to=20CPU=20QR=20=E2=80=94=20solv?= =?UTF-8?q?e=5Ftriangular=20also=20uses=20cusolver?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The shifted Cholesky QR moved cholesky to CPU but left solve_triangular on GPU, which still hits the cusolver corruption. Revert to the proven CPU QR approach that ran 569 steps successfully. --- dion/aro.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/dion/aro.py b/dion/aro.py index b5ad18e..bfbb973 100644 --- a/dion/aro.py +++ b/dion/aro.py @@ -220,12 +220,11 @@ def aro_ortho_fn(M_batch, epsilon=None): cross = M_f32 @ f_rotated.mT del f_rotated, M_f32 - # Phase 2: Shifted Cholesky QR — uses only matmul + Cholesky + - # triangular solve, avoiding Householder QR's large workspace. - # Release cached-but-free memory so cusolver can allocate handles - # and workspace (it allocates outside PyTorch's caching allocator). - torch.cuda.empty_cache() - Q = _shifted_cholesky_qr(cross) + # Phase 2: QR on CPU — cusolver is corrupted by torch.compile on + # B200/CUDA 13.0 after the first compiled forward/backward. + # CPU QR is cheap since cross is square [per_rank, m, m]. + Q, _ = torch.linalg.qr(cross.cpu()) + Q = Q.to(device=cross.device, dtype=cross.dtype) del cross R_new_holder[0] = Q From 0a2b8b965ff7ba4dfc78a6eb6a64bd135a6e2b4a Mon Sep 17 00:00:00 2001 From: JohnLangford Date: Tue, 31 Mar 2026 02:10:35 -0700 Subject: [PATCH 18/18] Disable autocast in aro_ortho_fn and ensure R_my is float32 FSDP MixedPrecisionPolicy wraps the optimizer step in bf16 autocast, causing the R_my.mT @ M_f32 matmul to fail with dtype mismatch. Wrap aro_ortho_fn in torch.autocast("cuda", enabled=False) and explicitly cast R_my to float32. --- dion/aro.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/dion/aro.py b/dion/aro.py index bfbb973..a5288f4 100644 --- a/dion/aro.py +++ b/dion/aro.py @@ -211,10 +211,17 @@ def aro_ortho_fn(M_batch, epsilon=None): or direct stacking (DDP). """ + # Disable autocast — FSDP MixedPrecisionPolicy may wrap the + # optimizer step in bf16 autocast, but ARO needs float32 for + # the rotation and cross-alignment computation. + with torch.autocast("cuda", enabled=False): + return _aro_ortho_fn_impl(M_batch, epsilon) + + def _aro_ortho_fn_impl(M_batch, epsilon=None): M_f32 = M_batch.float() # Phase 1: compute cross-alignment matrix - rotated = R_my.mT @ M_f32 + rotated = R_my.float().mT @ M_f32 f_rotated = base_opt_fn(rotated) del rotated cross = M_f32 @ f_rotated.mT