From 3cc4a453a8d18d0c4192be978520c4d417026822 Mon Sep 17 00:00:00 2001 From: Pavel Kuksa Date: Wed, 22 Apr 2026 13:07:48 +0200 Subject: [PATCH] fix: use device-level sync instead of per-event in Stopwatch to prevent MPS hang --- src/noether/core/utils/common/stopwatch.py | 7 +- .../core/utils/common/test_stopwatch.py | 104 ++++++++++++++++++ 2 files changed, 109 insertions(+), 2 deletions(-) diff --git a/src/noether/core/utils/common/stopwatch.py b/src/noether/core/utils/common/stopwatch.py index 681043c1..60a69d5d 100644 --- a/src/noether/core/utils/common/stopwatch.py +++ b/src/noether/core/utils/common/stopwatch.py @@ -114,10 +114,13 @@ def sync(device: torch.device) -> None: def _flush_pending_gpu_laps(self) -> None: """Resolve pending GPU event pairs into elapsed seconds. - Synchronizes each end event before calling ``elapsed_time()``. + Synchronizes the device once before resolving all pending event pairs. + Per-event synchronization (``end_event.synchronize()``) can stall on MPS, + so a single device-level sync is used instead. """ + if self._gpu_pending_laps and self._device is not None: + self.sync(self._device) for start_event, end_event in self._gpu_pending_laps: - end_event.synchronize() self._elapsed_seconds.append(start_event.elapsed_time(end_event) / 1000.0) # type: ignore[arg-type] self._gpu_pending_laps.clear() diff --git a/tests/unit/noether/core/utils/common/test_stopwatch.py b/tests/unit/noether/core/utils/common/test_stopwatch.py index bc94cbaf..62ac40f5 100644 --- a/tests/unit/noether/core/utils/common/test_stopwatch.py +++ b/tests/unit/noether/core/utils/common/test_stopwatch.py @@ -4,11 +4,19 @@ from unittest.mock import patch import pytest +import torch from noether.core.utils.common.stopwatch import Stopwatch MODULE_PATH = "noether.core.utils.common.stopwatch" +_has_cuda = torch.cuda.is_available() +_has_mps = torch.backends.mps.is_available() + +requires_cuda = pytest.mark.skipif(not _has_cuda, reason="CUDA not available") +requires_mps = pytest.mark.skipif(not _has_mps, reason="MPS not available") +requires_gpu = pytest.mark.skipif(not (_has_cuda or _has_mps), reason="No GPU available") + @patch(f"{MODULE_PATH}.time.perf_counter") def test_stopwatch_basic_flow(mock_time): @@ -133,3 +141,99 @@ def test_real_time_accuracy(): # Check total: assert sw.elapsed_seconds == pytest.approx(0.3, abs=0.02) assert sw.lap_count == 2 + + +def _gpu_device() -> torch.device: + """Return the first available GPU device.""" + if _has_cuda: + return torch.device("cuda") + if _has_mps: + return torch.device("mps") + pytest.skip("No GPU available") + + +def _warmup_gpu(device: torch.device) -> None: + """Run a dummy op to ensure the GPU device is initialized (MPS compiles shaders lazily).""" + x = torch.randn(32, 32, device=device) + _ = x @ x + Stopwatch.sync(device) + + +@requires_gpu +def test_gpu_stopwatch_basic(): + """Start -> stop on GPU returns a positive elapsed time.""" + device = _gpu_device() + _warmup_gpu(device) + x = torch.randn(128, 128, device=device) + with Stopwatch(device=device) as sw: + _ = x @ x + assert sw.elapsed_seconds >= 0 + assert sw.lap_count == 1 + + +@requires_gpu +def test_gpu_stopwatch_multiple_steps(): + """GPU stopwatch works across many consecutive start/stop cycles without hanging.""" + device = _gpu_device() + _warmup_gpu(device) + x = torch.randn(128, 128, device=device) + for i in range(20): + with Stopwatch(device=device) as sw: + _ = x @ x + assert sw.elapsed_seconds >= 0, f"step {i} returned negative elapsed time" + + +@requires_cuda +def test_gpu_stopwatch_laps(): + """GPU stopwatch correctly records multiple laps. + + MPS does not support this because lap() reuses the end event as the start + of the next lap, and torch.mps.Event.elapsed_time() rejects shared events. + This is not an issue in practice since lap() is unused in the training loop. + """ + device = torch.device("cuda") + _warmup_gpu(device) + x = torch.randn(128, 128, device=device) + sw = Stopwatch(device=device) + sw.start() + _ = x @ x + sw.lap() + _ = x @ x + sw.stop() + # GPU laps are resolved lazily when elapsed_seconds is accessed + assert sw.elapsed_seconds >= 0 + assert sw.lap_count == 2 + + +@requires_gpu +def test_gpu_stopwatch_falls_back_to_cpu_for_unknown_device(): + """Stopwatch with a CPU device uses wall-clock timing, not GPU events.""" + sw = Stopwatch(device=torch.device("cpu")) + assert not sw._use_gpu + sw.start() + time.sleep(0.05) + sw.stop() + assert sw.elapsed_seconds > 0 + + +@requires_mps +def test_mps_event_sync_does_not_hang(): + """Regression test: per-event synchronization was replaced with device-level sync + because torch.mps.Event.synchronize() can hang. Verify the flush path works.""" + device = torch.device("mps") + x = torch.randn(256, 256, device=device) + for _ in range(30): + with Stopwatch(device=device) as sw: + _ = x @ x + # accessing elapsed_seconds triggers _flush_pending_gpu_laps + assert sw.elapsed_seconds >= 0 + + +@requires_cuda +def test_cuda_stopwatch_basic(): + """Smoke test: CUDA stopwatch records timing without error.""" + device = torch.device("cuda") + x = torch.randn(256, 256, device=device) + with Stopwatch(device=device) as sw: + _ = x @ x + assert sw.elapsed_seconds >= 0