Skip to content

PML-379: Fix .cuda(), .cpu(), .float() for memristor#278

Open
LF-Vigneux wants to merge 10 commits into
merlinquantum:release/0.4.1from
LF-Vigneux:PML-379-Fix-.device-for-memristor
Open

PML-379: Fix .cuda(), .cpu(), .float() for memristor#278
LF-Vigneux wants to merge 10 commits into
merlinquantum:release/0.4.1from
LF-Vigneux:PML-379-Fix-.device-for-memristor

Conversation

@LF-Vigneux

Copy link
Copy Markdown
Contributor

Summary

Issue raised by Claude on review of release 0.4.0

Memristive state not moved by .cuda()/.cpu()/.float() (device/dtype split)
QuantumLayer overrides to() to move memristive_state/memristive_history, but nn.Module.cuda()/.cpu()/.float()/.double()/.half() route through _apply, not to(). Since the memristive tensors are plain Python lists (not registered buffers), those move paths leave them on the old device/dtype, and the next forward() mixes devices → runtime error or silent wrong-device compute.
Fix: override _apply (or register the state/history as buffers) so all torch move paths convert the memristive tensors.

Related Issue

PML-379

Type of change

  • Bug fix
  • New feature
  • Documentation update
  • Refactor / Cleanup
  • Performance improvement
  • CI / Build / Tooling
  • Breaking change (requires migration notes)

Proposed changes

  • Added a _apply method, does the same thing as to by moving all auxiliary modules with to and the rest uses apply like it is supposed to do.

How to test / How to run

pytest -q

Documentation

  • User docs updated (Sphinx)
  • Examples / notebooks updated
  • Docstrings updated
  • Updated the API

@LF-Vigneux

Copy link
Copy Markdown
Contributor Author

@claude review

@claude

claude Bot commented Jun 25, 2026

Copy link
Copy Markdown

Claude finished @LF-Vigneux's task in 2m 50s —— View job


PR Review: PML-379 Fix .cuda(), .cpu(), .float() for memristor

  • Gather context (diff, base branch)
  • Review _apply implementation in layer.py
  • Review tests in test_layer.py
  • Post review feedback

The overall approach is correct: overriding _apply is the right fix, since nn.Module.cuda()/.cpu()/.float()/.double()/.half() all route through _apply (not to), and the memristive tensors live in plain Python lists. The test coverage is also genuinely good. But there are a few correctness issues and a notable redundancy worth addressing before merge.


🔴 Bugs / correctness

1. _apply returns None when the module has no parameters/buffersmerlin/algorithms/layer.py:1634

except StopIteration:
    return   # ← returns None, not self

nn.Module.to(), .cuda(), .cpu() etc. all do return self._apply(...). If a QuantumLayer ever has neither parameters nor buffers, layer.cuda() would return None, silently breaking x = model.cuda() chaining. Should be return self. (Edge case, but a clear correctness defect.)

Fix this →

2. self.dtype is unconditionally recomputed from ref_tensor.dtype — breaks .half() and complex first-tensorsmerlin/algorithms/layer.py:1638

self.device = ref_tensor.device
_, self.dtype, self.complex_dtype = MerlinModule.setup_device_and_dtype(
    None, ref_tensor.dtype,
)

setup_device_and_dtype (module.py:103) raises ValueError for any dtype other than float32/float64. Two consequences:

  • .half() now raises instead of working. The new docstring explicitly lists half() as supported, but super()._apply(fn) converts the first parameter to float16, so setup_device_and_dtype(None, torch.float16) throws. So the docstring is inaccurate — .half() is not actually supported by this implementation.
  • Device-only moves can crash. On .cpu()/.cuda(), dtype shouldn't change at all, but the code still feeds ref_tensor.dtype into setup_device_and_dtype. If the first registered parameter or buffer happens to be complex (e.g. a unitary buffer), .cpu()/.cuda() will raise ValueError. The current tests pass only because the first tensor in the test fixture is real float32.

Recommend only updating self.dtype/self.complex_dtype when ref_tensor is a real floating dtype (or otherwise guarding against complex/half), and leaving them untouched on pure device moves:

self.device = ref_tensor.device
if ref_tensor.dtype in (torch.float32, torch.float64):
    _, self.dtype, self.complex_dtype = MerlinModule.setup_device_and_dtype(
        None, ref_tensor.dtype,
    )

Fix this →


🟡 Design / maintainability

3. The to() override is now largely redundant with _apply — and double-moves everything.

nn.Module.to() is implemented as return self._apply(convert). Since _apply is now overridden, calling layer.to(device, dtype) executes the full memristive-state + auxiliary-module moving once inside _apply (via super().to()), and then again in the body of the to() override. It's not incorrect (re-moving an already-moved tensor is a no-op), but it's wasteful and the two methods will drift over time. Consider deleting the to() override entirely (now that _apply is the single source of truth) or reducing it to argument normalization only. This would also eliminate the dtype-handling inconsistency between the two paths.

