-
Notifications
You must be signed in to change notification settings - Fork 14
Description
Bug
The nuPI optimizer has a bug in its dense initialization function, _nupi_sgd_init, when init_type=nuPIInitType.SGD.
At the end of the first step (t=0), it performs an incorrect state-saving operation, which causes the second optimizer update to be wrong.
Given the state update rule
the first saved state should be
Thus, after the first step (t=0), _nupi_sgd_init should store
state["xi"] = e_0 (with e_0 given by detached_error).
Instead, it incorrectly saves
state["xi"] = torch.zeros_like(param),
which forces (
On the second step (t=1), the optimizer then computes the proportional term
using this incorrect (
Steps
- Initialize the
nuPIoptimizer withinit_type=nuPIInitType.SGD(the default) on a dense parameter. - Run
optimizer.step()(t=0). The step update is correct, but the statexi_0is incorrectly saved as0. - Run
optimizer.step()(t=1). - Observe that the update computed in step 3 (t=1) is wrong, as the proportional term is based on a bad
xi_0state.
Expected behavior
The _nupi_sgd_init function must save state["xi"] = detached_error.clone() at the end of the first step (t=0) to correctly set
Context
The bug is in this block of _nupi_sgd_init:
30b158e/.../nupi_optimizer.py#L360-L362
Proposed Fix
if uses_kp_term:
if "xi" not in state:
# This is step t=0. Initialize xi_0 = v_0 as per SGD init math.
state["xi"] = detached_error.clone()
else:
# This is step t > 0. Update xi_t = nu*xi_{t-1} + (1-nu)*v_t
state["xi"].mul_(ema_nu).add_(detached_error, alpha=1 - ema_nu)
Note on _sparse_nupi_sgd_init
The sparse implementation (_sparse_nupi_sgd_init) should also be checked to ensure it does not suffer from a similar state initialization error.