Skip to content

Conversation

@f-dangel
Copy link
Owner

@f-dangel f-dangel commented Dec 26, 2025

This PR is a massive refactoring with the goal to reduce code duplication between KFAC, EKFAC, and their inverses, and remove unnecessary complexity. We will not directly merge this PR. The goal is to discuss the current refactor, incorporate feedback, then try to break it down into smaller submitt-able PRs.

  • Preparations: Introduce new structured linear operators:
    • Introduces a new class KroneckerProductLinearOperator to represent Kronecker products S_1 \otimes S_2 \otimes ... with torch.Tensor s S_i representing the Kronecker factors. This operator contains the logic to multiply Kronecker products onto vectors and bundles all einsum calls into a single class without leaking into (E)KFACLinearOperator. It also implements the properties trace, det, logdet and frobenius_norm, and an .inverse(...) function that accepts arguments to specify the damping strategy and value.
    • Introduces a new class EighLinearOperator representing an eigen-decomposed matrix Q @ diag(lam) @ Q.T where Q can be either a torch.Tensor or a linear operator. It implements the same functions as KroneckerProductLinearOperator.
    • Introduces a new class BlockdiagonalLinearOperator representing block_diag(B_1, B_2, ...) where B_i is a linear operator (either Kronecker product or eigen-decomposed matrix). Implements the same functions as the above operators, by calling the appropriate functions for each block (e.g. the trace is the sum of all block traces).

Each of the above points can be a separate PR. I can start this immediately after we decide the overall approach makes sense.

  • Refactoring: New structure of a KFACLinearOperator called K:

    • Constructor remains unchanged.
    • The operator's constructor builds up three operators, K_op = P @ B @ P.T. P is an operator that converts the parameter space to a canonical form by undoing parameter shuffling and potentially grouping together weight and bias components that are treated jointly. B is a block-diagonal linear operator (BlockDiagonalLinearOperator) and each block is a Kronecker-factored linear operator (KroneckerProductLinearOperator). This makes it much easier to share code between KFAC and EKFAC as they only differ how B is set up. It also simplifies their inversion, because K_op_inv = P @ B_inv @ P.T. P and P.T are separate linear operators that take care of the canonicalization.
    • Removes all state_dict functionality. Users can simply save K_op via torch.save(K._operator) and then load it back. This is simpler to the current solution because we can discard the neural network, loss function, etc. K._operator is a _ChainPyTorchLinearOperator and can be used like a normal linear operator.
    • Always compute the Kronecker factors in the constructor. This removes a lot of caching logic, and is fine because we do not need to be able to construct a stateless KFACLinearOperator to populate it with a state_dict anymore (due to the previous point).
    • Adds a function K.inverse(...) which returns a _ChainPyTorchLinearOperator representing P @ B_inv @ P.T. It accepts arguments to specify the damping value and strategy. As a result, we can remove KFACInverseLinearOperator.
    • Side note: Makes computing the Kronecker factors more 'functional' by removing class dictionaries like self._input_covariances. This massively reduces dependencies across class methods.
  • Refactoring: New structure of an EKFACLinearOperator called E.

    • Compared to KFAC, the operator B is different in that its blocks contain EighLinearOperators and not KroneckerProductLinearOperators. The eigencorrection part requires rotating the gradients into the Kronecker-factored basis and this is done purely with functionality from the above base operators, removing a lot of einsum calls.
    • Saving, loading, and inverting works in exactly the same way as KFAC.

We need to discuss how to best split these refactorings up into manage-able PRs.

  • Miscellaneous:
    • Let K be a KFACLinearOperator. Instead of K.det (or .logdet, .trace, .frobenius_norm), we now call K.det() etc. Why? This makes K feel more like a torch.Tensor.

@f-dangel
Copy link
Owner Author

f-dangel commented Dec 26, 2025

@runame I probably need to clarify more aspects. Let me know if something is weird. Tests are passing locally for me.

@runame runame added the enhancement New feature or request label Dec 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[REF] Create shared parent classes for KFACLinearOperator and KFACInverseLinearOperator

3 participants