From fd3d9977c713eb1ab72b5035aff2a7bee26348fa Mon Sep 17 00:00:00 2001 From: Pathways-on-Cloud Team Date: Tue, 14 Apr 2026 13:49:46 -0700 Subject: [PATCH] add support for max_num_hosts in start_trace PiperOrigin-RevId: 899756971 --- pathwaysutils/profiling.py | 11 +++++- pathwaysutils/test/profiling_test.py | 52 ++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/pathwaysutils/profiling.py b/pathwaysutils/profiling.py index a3d5671..b4f4378 100644 --- a/pathwaysutils/profiling.py +++ b/pathwaysutils/profiling.py @@ -80,10 +80,12 @@ def _is_default_profile_options( def _create_profile_request( log_dir: os.PathLike[str] | str, profiler_options: jax.profiler.ProfileOptions | None = None, + max_num_hosts: int = 1, ) -> Mapping[str, Any]: """Creates a profile request mapping from the given options.""" profile_request: dict[str, Any] = { "traceLocation": str(log_dir), + "maxNumHosts": max_num_hosts, } if profiler_options is None or _is_default_profile_options(profiler_options): @@ -173,6 +175,7 @@ def start_trace( create_perfetto_link: bool = False, create_perfetto_trace: bool = False, profiler_options: jax.profiler.ProfileOptions | None = None, + max_num_hosts: int = 1, ) -> None: """Starts a profiler trace. @@ -201,6 +204,8 @@ def start_trace( This feature is experimental for Pathways on Cloud and may not be fully supported. profiler_options: Profiler options to configure the profiler for collection. + max_num_hosts: An optional integer to limit the number of hosts profiled + (defaults to 1). """ if not str(log_dir).startswith("gs://"): raise ValueError(f"log_dir must be a GCS bucket path, got {log_dir}") @@ -218,7 +223,9 @@ def start_trace( ) profiler_options = None - profile_request = _create_profile_request(log_dir, profiler_options) + profile_request = _create_profile_request( + log_dir, profiler_options, max_num_hosts=max_num_hosts + ) _logger.debug("Profile request: %s", profile_request) @@ -366,6 +373,7 @@ def start_trace_patch( create_perfetto_link: bool = False, create_perfetto_trace: bool = False, profiler_options: jax.profiler.ProfileOptions | None = None, + max_num_hosts: int = 1, ) -> None: _logger.debug("jax.profile.start_trace patched with pathways' start_trace") start_trace( @@ -373,6 +381,7 @@ def start_trace_patch( create_perfetto_link=create_perfetto_link, create_perfetto_trace=create_perfetto_trace, profiler_options=profiler_options, + max_num_hosts=max_num_hosts, ) jax.profiler.start_trace = start_trace_patch diff --git a/pathwaysutils/test/profiling_test.py b/pathwaysutils/test/profiling_test.py index 66e6d57..033901e 100644 --- a/pathwaysutils/test/profiling_test.py +++ b/pathwaysutils/test/profiling_test.py @@ -232,6 +232,7 @@ def test_start_trace_success(self): json.dumps({ "profileRequest": { "traceLocation": "gs://test_bucket/test_dir", + "maxNumHosts": 1, } }) ) @@ -243,6 +244,25 @@ def test_start_trace_success(self): ) self.assertIsNotNone(profiling._profile_state.executable) + def test_start_trace_with_max_num_hosts(self): + profiling.start_trace("gs://test_bucket/test_dir", max_num_hosts=10) + + self.mock_toy_computation.assert_called_once() + self.mock_plugin_executable_cls.assert_called_once_with( + json.dumps({ + "profileRequest": { + "traceLocation": "gs://test_bucket/test_dir", + "maxNumHosts": 10, + } + }) + ) + self.mock_plugin_executable_cls.return_value.call.assert_called_once() + self.mock_original_start_trace.assert_called_once_with( + log_dir="gs://test_bucket/test_dir", + create_perfetto_link=False, + create_perfetto_trace=False, + ) + def test_start_trace_no_toy_computation_second_time(self): profiling.start_trace("gs://test_bucket/test_dir") profiling.stop_trace() @@ -408,6 +428,24 @@ def test_monkey_patched_start_trace(self, profiler_module): create_perfetto_link=False, create_perfetto_trace=False, profiler_options=None, + max_num_hosts=1, + ) + + @parameterized.named_parameters( + dict(testcase_name="jax_profiler", profiler_module=jax.profiler), + dict(testcase_name="jax_src_profiler", profiler_module=jax._src.profiler), + ) + def test_monkey_patched_start_trace_with_max_num_hosts(self, profiler_module): + mocks = self._setup_monkey_patch() + + profiler_module.start_trace("gs://bucket/dir", max_num_hosts=3) + + mocks["start_trace"].assert_called_once_with( + "gs://bucket/dir", + create_perfetto_link=False, + create_perfetto_trace=False, + profiler_options=None, + max_num_hosts=3, ) @parameterized.named_parameters( @@ -444,6 +482,19 @@ def test_create_profile_request_default_options(self, profiler_options): request, { "traceLocation": "gs://bucket/dir", + "maxNumHosts": 1, + }, + ) + + def test_create_profile_request_with_max_num_hosts(self): + request = profiling._create_profile_request( + "gs://bucket/dir", max_num_hosts=5 + ) + self.assertEqual( + request, + { + "traceLocation": "gs://bucket/dir", + "maxNumHosts": 5, }, ) @@ -471,6 +522,7 @@ def test_create_profile_request_with_options(self): { "traceLocation": "gs://bucket/dir", "maxDurationSecs": 2.0, + "maxNumHosts": 1, "xprofTraceOptions": { "traceDirectory": "gs://bucket/dir", "pwTraceOptions": {