Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions mlpstorage_py/benchmarks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,15 +442,22 @@ def _collect_cluster_information(self) -> 'ClusterInformation':
mpi_bin = getattr(self.args, 'mpi_bin', 'mpirun')
allow_run_as_root = getattr(self.args, 'allow_run_as_root', False)
timeout = getattr(self.args, 'cluster_collection_timeout', 60)
ssh_username = getattr(self.args, 'ssh_username', None)
shared_staging_dir = getattr(self.args, 'shared_staging_dir', None)

# Collect cluster info
# Collect cluster info. ``results_dir`` is required by
# ``collect_cluster_info`` for staging the helper script under
# ``<results_dir>/collector-staging/`` (see issue #363).
collected_data = collect_cluster_info(
hosts=self.args.hosts,
mpi_bin=mpi_bin,
logger=self.logger,
results_dir=self.run_result_output,
allow_run_as_root=allow_run_as_root,
timeout_seconds=timeout,
fallback_to_local=True
fallback_to_local=True,
shared_staging_dir=shared_staging_dir,
ssh_username=ssh_username,
)

# Create ClusterInformation from collected data
Expand Down
128 changes: 126 additions & 2 deletions mlpstorage_py/tests/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,20 +218,30 @@ def _run(self):
benchmark = TestBenchmark.__new__(TestBenchmark)
benchmark.args = base_args
benchmark.logger = mock_logger
# ``run_result_output`` is normally set in ``Benchmark.__init__``
# via ``generate_output_location()``. We patched ``__init__``
# away, so set it explicitly so the call site has a results dir
# to forward to ``collect_cluster_info`` (issue #363).
benchmark.run_result_output = '/tmp/results/run-001'

with patch('mlpstorage_py.benchmarks.base.collect_cluster_info') as mock_collect:
mock_collect.return_value = mock_collected_data

result = benchmark._collect_cluster_information()

# Verify collect_cluster_info was called with correct args
# Verify collect_cluster_info was called with correct args.
# ``results_dir`` is REQUIRED by collect_cluster_info; missing
# it was the root cause of issue #363.
mock_collect.assert_called_once_with(
hosts=['host1', 'host2'],
mpi_bin='mpirun',
logger=mock_logger,
results_dir='/tmp/results/run-001',
allow_run_as_root=False,
timeout_seconds=60,
fallback_to_local=True
fallback_to_local=True,
shared_staging_dir=None,
ssh_username=None,
)

# Verify result is a ClusterInformation instance
Expand Down Expand Up @@ -260,6 +270,120 @@ def _run(self):
assert result is None


# =============================================================================
# Regression tests for issue #363
# =============================================================================
# The original bug was that ``Benchmark._collect_cluster_information`` called
# ``collect_cluster_info`` without the required ``results_dir`` argument. Every
# pre-existing test patched ``collect_cluster_info`` away, so the missing-arg
# ``TypeError`` never surfaced. The tests below validate the call against the
# *real* function signature so future signature drift is caught at unit-test
# time.

class TestCollectClusterInfoSignatureBinding:
"""Issue #363: guard ``_collect_cluster_information`` against signature drift."""

def test_call_binds_to_real_collect_cluster_info_signature(
self, base_args, mock_logger
):
"""The kwargs passed by ``_collect_cluster_information`` must bind to
the real ``collect_cluster_info`` signature without raising
``TypeError`` for missing required arguments.

This is what would have caught issue #363 before merge.
"""
import inspect
from mlpstorage_py.benchmarks.base import Benchmark
from mlpstorage_py.cluster_collector import collect_cluster_info

class TestBenchmark(Benchmark):
BENCHMARK_TYPE = BENCHMARK_TYPES.training
def _run(self):
pass

sig = inspect.signature(collect_cluster_info)
captured_kwargs = {}

def capture(*args, **kwargs):
# Reject positional shadowing — the call site is keyword-only.
assert not args, "call site should use keyword arguments only"
captured_kwargs.update(kwargs)
# Validate against the REAL signature; this raises TypeError if
# any required parameter (e.g., ``results_dir``) is missing.
sig.bind(**kwargs)
return {
'host1': {'hostname': 'host1', 'meminfo': {'MemTotal': 16384000}},
'_metadata': {
'collection_method': 'mpi',
'collection_timestamp': '2024-01-01T00:00:00Z',
},
}

with patch.object(TestBenchmark, '__init__', lambda x, *a, **kw: None):
benchmark = TestBenchmark.__new__(TestBenchmark)
benchmark.args = base_args
benchmark.logger = mock_logger
benchmark.run_result_output = '/tmp/results/run-001'

with patch(
'mlpstorage_py.benchmarks.base.collect_cluster_info',
side_effect=capture,
):
benchmark._collect_cluster_information()

# ``results_dir`` is the parameter that was missing in issue #363.
assert 'results_dir' in captured_kwargs
assert captured_kwargs['results_dir'] == '/tmp/results/run-001'

def test_warning_message_from_issue_363_is_not_emitted(
self, base_args, mock_logger
):
"""The exact warning ``MPI cluster info collection failed:
collect_cluster_info() missing 1 required positional argument:
'results_dir'`` must NOT appear after the fix.
"""
from mlpstorage_py.benchmarks.base import Benchmark

class TestBenchmark(Benchmark):
BENCHMARK_TYPE = BENCHMARK_TYPES.training
def _run(self):
pass

warnings_seen = []

class CapturingLogger(MockLogger):
def warning(self, msg):
warnings_seen.append(msg)

with patch.object(TestBenchmark, '__init__', lambda x, *a, **kw: None):
benchmark = TestBenchmark.__new__(TestBenchmark)
benchmark.args = base_args
benchmark.logger = CapturingLogger()
benchmark.run_result_output = '/tmp/results/run-001'

# Use the REAL ``collect_cluster_info`` but stub out the heavy
# ``MPIClusterCollector`` so we don't need an actual cluster.
with patch(
'mlpstorage_py.cluster_collector.MPIClusterCollector'
) as mock_collector_cls:
mock_instance = MagicMock()
mock_instance.collect.return_value = {
'host1': {'hostname': 'host1', 'meminfo': {'MemTotal': 16384000}},
}
mock_collector_cls.return_value = mock_instance

benchmark._collect_cluster_information()

offending = [
w for w in warnings_seen
if 'missing 1 required positional argument' in w
and 'results_dir' in w
]
assert offending == [], (
f"Issue #363 warning regressed: {offending}"
)


# =============================================================================
# Tests for DLIOBenchmark.accumulate_host_info
# =============================================================================
Expand Down
Loading