PML-379: Fix .cuda(), .cpu(), .float() for memristor#278
Conversation
|
@claude review |
|
Claude finished @LF-Vigneux's task in 2m 50s —— View job PR Review: PML-379 Fix
|
|
@claude I have tried to correct your 🟡 Design / maintainability concerns, however I have a problem with the fact that device=None does not pass as None device with the apply function now. How would you fix it? |
|
Claude finished @LF-Vigneux's task in 5m 32s —— View job PML-379 — Fixing
|
| call | old_device |
moved.device |
result |
|---|---|---|---|
.float() (dtype only) |
cpu/None | cpu | device left as-is (None preserved) ✅ |
.cuda() from cpu |
cpu | cuda:0 | self.device = cuda:0 ✅ |
.cpu() from cuda |
cuda:0 | cpu | self.device = cpu ✅ |
Two issues from my earlier review fall out for free
.half()no longer crashes — thedtype in (float32, float64)guard skipssetup_device_and_dtype(which raisesValueErroronfloat16/complex).- Auxiliary list-modules now reuse
fndirectly instead of a reconstructedtarget_kwargs, so they inherit the samedevice=None/float-only semantics. Single-module variants are registered submodules and were already moved bysuper()._apply, so only theSequencevariants need handling:
for aux in (self._photon_loss_transform, self._detector_transform, self._probability_readout):
if isinstance(aux, Sequence):
for module in aux:
module._apply(fn, recurse=recurse)Bonus: removed the now-redundant to() override
nn.Module.to() is return self._apply(convert). Now that _apply is the single source of truth, the to() override only duplicated logic (and its own arg-parsing would still ValueError on .to(torch.float16)). Deleting it removes the double-move (review finding #3) and the divergence (#4/#5), and routes .to() through the same robust path.
⚠️ Couldn't push
I committed the change locally (eef6e17), but the push was rejected — Permission to merlinquantum/merlin.git denied to github-actions[bot] (403). So you'll need to apply it. Full diff below.
Diff for merlin/algorithms/layer.py
@@ class QuantumLayer(MerlinModule):
- def to(self, *args: Any, **kwargs: Any) -> QuantumLayer:
- """Move the layer and auxiliary transforms to a new device or dtype.
- ...
- """
- super().to(*args, **kwargs)
- ... # (entire to() override deleted)
- return self
-
def _apply(self, fn, recurse=True):
...
super()._apply(fn, recurse=recurse)
# ... memristive tensors moved with fn(tensor) ...
- # infer canonical device/dtype AFTER move
- ref_tensor = None
- for p in self.parameters():
- ref_tensor = p
- break
- if ref_tensor is None:
- for b in self.buffers():
- ref_tensor = b
- break
- if ref_tensor is not None:
- self.device = ref_tensor.device
- if ref_tensor.dtype in (torch.float32, torch.float64):
- _, self.dtype, self.complex_dtype = MerlinModule.setup_device_and_dtype(
- None, ref_tensor.dtype
- )
+ old_device = self.device if self.device is not None else torch.device("cpu")
+ probe = torch.zeros((), device=old_device, dtype=torch.get_default_dtype())
+ moved = fn(probe)
+ if moved.device != old_device:
+ self.device = moved.device
+ if moved.dtype in (torch.float32, torch.float64):
+ _, self.dtype, self.complex_dtype = MerlinModule.setup_device_and_dtype(
+ None, moved.dtype
+ )
self.computation_process.to(dtype=self.dtype, device=self.device)
- target_kwargs = {"dtype": self.dtype}
- if self.device is not None:
- target_kwargs["device"] = self.device
- if self._photon_loss_transform is not None:
- ... # .to(**target_kwargs) for both single + list variants
- if self._detector_transform is not None:
- ...
- if self._probability_readout is not None:
- self._probability_readout = self._probability_readout.to(device=self.device)
+ for aux in (
+ self._photon_loss_transform,
+ self._detector_transform,
+ self._probability_readout,
+ ):
+ if isinstance(aux, Sequence):
+ for module in aux:
+ module._apply(fn, recurse=recurse)
return selfNote: I couldn't run
pytestin this environment (sandbox blocks process execution). Please runpytest -q tests/algorithms/test_layer.pyafter applying, and consider adding a.half()test plus an assertion thatlayer.to(torch.float64)leaveslayer.deviceunchanged — both now behave correctly with this fix.
·PML-379-Fix-.device-for-memristor
|
@claude half should fail no? Is it what it was happening with to. If so, how do I keep it. Also, the added tests are failing right now because AssertionError: assert device(type='cpu') == None, how would you change that |
|
Claude finished @LF-Vigneux's task in 6m 3s —— View job PML-379 —
|
|
@claude Since the to and apply are now indentical, excatly how would you change the tests and implement them to test the complete working |
|
Claude finished @LF-Vigneux's task in 3m 2s —— View job How to test once
|
@claude following the recent fixes, codex raised this concern "The issue is that QuantumLayer.to() already triggers _apply() via PyTorch, but then repeats its own dtype/device bookkeeping afterward, creating a second path that can drift from .cpu(), .cuda(), .float(), and .double().". please substantiate it and see if it merits further investigation within this ticket |
|
Claude finished @ben9871's task in 2m 50s —— View job Substantiating Codex's
|
| path | rule for self.device |
|---|---|
_apply (probe) |
set self.device = moved.device only if moved.device != old_device — a same-device or dtype-only move leaves it untouched (stays None) |
to() body (line 1544) |
if device is not None: self.device = torch.device(device) — sets a concrete device whenever a device= arg was passed, even if nothing moved |
Concrete divergence (traced)
Starting from a fresh layer (self.device is None):
.cpu() -> self.device is None # _apply: cpu == old cpu, left as-is
.to("cpu") -> self.device == device('cpu') # _apply leaves None, then to() body forces concrete
So layer.cpu().device is None but layer.to("cpu").device == torch.device("cpu") — the same logical no-op move lands the bookkeeping in two different states depending on which alias you call. This is not hypothetical: it's the exact AssertionError: assert device(type='cpu') == None you hit earlier in test_apply_methods_match_to. (I couldn't execute the repro here — the CI sandbox blocks process execution — but the trace is unambiguous and matches the failure you already observed.)
Note the current PR tests don't catch it, because assert_layer_on guards the device check with if layer.device is not None: and the tested .to(...) calls pass no device= arg. The drift is masked, not fixed.
Secondary observations from the same trace
- Pure redundancy / wasted work. Even when the two paths agree (e.g. dtype-only moves),
to()'s entire bookkeeping block is dead work —_applyalready did it duringsuper().to(). This is the "second path that can drift" Codex names; today it costs a redundantsetup_device_and_dtypecall, tomorrow it's where the two will silently diverge as one is edited. .half()leaves the module half-converted. In_apply,super()._apply(fn)(line 1578) runs before the probe at line 1596 that raisesValueErrorforfloat16. So.half()converts every real param tofloat16and then raises — the module is left in a broken partial state. Worth fixing alongside (move the probe/validation beforesuper()._apply).
Recommendation — yes, address it in this ticket
It's squarely in PML-379's scope (device/dtype consistency across exactly these methods). The clean fix is the one already discussed but never landed: delete the to() override entirely and let _apply be the single source of truth. Since nn.Module.to() is return self._apply(convert), .to(...) then runs the identical code as .cpu()/.float()/.double(), and the drift becomes structurally impossible.
After deleting to(), decide the intended contract for self.device on a same-device/dtype-only move (probe-based logic leaves it None, meaning "torch default placement"). That's a legitimate state everywhere in layer.py, but it does mean .to("cpu").device becomes None rather than concrete — make that the deliberate, documented behavior rather than an accident of which alias was called.
I've kept this to analysis as you asked ("substantiate... and see if it merits investigation"). Happy to implement the to() deletion + .half() ordering fix and adjust the tests if you want me to push it to this branch.
· PML-379-Fix-.device-for-memristor
|
@claude review my last fix in regards to your last comment |
|
Claude finished @LF-Vigneux's task in 2m 1s —— View job Review of the latest fix (
|
Summary
Issue raised by Claude on review of release 0.4.0
Memristive state not moved by .cuda()/.cpu()/.float() (device/dtype split)
QuantumLayer overrides to() to move memristive_state/memristive_history, but nn.Module.cuda()/.cpu()/.float()/.double()/.half() route through _apply, not to(). Since the memristive tensors are plain Python lists (not registered buffers), those move paths leave them on the old device/dtype, and the next forward() mixes devices → runtime error or silent wrong-device compute.
Fix: override _apply (or register the state/history as buffers) so all torch move paths convert the memristive tensors.
Related Issue
PML-379
Type of change
Proposed changes
How to test / How to run
Documentation