4. Heavy code duplication between to() and _apply. The photon-loss / detector / probability-readout / memristive blocks are near-verbatim copies. Extracting a private _move_auxiliary_state(device, dtype, fn) helper would remove ~40 duplicated lines and keep the two move paths from diverging.

5. Behavioral divergence in how memristive tensors are converted. _apply uses fn(tensor) (respects PyTorch's float-only semantics), while to() forces dtype=self.dtype on every memristive tensor. For real tensors these agree, but they would diverge for complex memristive state — another reason to unify on a single path.


🟢 Minor / robustness

6. Signature mismatch with the parent. def _apply(self, fn): drops the recurse parameter that torch.nn.Module._apply(self, fn, recurse=True) has had since PyTorch 2.0. Any caller using module._apply(fn, recurse=False) would hit a TypeError. Safer to mirror the parent: def _apply(self, fn, recurse=True): and super()._apply(fn, recurse=recurse).

7. Redundant .to() on registered submodules. computation_process / single-module _photon_loss_transform / _detector_transform are auto-registered as submodules, so super()._apply(fn) already moved them; the explicit .to(...) calls only matter for the list (Sequence) variants. Harmless, but the comment could clarify that the manual handling exists for the un-registered list case.


✅ Tests

Good, thorough coverage — parametrized cpu/float/double, CUDA-guarded variants, and a noisy-layer fixture exercising the photon-loss/detector/g2 paths. A few gaps that map directly to the bugs above:

  • No .half() test — would currently surface the ValueError from finding Welcome to merlin Discussions! #2.
  • No test for the no-params/no-buffers path (finding Update issue templates #1).
  • No combined .to(device=..., dtype=...) test, and no assertion that .to() and _apply agree on a layer whose first tensor is complex.
  • assert_layers_equal only checks zip-aligned pairs for parameters/buffers; it does assert key-set equality for named params/buffers (good), but _module_or_sequence_equal relies on zip without a length check on .parameters() — fine in practice, minor.

Nothing here blocks the core fix conceptually, but findings #1 and #2 are real correctness issues (especially the inaccurate .half() docstring claim) and I'd address them before merge.
· PML-379-Fix-.device-for-memristor

@LF-Vigneux

Copy link
Copy Markdown
Contributor Author

@claude I have tried to correct your 🟡 Design / maintainability concerns, however I have a problem with the fact that device=None does not pass as None device with the apply function now. How would you fix it?

@claude

claude Bot commented Jun 25, 2026

Copy link
Copy Markdown

Claude finished @LF-Vigneux's task in 5m 32s —— View job


PML-379 — Fixing device=None semantics in _apply

  • Read current to() / _apply implementation
  • Trace where device=None gets lost
  • Implement fix
  • Verify reasoning (couldn't run pytest — see note)

Why device=None stopped passing through

The regression is in this block:

ref_tensor = next(self.parameters(), ...)   # a parameter
self.device = ref_tensor.device             # <-- always concrete!

After super()._apply(fn) runs, every parameter already lives on a concrete device (e.g. cpu). So reading the device back off a parameter can never yield None — a dtype-only conversion like .float()/.double() ends up forcing a device onto everything downstream. That matters because ComputationProcess.to(dtype=…, device=None) treats device=None specially ("keep PyTorch's implicit placement"), and _apply was clobbering that distinction.

The root cause is using a result (the moved parameter) to infer the intent (what the user asked fn to change). A parameter can't tell you "the user only changed dtype, not device."

The fix — probe fn instead of reading a parameter

fn is the conversion closure PyTorch builds; it encodes exactly what was requested. Apply it to a tiny probe tensor placed on the layer's current device, and compare:

old_device = self.device if self.device is not None else torch.device("cpu")
probe = torch.zeros((), device=old_device, dtype=torch.get_default_dtype())
moved = fn(probe)

if moved.device != old_device:      # a real device move was requested
    self.device = moved.device      # ...otherwise leave it (stays None)

if moved.dtype in (torch.float32, torch.float64):   # guards .half()/complex
    _, self.dtype, self.complex_dtype = MerlinModule.setup_device_and_dtype(
        None, moved.dtype
    )

self.computation_process.to(dtype=self.dtype, device=self.device)

Probing on old_device (not bare cpu) is what makes it correct in every direction:

call old_device moved.device result
.float() (dtype only) cpu/None cpu device left as-is (None preserved) ✅
.cuda() from cpu cpu cuda:0 self.device = cuda:0
.cpu() from cuda cuda:0 cpu self.device = cpu

Two issues from my earlier review fall out for free

  • .half() no longer crashes — the dtype in (float32, float64) guard skips setup_device_and_dtype (which raises ValueError on float16/complex).
  • Auxiliary list-modules now reuse fn directly instead of a reconstructed target_kwargs, so they inherit the same device=None/float-only semantics. Single-module variants are registered submodules and were already moved by super()._apply, so only the Sequence variants need handling:
for aux in (self._photon_loss_transform, self._detector_transform, self._probability_readout):
    if isinstance(aux, Sequence):
        for module in aux:
            module._apply(fn, recurse=recurse)

Bonus: removed the now-redundant to() override

nn.Module.to() is return self._apply(convert). Now that _apply is the single source of truth, the to() override only duplicated logic (and its own arg-parsing would still ValueError on .to(torch.float16)). Deleting it removes the double-move (review finding #3) and the divergence (#4/#5), and routes .to() through the same robust path.


⚠️ Couldn't push

I committed the change locally (eef6e17), but the push was rejected — Permission to merlinquantum/merlin.git denied to github-actions[bot] (403). So you'll need to apply it. Full diff below.

Diff for merlin/algorithms/layer.py
@@ class QuantumLayer(MerlinModule):
-    def to(self, *args: Any, **kwargs: Any) -> QuantumLayer:
-        """Move the layer and auxiliary transforms to a new device or dtype.
-        ...
-        """
-        super().to(*args, **kwargs)
-        ...                       # (entire to() override deleted)
-        return self
-
     def _apply(self, fn, recurse=True):
         ...
         super()._apply(fn, recurse=recurse)
         # ... memristive tensors moved with fn(tensor) ...

-        # infer canonical device/dtype AFTER move
-        ref_tensor = None
-        for p in self.parameters():
-            ref_tensor = p
-            break
-        if ref_tensor is None:
-            for b in self.buffers():
-                ref_tensor = b
-                break
-        if ref_tensor is not None:
-            self.device = ref_tensor.device
-            if ref_tensor.dtype in (torch.float32, torch.float64):
-                _, self.dtype, self.complex_dtype = MerlinModule.setup_device_and_dtype(
-                    None, ref_tensor.dtype
-                )
+        old_device = self.device if self.device is not None else torch.device("cpu")
+        probe = torch.zeros((), device=old_device, dtype=torch.get_default_dtype())
+        moved = fn(probe)
+        if moved.device != old_device:
+            self.device = moved.device
+        if moved.dtype in (torch.float32, torch.float64):
+            _, self.dtype, self.complex_dtype = MerlinModule.setup_device_and_dtype(
+                None, moved.dtype
+            )

         self.computation_process.to(dtype=self.dtype, device=self.device)

-        target_kwargs = {"dtype": self.dtype}
-        if self.device is not None:
-            target_kwargs["device"] = self.device
-        if self._photon_loss_transform is not None:
-            ...  # .to(**target_kwargs) for both single + list variants
-        if self._detector_transform is not None:
-            ...
-        if self._probability_readout is not None:
-            self._probability_readout = self._probability_readout.to(device=self.device)
+        for aux in (
+            self._photon_loss_transform,
+            self._detector_transform,
+            self._probability_readout,
+        ):
+            if isinstance(aux, Sequence):
+                for module in aux:
+                    module._apply(fn, recurse=recurse)

         return self

Note: I couldn't run pytest in this environment (sandbox blocks process execution). Please run pytest -q tests/algorithms/test_layer.py after applying, and consider adding a .half() test plus an assertion that layer.to(torch.float64) leaves layer.device unchanged — both now behave correctly with this fix.
· PML-379-Fix-.device-for-memristor

@LF-Vigneux

Copy link
Copy Markdown
Contributor Author

@claude half should fail no? Is it what it was happening with to. If so, how do I keep it. Also, the added tests are failing right now because AssertionError: assert device(type='cpu') == None, how would you change that

@claude

claude Bot commented Jun 25, 2026

Copy link
Copy Markdown

Claude finished @LF-Vigneux's task in 6m 3s —— View job


PML-379 — .half() semantics & the device == None test failure

  • Confirm what the old to() did with float16
  • Decide how to keep .half() failing
  • Diagnose the assert device(type='cpu') == None failure
  • Implement fix in layer.py (+ tests)
  • ⚠️ Push — rejected (403, github-actions[bot] lacks write), diff below for you to apply

1. Should .half() fail? Yes — and here's how to keep it

You're right: Merlin only supports float32/float64 (setup_device_and_dtype in module.py:103 raises ValueError for anything else). With the old to(), layer.to(torch.float16) did raise — because to() called setup_device_and_dtype(None, torch.float16). But .half() went through nn.Module._apply (no override back then), so it silently half-converted. Inconsistent.

The reason .half() stopped failing in your current _apply is this guard:

if moved.dtype in (torch.float32, torch.float64):   # ← silently skips float16

It swallows the unsupported dtype instead of rejecting it. To keep .half() failing, flip the logic: only skip setup_device_and_dtype when the dtype didn't change (pure device move). When the dtype does change, always route it through setup_device_and_dtype so unsupported dtypes raise:

if moved.dtype != old_dtype:          # a real dtype change was requested
    _, new_dtype, new_complex = MerlinModule.setup_device_and_dtype(None, moved.dtype)

I also moved the probe before super()._apply(fn), so .half() raises up front instead of leaving the module half-converted. Now .half() and .to(torch.float16) both raise ValueError — consistent with the old to().

2. The assert device(type='cpu') == None failure

The failing test is test_apply_methods_match_to[{'device':'cpu'},'cpu']:

expected = deepcopy(layer).to(device="cpu")   # old to(): sets self.device = cpu
actual   = deepcopy(layer).cpu()              # _apply: probe sees cpu→cpu, leaves device = None

The two paths disagreed on bookkeeping. to() explicitly did self.device = torch.device("cpu"), but _apply's probe (correctly, per your earlier device=None requirement) can't distinguish "moved to cpu" from "already cpu / unset", so it left self.device = None. Hence cpu == NoneAssertionError.

Fix: delete the to() override entirely. nn.Module.to() is literally return self._apply(convert), so once _apply is the single source of truth, .to(...) and .cpu()/.float()/.double() run the exact same code and can't diverge. The test then compares two identical paths and passes. This also resolves review findings #3/#4/#5 (double-move + duplication + divergence).

self.device = None is already a fully supported state everywhere in layer.py (device=None → torch's default cpu placement; guarded at layer.py:871, :956; reset() builds tensors with device=self.device). I verified the existing device tests (test_memristive_layer_to_moves_state_and_history_without_raising, test_memristive_reset_after_to_uses_updated_device_and_dtype) still hold because they assert on the tensor device (cpu), not on self.device.

Note: after this change layer.cpu() leaves layer.device == None (meaning "default/cpu") rather than torch.device("cpu"). The tensors are genuinely on cpu either way. If you'd rather self.device always be concrete, that's a different design choice — but then dtype-only .float() couldn't preserve device=None, which is the property you asked for earlier.


Diff (commit cdf44c9, applied locally — please apply since I can't push)

merlin/algorithms/layer.py — delete the whole to() override, and replace the _apply device/dtype block:

-    def to(self, *args: Any, **kwargs: Any) -> QuantumLayer:
-        """Move the layer and auxiliary transforms to a new device or dtype.
-        ...
-        """
-        super().to(*args, **kwargs)
-        ...                       # (entire to() override deleted)
-        return self
-
     def _apply(self, fn, recurse=True):
         """...
+        Only ``torch.float32`` and ``torch.float64`` are supported dtypes for
+        Merlin modules. Requesting any other floating dtype (e.g. ``.half()`` /
+        ``torch.float16``) raises ``ValueError``.
         ...
         """
+        # Probe what ``fn`` will do *before* mutating any real tensor.
+        old_device = self.device if self.device is not None else torch.device("cpu")
+        old_dtype = self.dtype if self.dtype is not None else torch.get_default_dtype()
+        moved = fn(torch.zeros((), device=old_device, dtype=old_dtype))
+
+        new_device = self.device
+        if moved.device != old_device:          # a real device move was requested
+            new_device = moved.device           # ...otherwise leave it (stays None)
+
+        new_dtype, new_complex = self.dtype, self.complex_dtype
+        if moved.dtype != old_dtype:            # a real dtype change was requested
+            # Raises ValueError for unsupported dtypes (e.g. float16).
+            _, new_dtype, new_complex = MerlinModule.setup_device_and_dtype(
+                None, moved.dtype
+            )
+
         super()._apply(fn, recurse=recurse)

-        # memristive tensors
+        # memristive tensors (plain Python lists, untouched by super()._apply)
         for state in range(len(self.memristive_history)):
             ...                                 # (memristive move unchanged)
         for state in range(len(self.memristive_state)):
             ...

-        # infer canonical device/dtype AFTER move
-        old_device = self.device if self.device is not None else torch.device("cpu")
-        probe = torch.zeros((), device=old_device, dtype=torch.get_default_dtype())
-        moved = fn(probe)
-        if moved.device != old_device:
-            self.device = moved.device
-        if moved.dtype in (torch.float32, torch.float64):
-            _, self.dtype, self.complex_dtype = MerlinModule.setup_device_and_dtype(
-                None, moved.dtype
-            )
+        self.device, self.dtype, self.complex_dtype = new_device, new_dtype, new_complex

         self.computation_process.to(dtype=self.dtype, device=self.device)
         # ... auxiliary-module moving block unchanged ...

tests/algorithms/test_layer.py — added two tests:

def test_half_raises_value_error(layer):
    """Merlin only supports float32/float64, so .half()/float16 must raise."""
    with pytest.raises(ValueError):
        deepcopy(layer).half()
    with pytest.raises(ValueError):
        deepcopy(layer).to(torch.float16)


def test_dtype_only_move_leaves_device_unchanged(layer):
    """A pure dtype move must not assign a concrete device."""
    original_device = layer.device
    moved = deepcopy(layer).double()
    assert moved.device == original_device
    assert moved.dtype == torch.float64

⚠️ Couldn't push / couldn't run pytest

The push was rejected — Permission to merlinquantum/merlin.git denied to github-actions[bot] (403) — so please apply the diff above (it's committed locally as cdf44c9). I also couldn't run pytest (the sandbox blocks process execution). After applying, please run:

pytest -q tests/algorithms/test_layer.py

The previously failing test_apply_methods_match_to cases should now pass (both paths share _apply), and .half() will raise ValueError as intended.
· PML-379-Fix-.device-for-memristor

@LF-Vigneux

Copy link
Copy Markdown
Contributor Author

@claude Since the to and apply are now indentical, excatly how would you change the tests and implement them to test the complete working

@claude

claude Bot commented Jun 25, 2026

Copy link
Copy Markdown

Claude finished @LF-Vigneux's task in 3m 2s —— View job


How to test once to() and _apply() are the same code path

Once you delete the to() override, nn.Module.to() is literally return self._apply(convert). That has a direct consequence for your existing tests:

test_apply_methods_match_to becomes a tautology. It does deepcopy(layer).to(device="cpu") vs deepcopy(layer).cpu() and asserts they're equal — but both now run the exact same _apply. If _apply had a bug (e.g. it moved the memristive history to the wrong device, or forgot a buffer), both sides would carry the identical bug and the test would still pass. A path-equality test can only catch divergence between two paths; when there's only one path, it catches nothing.

So the fix isn't to tweak those tests — it's to change what they assert: stop comparing a move against a sibling move, and start comparing the move against ground truth (the device/dtype every tensor should end up on) plus behavior (a forward pass actually runs after the move). That's what "test the complete working" means here.

  • Read current to() / _apply on the branch
  • Identify why the existing match-tests go weak
  • Design ground-truth + behavioral replacements
  • Post concrete, paste-ready test code

1. Ground-truth helper — assert the end state, not path-equality

This walks every tensor the layer owns, including the memristive state/history that live in plain Python lists (the whole point of the fix), and checks each against an explicit expected device/dtype:

def _iter_layer_tensors(layer):
    """Yield (name, tensor) for every tensor the layer owns — including the
    memristive state/history stored in plain Python lists (not buffers)."""
    yield from layer.named_parameters()
    yield from layer.named_buffers()
    for i, s in enumerate(layer.memristive_state):
        if torch.is_tensor(s):
            yield f"memristive_state[{i}]", s
    for i, hist in enumerate(layer.memristive_history):
        for j, t in enumerate(hist):
            if torch.is_tensor(t):
                yield f"memristive_history[{i}][{j}]", t


def assert_layer_on(layer, *, device, real_dtype):
    """Ground truth: assert the end state of a move directly, with no reference
    to any other code path."""
    dev = torch.device(device)
    seen_memristive = False
    for name, t in _iter_layer_tensors(layer):
        assert t.device.type == dev.type, f"{name}: {t.device} != {dev}"
        if t.is_floating_point():                     # complex left to its own test
            assert t.dtype == real_dtype, f"{name}: {t.dtype} != {real_dtype}"
        if name.startswith("memristive"):
            seen_memristive = True
    # guard against a fixture that silently has no memristive tensors —
    # otherwise this whole helper would vacuously pass
    assert seen_memristive, "fixture has no memristive tensors to check"
    # layer bookkeeping
    assert layer.dtype == real_dtype
    assert layer.computation_process.dtype == real_dtype
    if layer.device is not None:                      # None == 'default/cpu', allowed
        assert layer.device.type == dev.type

The seen_memristive guard matters: a ground-truth helper that iterates an empty list passes vacuously, which would hide a regression where the state stops being tracked.

2. Replace the match-tests with ground-truth tests

@pytest.mark.parametrize(
    ("move", "exp_device", "exp_dtype"),
    [
        (lambda l: l.cpu(),    "cpu", torch.float32),  # default real dtype
        (lambda l: l.float(),  "cpu", torch.float32),
        (lambda l: l.double(), "cpu", torch.float64),
    ],
)
def test_move_lands_on_expected_state(layer, move, exp_device, exp_dtype):
    moved = move(deepcopy(layer))
    assert_layer_on(moved, device=exp_device, real_dtype=exp_dtype)


@pytest.mark.parametrize(
    ("move", "exp_device", "exp_dtype"),
    [
        (lambda l: l.cpu(),    "cpu", torch.float32),
        (lambda l: l.float(),  "cpu", torch.float32),
        (lambda l: l.double(), "cpu", torch.float64),
    ],
)
def test_move_lands_on_expected_state_noisy(noisy_layer, move, exp_device, exp_dtype):
    moved = move(deepcopy(noisy_layer))
    assert_layer_on(moved, device=exp_device, real_dtype=exp_dtype)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
def test_cuda_lands_on_expected_state(layer):
    moved = deepcopy(layer).cuda()
    assert_layer_on(moved, device="cuda", real_dtype=torch.float32)

These would still pass if you reverted the fix to a buggy _apply only if the bug also happened to land on the right device — which it can't, because they assert the target, not agreement with a twin path.

3. Behavioral test — the thing the original bug actually broke

The PML-379 bug wasn't "the bookkeeping field is wrong," it was forward() mixing devices → runtime error. The strongest "complete working" test reproduces exactly that: move the layer, then run a forward pass and assert it doesn't raise and the output lands where expected.

def _example_input(layer):
    # adapt to your fixture's expected input shape/dtype (see the existing
    # memristive forward tests in this file for the right shape)
    return torch.rand(1, layer.input_size, dtype=layer.dtype, device=layer.device)


@pytest.mark.parametrize("move", [
    lambda l: l.float(),
    lambda l: l.double(),
    lambda l: l.cpu(),
])
def test_forward_runs_after_move(layer, move):
    moved = move(deepcopy(layer))
    out = moved(_example_input(moved))                # must not raise device/dtype mismatch
    assert out.device.type == (moved.device.type if moved.device is not None else "cpu")
    if out.is_floating_point():
        assert out.dtype == moved.dtype


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA unavailable")
def test_forward_runs_after_cuda_roundtrip(layer):
    moved = deepcopy(layer).cuda().cpu()              # round-trip exercises both directions
    out = moved(_example_input(moved))
    assert out.device.type == "cpu"

If _apply ever leaves the memristive state on the wrong device, this test fails with the real-world symptom, regardless of what the bookkeeping says.

4. Keep the dtype/device-edge tests you already have

test_half_raises_value_error and test_dtype_only_move_leaves_device_unchanged are still valuable and not tautological — they pin contract details (float16 rejected, device=None preserved on pure dtype moves). Keep them.

5. (Optional) one consistency guard against re-introducing a divergent to()

After deletion, to()/.cpu()/.float() are the same code, so a match-test is near-tautological — but one cheap guard is still worth keeping to fail loudly if someone later re-adds a custom to() override that drifts:

def test_to_delegates_to_apply(layer):
    """to() must remain a thin alias of _apply (no divergent override)."""
    via_to    = deepcopy(layer).to(dtype=torch.float64)
    via_apply = deepcopy(layer).double()
    assert_layer_on(via_to,    device="cpu", real_dtype=torch.float64)
    assert_layer_on(via_apply, device="cpu", real_dtype=torch.float64)

Note it asserts both against ground truth, not against each other — so it documents intent without being a pure self-comparison.


⚠️ One real thing the complex dtype exposes — worth a dedicated test

.float()/.double() are defined by torch as floating-point-only (lambda t: t.double() if t.is_floating_point() else t), so they do not touch complex buffers (e.g. unitaries) — those stay complex64 even though _apply sets self.complex_dtype = complex128. .to(torch.float64) uses a different rule (t.is_floating_point() or t.is_complex()). That's a genuine divergence the old path-equality test could never see, and it's exactly why ground-truth testing is the right move. I deliberately left complex out of assert_layer_on above and suggest pinning the intended contract explicitly:

def test_double_complex_dtype_contract(layer):
    moved = deepcopy(layer).double()
    complex_tensors = [t for _, t in _iter_layer_tensors(moved) if t.is_complex()]
    # Decide the contract and assert it. Either:
    #   (a) complex stays at its paired width -> assert all complex128, OR
    #   (b) torch's float-only .double() leaves complex64 untouched.
    # If (a) is what you want but this fails, the fix is in _apply (also convert
    # complex tensors to self.complex_dtype), NOT in the test.
    for t in complex_tensors:
        assert t.dtype == moved.complex_dtype

If that fails, you've found a real _apply gap — not a test bug.


Summary of changes to tests/algorithms/test_layer.py:

  1. Replace the test_apply_methods_match_to[...] / ..._noisy / test_cuda_matches_to* match-tests with the ground-truth versions in §2 (and drop the now-redundant assert_layers_equal/_module_or_sequence_equal comparison helpers, or keep just the one guard in §5).
  2. Add the behavioral forward-after-move tests in §3 — these reproduce the actual PML-379 failure mode.
  3. Keep test_half_raises_value_error and test_dtype_only_move_leaves_device_unchanged.
  4. Add the complex-dtype contract test in §6 to pin the .float()/.double() vs .to() behavior.

I've described the full implementation above rather than committing it — happy to push these into the branch if you'd like; note that prior pushes from this job hit a 403 (github-actions[bot] lacks write to merlinquantum/merlin), and I can't run pytest in this sandbox, so please run pytest -q tests/algorithms/test_layer.py after applying.
· PML-379-Fix-.device-for-memristor

@ben9871

ben9871 commented Jun 26, 2026

Copy link
Copy Markdown
Contributor

Summary

Issue raised by Claude on review of release 0.4.0

Memristive state not moved by .cuda()/.cpu()/.float() (device/dtype split) QuantumLayer overrides to() to move memristive_state/memristive_history, but nn.Module.cuda()/.cpu()/.float()/.double()/.half() route through _apply, not to(). Since the memristive tensors are plain Python lists (not registered buffers), those move paths leave them on the old device/dtype, and the next forward() mixes devices → runtime error or silent wrong-device compute. Fix: override _apply (or register the state/history as buffers) so all torch move paths convert the memristive tensors.

Related Issue

PML-379

Type of change

  • Bug fix
  • New feature
  • Documentation update
  • Refactor / Cleanup
  • Performance improvement
  • CI / Build / Tooling
  • Breaking change (requires migration notes)

Proposed changes

  • Added a _apply method, does the same thing as to by moving all auxiliary modules with to and the rest uses apply like it is supposed to do.

How to test / How to run

pytest -q

Documentation

  • User docs updated (Sphinx)
  • Examples / notebooks updated
  • Docstrings updated
  • Updated the API

@claude following the recent fixes, codex raised this concern "The issue is that QuantumLayer.to() already triggers _apply() via PyTorch, but then repeats its own dtype/device bookkeeping afterward, creating a second path that can drift from .cpu(), .cuda(), .float(), and .double().". please substantiate it and see if it merits further investigation within this ticket

@claude

claude Bot commented Jun 26, 2026

Copy link
Copy Markdown

Claude finished @ben9871's task in 2m 50s —— View job


Substantiating Codex's to() / _apply() drift concern

  • Read current to() and _apply() on the branch
  • Trace the call path and confirm/refute drift
  • Build a concrete drift reproduction
  • Assess whether it merits action in this ticket

Verdict: Codex is correct. The concern is real and in-scope for PML-379.

Both to() (layer.py:1502) and _apply() (layer.py:1549) currently exist on the branch — the earlier suggestion to collapse them was never landed. Tracing the call path confirms the bookkeeping runs twice via two different rules, and they observably diverge.

Why there are two paths

torch.nn.Module.to() is implemented as return self._apply(convert). So when you call layer.to(...):

  1. to() line 1517super().to(*args, **kwargs) calls nn.Module.to, which calls self._apply(convert). Because _apply is overridden, this dispatches into QuantumLayer._apply — running the entire move: params/buffers, memristive tensors, the probe-based self.device/self.dtype bookkeeping (lines 1593–1604), computation_process.to, and the auxiliary modules.
  2. to() lines 1518–1545 — then re-parses args/kwargs and overwrites self.device / self.dtype / self.complex_dtype a second time, using a different rule.

.cpu(), .cuda(), .float(), .double(), .half() skip step 1's wrapper and go straight through _apply — so they only ever run the probe-based rule. .to() runs both. That asymmetry is the drift.

The rules disagree on self.device

path rule for self.device
_apply (probe) set self.device = moved.device only if moved.device != old_device — a same-device or dtype-only move leaves it untouched (stays None)
to() body (line 1544) if device is not None: self.device = torch.device(device) — sets a concrete device whenever a device= arg was passed, even if nothing moved

Concrete divergence (traced)

Starting from a fresh layer (self.device is None):

.cpu()        ->  self.device  is None        # _apply: cpu == old cpu, left as-is
.to("cpu")    ->  self.device == device('cpu')  # _apply leaves None, then to() body forces concrete

So layer.cpu().device is None but layer.to("cpu").device == torch.device("cpu")the same logical no-op move lands the bookkeeping in two different states depending on which alias you call. This is not hypothetical: it's the exact AssertionError: assert device(type='cpu') == None you hit earlier in test_apply_methods_match_to. (I couldn't execute the repro here — the CI sandbox blocks process execution — but the trace is unambiguous and matches the failure you already observed.)

Note the current PR tests don't catch it, because assert_layer_on guards the device check with if layer.device is not None: and the tested .to(...) calls pass no device= arg. The drift is masked, not fixed.

Secondary observations from the same trace

  • Pure redundancy / wasted work. Even when the two paths agree (e.g. dtype-only moves), to()'s entire bookkeeping block is dead work — _apply already did it during super().to(). This is the "second path that can drift" Codex names; today it costs a redundant setup_device_and_dtype call, tomorrow it's where the two will silently diverge as one is edited.
  • .half() leaves the module half-converted. In _apply, super()._apply(fn) (line 1578) runs before the probe at line 1596 that raises ValueError for float16. So .half() converts every real param to float16 and then raises — the module is left in a broken partial state. Worth fixing alongside (move the probe/validation before super()._apply).

Recommendation — yes, address it in this ticket

It's squarely in PML-379's scope (device/dtype consistency across exactly these methods). The clean fix is the one already discussed but never landed: delete the to() override entirely and let _apply be the single source of truth. Since nn.Module.to() is return self._apply(convert), .to(...) then runs the identical code as .cpu()/.float()/.double(), and the drift becomes structurally impossible.

After deleting to(), decide the intended contract for self.device on a same-device/dtype-only move (probe-based logic leaves it None, meaning "torch default placement"). That's a legitimate state everywhere in layer.py, but it does mean .to("cpu").device becomes None rather than concrete — make that the deliberate, documented behavior rather than an accident of which alias was called.

Fix this →

I've kept this to analysis as you asked ("substantiate... and see if it merits investigation"). Happy to implement the to() deletion + .half() ordering fix and adjust the tests if you want me to push it to this branch.
· PML-379-Fix-.device-for-memristor

@LF-Vigneux

Copy link
Copy Markdown
Contributor Author

@claude review my last fix in regards to your last comment

@claude

claude Bot commented Jun 26, 2026

Copy link
Copy Markdown

Claude finished @LF-Vigneux's task in 2m 1s —— View job


Review of the latest fix (to() deleted, _apply single source of truth)

  • Read current _apply on the branch (layer.py:1503)
  • Confirm Codex's drift concern is resolved
  • Verify .half() ordering
  • Trace the probe logic for regressions
  • Review the new tests

✅ What's correct now

  • The to() override is gone. _apply (layer.py:1503) is now the single source of truth, and since nn.Module.to() is literally return self._apply(convert), .to(...), .cpu(), .cuda(), .float(), .double() all run the exact same code. This structurally eliminates the drift Codex flagged (findings Docs #3/include license in py files #4/move tests to root, add license header in all files #5) — the .to("cpu").device == device('cpu') vs .cpu().device is None divergence can no longer happen.
  • .half() ordering is fixed. The probe + setup_device_and_dtype validation (lines 1532–1544) now runs before super()._apply(fn) (line 1546), so .half() / torch.float16 raises ValueError before any real tensor is partially converted. Good — this was the partial-conversion concern from the last comment.
  • Memristive tensors move via fn(tensor) (lines 1548–1558), so they inherit PyTorch's float-only semantics rather than a forced dtype= — consistent with params/buffers.
  • Tests match the ground-truth + behavioral approach: assert_layer_on walks every tensor (including memristive lists) with a seen_memristive guard, test_forward_runs_after_move reproduces the original PML-379 failure mode, and test_half_raises_value_error / test_dtype_only_move_leaves_device_unchanged pin the contract. Solid.

🔴 One real regression — the probe uses the wrong dtype

merlin/algorithms/layer.py:1534-1535

old_dtype = self.dtype if self.dtype is not None else torch.get_default_dtype()
probe = torch.zeros((), device=old_device, dtype=torch.get_default_dtype())  # ← ignores old_dtype

old_dtype is computed but never used — the probe is built with torch.get_default_dtype() (always float32). The dtype-change check on line 1541 then compares moved.dtype against old_dtype, so when the layer's current dtype isn't the default, a device-only move silently resets the dtype:

layer.double()        # self.dtype = float64, params float64
layer.cpu()           # device-only move
  ├─ old_dtype = float64
  ├─ probe = float32  (BUG: should be float64)
  ├─ fn = .cpu()      → moved.dtype = float32   (dtype untouched by a device move)
  ├─ float32 != float64  → True
  └─ setup_device_and_dtype(None, float32) → self.dtype = float32   ← WRONG
# super()._apply only moved device → params stay float64
# result: self.dtype=float32 + computation_process forced to float32,
#         but params/memristive tensors are float64 → mixed-dtype forward error

This reintroduces exactly the device/dtype drift this PR set out to kill — just on the .double().cpu() (or .cuda() on an already-double layer) path. The current tests don't catch it because every test starts from a fresh float32 layer, so old_dtype happens to equal the default and the bug is masked.

Fix: build the probe with old_dtype:

probe = torch.zeros((), device=old_device, dtype=old_dtype)

Fix this →

Add a regression test (currently missing): deepcopy(layer).double().cpu() then assert_layer_on(..., real_dtype=torch.float64) — it fails today, passes after the one-line fix.


🟢 Minor

  1. Docstring wrong defaultlayer.py:1522-1523 says recurse "Default is False", but the signature is recurse=True. Flip the text.
  2. Dead helper in tests_example_input(layer) (test_layer.py) is defined but unused now that test_forward_runs_after_move calls moved() with no args. Remove it to avoid confusion.
  3. Auxiliary modules still use a second rule — lines 1562–1588 move _photon_loss_transform / _detector_transform / _probability_readout via .to(**target_kwargs) (forcing dtype=self.dtype) rather than fn. It's contained (these are real-valued submodules and .half() already raised earlier), but it's the one remaining spot where a non-fn conversion rule survives. Also note the single-module (non-Sequence) variants are registered submodules already moved by super()._apply(fn), so their explicit .to(...) is redundant — only the Sequence/list variants strictly need manual handling. Low priority; fine to leave for this ticket.
  4. Comment typo — line 1502 "the to method refers to apply_ directly" reads a bit oddly now that to() is deleted; consider ".to() routes through _apply (nn.Module.to calls self._apply), so this is the single move path."

Verdict

The structural fix is right and resolves Codex's concern. Finding #1 (probe dtype) is a genuine correctness regression on the double → device-move path and should be fixed (one line) plus a regression test before merge. The rest are minor. I couldn't run pytest in this environment, so please run pytest -q tests/algorithms/test_layer.py after applying.
· PML-379-Fix-.device-for-memristor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants