From 28458dbc2917901c2e45236befcbbdf183a49d69 Mon Sep 17 00:00:00 2001 From: Maru-mee <151493593+Maru-mee@users.noreply.github.com> Date: Mon, 29 Sep 2025 15:55:33 +0900 Subject: [PATCH 1/2] Init res state to 1.0 Fix res related states init to 1.0 against res_approx abnormal spike --- came_pytorch/CAME.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/came_pytorch/CAME.py b/came_pytorch/CAME.py index 280edfd..ca5ad4a 100644 --- a/came_pytorch/CAME.py +++ b/came_pytorch/CAME.py @@ -102,8 +102,8 @@ def step(self, closure=None): grad_shape[:-2] + grad_shape[-1:] ).type_as(grad) - state["exp_avg_res_row"] = torch.zeros(grad_shape[:-1]).type_as(grad) - state["exp_avg_res_col"] = torch.zeros( + state["exp_avg_res_row"] = torch.ones(grad_shape[:-1]).type_as(grad) + state["exp_avg_res_col"] = torch.ones( grad_shape[:-2] + grad_shape[-1:] ).type_as(grad) else: @@ -171,4 +171,4 @@ def step(self, closure=None): update.mul_(group["lr"]) p.data.add_(-update) - return loss \ No newline at end of file + return loss From 362737ed818e6ee1a62bfdcedc218f3eb3bf56bc Mon Sep 17 00:00:00 2001 From: Maru-mee <151493593+Maru-mee@users.noreply.github.com> Date: Mon, 29 Sep 2025 16:13:31 +0900 Subject: [PATCH 2/2] r_factor preventing division by zero --- came_pytorch/CAME.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/came_pytorch/CAME.py b/came_pytorch/CAME.py index ca5ad4a..4962285 100644 --- a/came_pytorch/CAME.py +++ b/came_pytorch/CAME.py @@ -60,7 +60,7 @@ def _rms(self, tensor): def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): r_factor = ( - (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)) + (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True).add_(self.param_groups[0]["eps"][1])) .rsqrt_() .unsqueeze(-1) )