Skip to content

Commit 4b28263

Browse files
Pathways-on-Cloud Teamcopybara-github
authored andcommitted
add support for max_num_hosts in start_trace
PiperOrigin-RevId: 899756971
1 parent c2a9fe0 commit 4b28263

File tree

2 files changed

+91
-3
lines changed

2 files changed

+91
-3
lines changed

pathwaysutils/profiling.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,19 @@ def _is_default_profile_options(
8080
def _create_profile_request(
8181
log_dir: os.PathLike[str] | str,
8282
profiler_options: jax.profiler.ProfileOptions | None = None,
83+
max_num_hosts: int = 1,
8384
) -> Mapping[str, Any]:
8485
"""Creates a profile request mapping from the given options."""
8586
profile_request: dict[str, Any] = {
8687
"traceLocation": str(log_dir),
88+
"maxNumHosts": max_num_hosts,
8789
}
8890

89-
if profiler_options is None or _is_default_profile_options(profiler_options):
90-
return profile_request
91+
if profiler_options is None:
92+
profiler_options = jax.profiler.ProfileOptions()
93+
94+
# C++ code requires xprofTraceOptions to be set to respect max_num_hosts.
95+
# So we do not return early here if we want max_num_hosts to take effect.
9196

9297
advanced_config = None
9398
if getattr(profiler_options, "advanced_configuration", None):
@@ -173,6 +178,7 @@ def start_trace(
173178
create_perfetto_link: bool = False,
174179
create_perfetto_trace: bool = False,
175180
profiler_options: jax.profiler.ProfileOptions | None = None,
181+
max_num_hosts: int = 1,
176182
) -> None:
177183
"""Starts a profiler trace.
178184
@@ -201,6 +207,8 @@ def start_trace(
201207
This feature is experimental for Pathways on Cloud and may not be fully
202208
supported.
203209
profiler_options: Profiler options to configure the profiler for collection.
210+
max_num_hosts: An optional integer to limit the number of hosts profiled
211+
(defaults to 1).
204212
"""
205213
if not str(log_dir).startswith("gs://"):
206214
raise ValueError(f"log_dir must be a GCS bucket path, got {log_dir}")
@@ -218,7 +226,9 @@ def start_trace(
218226
)
219227
profiler_options = None
220228

221-
profile_request = _create_profile_request(log_dir, profiler_options)
229+
profile_request = _create_profile_request(
230+
log_dir, profiler_options, max_num_hosts=max_num_hosts
231+
)
222232

223233
_logger.debug("Profile request: %s", profile_request)
224234

@@ -366,13 +376,15 @@ def start_trace_patch(
366376
create_perfetto_link: bool = False,
367377
create_perfetto_trace: bool = False,
368378
profiler_options: jax.profiler.ProfileOptions | None = None,
379+
max_num_hosts: int = 1,
369380
) -> None:
370381
_logger.debug("jax.profile.start_trace patched with pathways' start_trace")
371382
start_trace(
372383
log_dir,
373384
create_perfetto_link=create_perfetto_link,
374385
create_perfetto_trace=create_perfetto_trace,
375386
profiler_options=profiler_options,
387+
max_num_hosts=max_num_hosts,
376388
)
377389

378390
jax.profiler.start_trace = start_trace_patch

pathwaysutils/test/profiling_test.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,13 @@ def test_start_trace_success(self):
232232
json.dumps({
233233
"profileRequest": {
234234
"traceLocation": "gs://test_bucket/test_dir",
235+
"maxNumHosts": 1,
236+
"xprofTraceOptions": {
237+
"traceDirectory": "gs://test_bucket/test_dir",
238+
"pwTraceOptions": {
239+
"enablePythonTracer": True,
240+
},
241+
},
235242
}
236243
})
237244
)
@@ -243,6 +250,31 @@ def test_start_trace_success(self):
243250
)
244251
self.assertIsNotNone(profiling._profile_state.executable)
245252

253+
def test_start_trace_with_max_num_hosts(self):
254+
profiling.start_trace("gs://test_bucket/test_dir", max_num_hosts=10)
255+
256+
self.mock_toy_computation.assert_called_once()
257+
self.mock_plugin_executable_cls.assert_called_once_with(
258+
json.dumps({
259+
"profileRequest": {
260+
"traceLocation": "gs://test_bucket/test_dir",
261+
"maxNumHosts": 10,
262+
"xprofTraceOptions": {
263+
"traceDirectory": "gs://test_bucket/test_dir",
264+
"pwTraceOptions": {
265+
"enablePythonTracer": True,
266+
},
267+
},
268+
}
269+
})
270+
)
271+
self.mock_plugin_executable_cls.return_value.call.assert_called_once()
272+
self.mock_original_start_trace.assert_called_once_with(
273+
log_dir="gs://test_bucket/test_dir",
274+
create_perfetto_link=False,
275+
create_perfetto_trace=False,
276+
)
277+
246278
def test_start_trace_no_toy_computation_second_time(self):
247279
profiling.start_trace("gs://test_bucket/test_dir")
248280
profiling.stop_trace()
@@ -408,6 +440,24 @@ def test_monkey_patched_start_trace(self, profiler_module):
408440
create_perfetto_link=False,
409441
create_perfetto_trace=False,
410442
profiler_options=None,
443+
max_num_hosts=1,
444+
)
445+
446+
@parameterized.named_parameters(
447+
dict(testcase_name="jax_profiler", profiler_module=jax.profiler),
448+
dict(testcase_name="jax_src_profiler", profiler_module=jax._src.profiler),
449+
)
450+
def test_monkey_patched_start_trace_with_max_num_hosts(self, profiler_module):
451+
mocks = self._setup_monkey_patch()
452+
453+
profiler_module.start_trace("gs://bucket/dir", max_num_hosts=3)
454+
455+
mocks["start_trace"].assert_called_once_with(
456+
"gs://bucket/dir",
457+
create_perfetto_link=False,
458+
create_perfetto_trace=False,
459+
profiler_options=None,
460+
max_num_hosts=3,
411461
)
412462

413463
@parameterized.named_parameters(
@@ -444,6 +494,31 @@ def test_create_profile_request_default_options(self, profiler_options):
444494
request,
445495
{
446496
"traceLocation": "gs://bucket/dir",
497+
"maxNumHosts": 1,
498+
"xprofTraceOptions": {
499+
"traceDirectory": "gs://bucket/dir",
500+
"pwTraceOptions": {
501+
"enablePythonTracer": True,
502+
},
503+
},
504+
},
505+
)
506+
507+
def test_create_profile_request_with_max_num_hosts(self):
508+
request = profiling._create_profile_request(
509+
"gs://bucket/dir", max_num_hosts=5
510+
)
511+
self.assertEqual(
512+
request,
513+
{
514+
"traceLocation": "gs://bucket/dir",
515+
"maxNumHosts": 5,
516+
"xprofTraceOptions": {
517+
"traceDirectory": "gs://bucket/dir",
518+
"pwTraceOptions": {
519+
"enablePythonTracer": True,
520+
},
521+
},
447522
},
448523
)
449524

@@ -471,6 +546,7 @@ def test_create_profile_request_with_options(self):
471546
{
472547
"traceLocation": "gs://bucket/dir",
473548
"maxDurationSecs": 2.0,
549+
"maxNumHosts": 1,
474550
"xprofTraceOptions": {
475551
"traceDirectory": "gs://bucket/dir",
476552
"pwTraceOptions": {

0 commit comments

Comments
 (0)