Skip to content

Bug: Incorrect state initialization in _nupi_sgd_init leads to wrong second step in nuPI #107

@juan43ramirez

Description

@juan43ramirez

Bug

The nuPI optimizer has a bug in its dense initialization function, _nupi_sgd_init, when init_type=nuPIInitType.SGD.

At the end of the first step (t=0), it performs an incorrect state-saving operation, which causes the second optimizer update to be wrong.

Given the state update rule
$$\xi_t = \nu, \xi_{t-1} + (1 - \nu), e_t,$$
the first saved state should be
$$\xi_0 = \nu, \xi_{-1} + (1 - \nu), e_0 = \nu e_0 + (1 - \nu)e_0 = e_0.$$

Thus, after the first step (t=0), _nupi_sgd_init should store
state["xi"] = e_0 (with e_0 given by detached_error).

Instead, it incorrectly saves
state["xi"] = torch.zeros_like(param),
which forces ($\xi_0 = 0$).

On the second step (t=1), the optimizer then computes the proportional term
$$K_P(1-\nu)(e_1 - \xi_0)$$,
using this incorrect ($\xi_0$), producing a disproportionately large and erroneous update.

Steps

  1. Initialize the nuPI optimizer with init_type=nuPIInitType.SGD (the default) on a dense parameter.
  2. Run optimizer.step() (t=0). The step update is correct, but the state xi_0 is incorrectly saved as 0.
  3. Run optimizer.step() (t=1).
  4. Observe that the update computed in step 3 (t=1) is wrong, as the proportional term is based on a bad xi_0 state.

Expected behavior

The _nupi_sgd_init function must save state["xi"] = detached_error.clone() at the end of the first step (t=0) to correctly set $\xi_0 = e_0$.

Context

The bug is in this block of _nupi_sgd_init:

30b158e/.../nupi_optimizer.py#L360-L362

Proposed Fix

    if uses_kp_term:
        if "xi" not in state:
            # This is step t=0. Initialize xi_0 = v_0 as per SGD init math.
            state["xi"] = detached_error.clone()
        else:
            # This is step t > 0. Update xi_t = nu*xi_{t-1} + (1-nu)*v_t
            state["xi"].mul_(ema_nu).add_(detached_error, alpha=1 - ema_nu)

Note on _sparse_nupi_sgd_init

The sparse implementation (_sparse_nupi_sgd_init) should also be checked to ensure it does not suffer from a similar state initialization error.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions