Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 21 additions & 18 deletions src/cooper/multipliers/multipliers.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,23 @@ def forward(self) -> torch.Tensor:
class IndexedMultiplier(ExplicitMultiplier):
r""":py:class:`~cooper.multipliers.ExplicitMultiplier` for indexed constraints which
are evaluated only for a subset of constraints on every optimization step.

Args:
num_constraints: Number of constraints associated with the multiplier.
init: Tensor used to initialize the multiplier values. If both ``init`` and
``num_constraints`` are provided, ``init`` must have shape ``(num_constraints,)``.
device: Device for the multiplier. If ``None``, the device is inferred from the
``init`` tensor or the default device.
dtype: Data type for the multiplier. Default is ``torch.float32``.
sparse_grad: Whether to use sparse gradients. Default is ``True``. When set to
``False`` with stateful optimizers (e.g., Adam), optimizer states will be
updated for all parameters, assuming zero gradients for non-sampled indices.
This may lead to incorrect optimization behavior as these values should not
be updated at all.

Note:
The default value of ``sparse_grad=True`` is recommended for stateful optimizers.
Set ``sparse_grad=False`` only when necessary (e.g., when using DDP) and with caution.
"""

expects_constraint_features = True
Expand All @@ -147,34 +164,20 @@ def __init__(
init: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.float32,
*,
sparse_grad: bool = True,
) -> None:
super().__init__(num_constraints, init, device, dtype)
if self.weight.dim() == 1:
# To use the forward call in F.embedding, we must reshape the weight to be a
# 2-dim tensor
self.weight.data = self.weight.data.unsqueeze(-1)
self.sparse_grad = sparse_grad

def forward(self, indices: torch.Tensor) -> torch.Tensor:
"""Return the current value of the multiplier at the provided indices.

Args:
indices: Indices of the multipliers to return. The shape of ``indices`` must
be ``(num_indices,)``.

Raises:
ValueError: If ``indices`` dtype is not ``torch.long``.
"""
if indices.dtype != torch.long:
# Not allowing for boolean "indices", which are treated as indices by
# torch.nn.functional.embedding and *not* as masks.
raise ValueError("Indices must be of type torch.long.")

# TODO(gallego-posada): Document sparse gradients are expected for stateful
# optimizers (having buffers)
multiplier_values = torch.nn.functional.embedding(indices, self.weight, sparse=True)

# Flatten multiplier values to 1D since Embedding works with 2D tensors.
return torch.flatten(multiplier_values)
return self.weight.gather(0, indices, sparse_grad=self.sparse_grad)


class ImplicitMultiplier(Multiplier):
Expand Down
10 changes: 7 additions & 3 deletions src/cooper/optim/torch_optimizers/nupi_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ def step(self, closure: Optional[Callable] = None) -> Optional[float]:
if p.grad is None:
continue

if p.grad.ndim != 1:
# TODO(juan43ramirez): Implement support for multidimensional parameters
raise NotImplementedError("nuPI optimizer only supports 1D parameters.")

