Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/noether/core/utils/common/stopwatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,13 @@ def sync(device: torch.device) -> None:
def _flush_pending_gpu_laps(self) -> None:
"""Resolve pending GPU event pairs into elapsed seconds.

Synchronizes each end event before calling ``elapsed_time()``.
Synchronizes the device once before resolving all pending event pairs.
Per-event synchronization (``end_event.synchronize()``) can stall on MPS,
so a single device-level sync is used instead.
"""
if self._gpu_pending_laps and self._device is not None:
self.sync(self._device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be fine but I am slightly worried that this causes more synchronization than necessary on CUDA. Could we perhaps do this only for MPS?

for start_event, end_event in self._gpu_pending_laps:
end_event.synchronize()
self._elapsed_seconds.append(start_event.elapsed_time(end_event) / 1000.0) # type: ignore[arg-type]
self._gpu_pending_laps.clear()

Expand Down
104 changes: 104 additions & 0 deletions tests/unit/noether/core/utils/common/test_stopwatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,19 @@
from unittest.mock import patch

import pytest
import torch

from noether.core.utils.common.stopwatch import Stopwatch

MODULE_PATH = "noether.core.utils.common.stopwatch"

_has_cuda = torch.cuda.is_available()
_has_mps = torch.backends.mps.is_available()

requires_cuda = pytest.mark.skipif(not _has_cuda, reason="CUDA not available")
requires_mps = pytest.mark.skipif(not _has_mps, reason="MPS not available")
requires_gpu = pytest.mark.skipif(not (_has_cuda or _has_mps), reason="No GPU available")


@patch(f"{MODULE_PATH}.time.perf_counter")
def test_stopwatch_basic_flow(mock_time):
Expand Down Expand Up @@ -133,3 +141,99 @@ def test_real_time_accuracy():
# Check total:
assert sw.elapsed_seconds == pytest.approx(0.3, abs=0.02)
assert sw.lap_count == 2


def _gpu_device() -> torch.device:
"""Return the first available GPU device."""
if _has_cuda:
return torch.device("cuda")
if _has_mps:
return torch.device("mps")
pytest.skip("No GPU available")


def _warmup_gpu(device: torch.device) -> None:
"""Run a dummy op to ensure the GPU device is initialized (MPS compiles shaders lazily)."""
x = torch.randn(32, 32, device=device)
_ = x @ x
Stopwatch.sync(device)


@requires_gpu
def test_gpu_stopwatch_basic():
"""Start -> stop on GPU returns a positive elapsed time."""
device = _gpu_device()
_warmup_gpu(device)
x = torch.randn(128, 128, device=device)
with Stopwatch(device=device) as sw:
_ = x @ x
assert sw.elapsed_seconds >= 0
assert sw.lap_count == 1


@requires_gpu
def test_gpu_stopwatch_multiple_steps():
"""GPU stopwatch works across many consecutive start/stop cycles without hanging."""
device = _gpu_device()
_warmup_gpu(device)
x = torch.randn(128, 128, device=device)
for i in range(20):
with Stopwatch(device=device) as sw:
_ = x @ x
assert sw.elapsed_seconds >= 0, f"step {i} returned negative elapsed time"


@requires_cuda
def test_gpu_stopwatch_laps():
"""GPU stopwatch correctly records multiple laps.

MPS does not support this because lap() reuses the end event as the start
of the next lap, and torch.mps.Event.elapsed_time() rejects shared events.
This is not an issue in practice since lap() is unused in the training loop.
"""
device = torch.device("cuda")
_warmup_gpu(device)
x = torch.randn(128, 128, device=device)
sw = Stopwatch(device=device)
sw.start()
_ = x @ x
sw.lap()
_ = x @ x
sw.stop()
# GPU laps are resolved lazily when elapsed_seconds is accessed
assert sw.elapsed_seconds >= 0
assert sw.lap_count == 2


@requires_gpu
def test_gpu_stopwatch_falls_back_to_cpu_for_unknown_device():
"""Stopwatch with a CPU device uses wall-clock timing, not GPU events."""
sw = Stopwatch(device=torch.device("cpu"))
assert not sw._use_gpu
sw.start()
time.sleep(0.05)
sw.stop()
assert sw.elapsed_seconds > 0


@requires_mps
def test_mps_event_sync_does_not_hang():
"""Regression test: per-event synchronization was replaced with device-level sync
because torch.mps.Event.synchronize() can hang. Verify the flush path works."""
device = torch.device("mps")
x = torch.randn(256, 256, device=device)
for _ in range(30):
with Stopwatch(device=device) as sw:
_ = x @ x
# accessing elapsed_seconds triggers _flush_pending_gpu_laps
assert sw.elapsed_seconds >= 0


@requires_cuda
def test_cuda_stopwatch_basic():
"""Smoke test: CUDA stopwatch records timing without error."""
device = torch.device("cuda")
x = torch.randn(256, 256, device=device)
with Stopwatch(device=device) as sw:
_ = x @ x
assert sw.elapsed_seconds >= 0
Loading