update_function = self.disambiguate_update_function(p.grad.is_sparse, group["init_type"])
update_function(
param=p,
Expand Down Expand Up @@ -301,7 +305,7 @@ def _sparse_nupi_zero_init(
nupi_update_values.add_(detached_error_values.mul(et_coef_values))

if xit_m1_coef_values.ne(0).any():
xi_values = state["xi"].sparse_mask(error)._values()
xi_values = state["xi"][tuple(error_indices)]
nupi_update_values.sub_(xi_values.mul(xit_m1_coef))

nupi_update = torch.sparse_coo_tensor(error_indices, nupi_update_values, size=param.shape)
Expand Down Expand Up @@ -404,9 +408,9 @@ def _sparse_nupi_sgd_init(
nupi_update_values.add_(detached_error_values.mul(filtered_Ki_values))

if uses_kp_term:
previous_xi_values = state["xi"].sparse_mask(error)._values()
previous_xi_values = state["xi"][tuple(error_indices)]
proportional_term_contribution = torch.where(
state["needs_error_initialization_mask"].sparse_mask(error)._values(),
state["needs_error_initialization_mask"][tuple(error_indices)],
torch.zeros_like(detached_error_values), # If state has not been initialized, xi_0 = 0
(1 - ema_nu) * (detached_error_values - previous_xi_values), # Else, we use recursive update
)
Expand Down
13 changes: 13 additions & 0 deletions tests/multipliers/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ def init_multiplier_tensor(constraint_type, num_constraints, random_seed):
return raw_init


@pytest.fixture(params=[True, False])
def sparse_grad(request):
return request.param


@pytest.fixture
def multiplier(multiplier_class, num_constraints, init_multiplier_tensor, device, sparse_grad):
kwargs = {"num_constraints": num_constraints, "init": init_multiplier_tensor, "device": device}
if multiplier_class == cooper.multipliers.IndexedMultiplier:
kwargs["sparse_grad"] = sparse_grad
return multiplier_class(**kwargs)


@pytest.fixture
def all_indices(num_constraints):
return torch.arange(num_constraints, dtype=torch.long)
85 changes: 33 additions & 52 deletions tests/multipliers/test_explicit_multipliers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@ def evaluate_multiplier(multiplier, all_indices):
return multiplier(all_indices) if multiplier.expects_constraint_features else multiplier()


def test_multiplier_initialization_with_init(multiplier_class, init_multiplier_tensor, device):
multiplier = multiplier_class(init=init_multiplier_tensor, device=device)
def test_multiplier_initialization_with_init(multiplier, init_multiplier_tensor, device):
assert torch.equal(multiplier.weight.view(-1), init_multiplier_tensor.to(device).view(-1))
assert multiplier.device.type == device.type


def test_multiplier_initialization_with_num_constraints(multiplier_class, num_constraints, device):
multiplier = multiplier_class(num_constraints=num_constraints, device=device)
def test_multiplier_initialization_with_num_constraints(multiplier, num_constraints, device):
assert multiplier.weight.numel() == num_constraints
assert multiplier.device.type == device.type

Expand All @@ -43,8 +41,7 @@ def test_multiplier_initialization_with_init_dim(multiplier_class, num_constrain
multiplier_class(num_constraints=num_constraints, init=torch.zeros(num_constraints, 1))


def test_multiplier_repr(multiplier_class, num_constraints):
multiplier = multiplier_class(num_constraints=num_constraints)
def test_multiplier_repr(multiplier, multiplier_class, num_constraints):
assert repr(multiplier) == f"{multiplier_class.__name__}(num_constraints={num_constraints})"


Expand All @@ -53,79 +50,57 @@ def test_multiplier_sanity_check(constraint_type, multiplier_class, init_multipl
if constraint_type == cooper.ConstraintType.EQUALITY:
pytest.skip("")

multiplier = multiplier_class(init=init_multiplier_tensor.abs().neg())
neg_multiplier = multiplier_class(init=init_multiplier_tensor.abs().neg())
with pytest.raises(ValueError, match=r"For inequality constraint, all entries in multiplier must be non-negative."):
multiplier.set_constraint_type(cooper.ConstraintType.INEQUALITY)
neg_multiplier.set_constraint_type(cooper.ConstraintType.INEQUALITY)


def test_multiplier_init_and_forward(multiplier_class, init_multiplier_tensor, all_indices):
def test_multiplier_init_and_forward(multiplier, init_multiplier_tensor, all_indices):
# Ensure that the multiplier returns the correct value when called
ineq_multiplier = multiplier_class(init=init_multiplier_tensor)
multiplier_values = evaluate_multiplier(ineq_multiplier, all_indices)
multiplier_values = evaluate_multiplier(multiplier, all_indices)
target_tensor = init_multiplier_tensor.reshape(multiplier_values.shape)
assert torch.allclose(multiplier_values, target_tensor)


def test_indexed_multiplier_forward_invalid_indices(init_multiplier_tensor):
multiplier = cooper.multipliers.IndexedMultiplier(init=init_multiplier_tensor)
indices = torch.tensor([0, 1, 2, 3, 4], dtype=torch.float32)

with pytest.raises(ValueError, match=r"Indices must be of type torch.long."):
multiplier.forward(indices)


def test_equality_post_step_(constraint_type, multiplier_class, init_multiplier_tensor, all_indices):
def test_equality_post_step_(constraint_type, multiplier, init_multiplier_tensor, all_indices):
"""Post-step for equality multipliers should be a no-op. Check that multiplier
values remain unchanged after calling post_step_.
"""
if constraint_type == cooper.ConstraintType.INEQUALITY:
pytest.skip("")

eq_multiplier = multiplier_class(init=init_multiplier_tensor)
eq_multiplier.set_constraint_type(cooper.ConstraintType.EQUALITY)
eq_multiplier.post_step_()
multiplier_values = evaluate_multiplier(eq_multiplier, all_indices)
multiplier.set_constraint_type(constraint_type)
multiplier.post_step_()
multiplier_values = evaluate_multiplier(multiplier, all_indices)
target_tensor = init_multiplier_tensor.reshape(multiplier_values.shape)
assert torch.allclose(multiplier_values, target_tensor)


def test_ineq_post_step_(constraint_type, multiplier_class, init_multiplier_tensor, all_indices):
def test_ineq_post_step_(constraint_type, multiplier, all_indices):
"""Ensure that the inequality multipliers remain non-negative after post-step."""
if constraint_type == cooper.ConstraintType.EQUALITY:
pytest.skip("")

ineq_multiplier = multiplier_class(init=init_multiplier_tensor)
ineq_multiplier.set_constraint_type(cooper.ConstraintType.INEQUALITY)
multiplier.set_constraint_type(constraint_type)

# Overwrite the multiplier to have some *negative* entries and gradients
hard_coded_weight_data = torch.randn_like(ineq_multiplier.weight)
ineq_multiplier.weight.data = hard_coded_weight_data
hard_coded_weight_data = torch.randn_like(multiplier.weight)
multiplier.weight.data = hard_coded_weight_data

hard_coded_gradient_data = torch.randn_like(ineq_multiplier.weight)
ineq_multiplier.weight.grad = hard_coded_gradient_data
if isinstance(ineq_multiplier, cooper.multipliers.IndexedMultiplier):
ineq_multiplier.weight.grad = ineq_multiplier.weight.grad.to_sparse(sparse_dim=1)
hard_coded_gradient_data = torch.randn_like(multiplier.weight)
if isinstance(multiplier, cooper.multipliers.IndexedMultiplier) and multiplier.sparse_grad:
hard_coded_gradient_data = hard_coded_gradient_data.to_sparse(sparse_dim=1)
multiplier.weight.grad = hard_coded_gradient_data

# Post-step should ensure non-negativity. Note that no feasible indices are passed,
# so "feasible" multipliers and their gradients are not reset.
ineq_multiplier.post_step_()
# Post-step should ensure non-negativity
multiplier.post_step_()
multiplier_values = evaluate_multiplier(multiplier, all_indices)

multiplier_values = evaluate_multiplier(ineq_multiplier, all_indices)
target_weight_data = hard_coded_weight_data.relu()
current_grad = multiplier.weight.grad

target_weight_data = hard_coded_weight_data.relu().reshape_as(multiplier_values)
current_grad = ineq_multiplier.weight.grad.to_dense()
assert torch.allclose(multiplier_values, target_weight_data)
assert torch.allclose(current_grad, hard_coded_gradient_data)

# Perform post-step again, this time with feasible indices
ineq_multiplier.post_step_()

multiplier_values = evaluate_multiplier(ineq_multiplier, all_indices)

current_grad = ineq_multiplier.weight.grad.to_dense()
# Latest post-step is a no-op
assert torch.allclose(multiplier_values, target_weight_data)
assert torch.allclose(current_grad, hard_coded_gradient_data)
assert torch.allclose(current_grad.to_dense(), hard_coded_gradient_data.to_dense())


def check_save_load_state_dict(multiplier, explicit_multiplier_class, num_constraints, random_seed):
Expand All @@ -144,7 +119,13 @@ def check_save_load_state_dict(multiplier, explicit_multiplier_class, num_constr
assert torch.equal(multiplier.weight, new_multiplier.weight)


def test_save_load_multiplier(multiplier_class, init_multiplier_tensor, num_constraints, random_seed):
def test_save_load_multiplier(multiplier, multiplier_class, num_constraints, random_seed):
"""Test that the state_dict of a multiplier can be saved and loaded correctly."""
multiplier = multiplier_class(init=init_multiplier_tensor)
check_save_load_state_dict(multiplier, multiplier_class, num_constraints, random_seed)


def test_multiplier_grad(multiplier, all_indices):
evaluate_multiplier(multiplier, all_indices).sum().backward()
assert multiplier.weight.grad.is_sparse == (
isinstance(multiplier, cooper.multipliers.IndexedMultiplier) and multiplier.sparse_grad
)
12 changes: 10 additions & 2 deletions tests/optim/torch_optimizers/test_nupi.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def loss_fn(indices):

def compute_analytic_gradient(indices):
# For the quadratic loss, the gradient is simply the current value of p.
return multiplier_module(indices).reshape(-1, 1).clone().detach()
return multiplier_module(indices).clone().detach()

def recursive_nuPI_direction(error, previous_xi):
return (Ki + (1 - ema_nu) * Kp) * error - (1 - ema_nu) * Kp * previous_xi
Expand Down Expand Up @@ -275,7 +275,7 @@ def loss_fn(indices):

def compute_analytic_gradient(indices):
# For the quadratic loss, the gradient is simply the current value of p.
return multiplier_module(indices).reshape(-1, 1).clone().detach()
return multiplier_module(indices).clone().detach()

optimizer = nuPI(
multiplier_module.parameters(),
Expand Down Expand Up @@ -352,3 +352,11 @@ def do_optimizer_step(indices):
# Check state entries that have not been updated yet
unseen_indices = torch.tensor([4, 8, 9], device=device)
assert torch.allclose(buffer[unseen_indices], torch.zeros_like(buffer[unseen_indices]))


def test_nupi_multi_dimensional_raises():
param = torch.ones(2, 3, requires_grad=True)
param.sum().backward()
optimizer = nuPI([param], lr=0.01)
with pytest.raises(NotImplementedError, match="nuPI optimizer only supports 1D parameters"):
optimizer.step()
Loading