Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion pathwaysutils/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,12 @@ def _is_default_profile_options(
def _create_profile_request(
log_dir: os.PathLike[str] | str,
profiler_options: jax.profiler.ProfileOptions | None = None,
max_num_hosts: int = 1,
) -> Mapping[str, Any]:
"""Creates a profile request mapping from the given options."""
profile_request: dict[str, Any] = {
"traceLocation": str(log_dir),
"maxNumHosts": max_num_hosts,
}

if profiler_options is None or _is_default_profile_options(profiler_options):
Expand Down Expand Up @@ -173,6 +175,7 @@ def start_trace(
create_perfetto_link: bool = False,
create_perfetto_trace: bool = False,
profiler_options: jax.profiler.ProfileOptions | None = None,
max_num_hosts: int = 1,
) -> None:
"""Starts a profiler trace.

Expand Down Expand Up @@ -201,6 +204,8 @@ def start_trace(
This feature is experimental for Pathways on Cloud and may not be fully
supported.
profiler_options: Profiler options to configure the profiler for collection.
max_num_hosts: An optional integer to limit the number of hosts profiled
(defaults to 1).
"""
if not str(log_dir).startswith("gs://"):
raise ValueError(f"log_dir must be a GCS bucket path, got {log_dir}")
Expand All @@ -218,7 +223,9 @@ def start_trace(
)
profiler_options = None

profile_request = _create_profile_request(log_dir, profiler_options)
profile_request = _create_profile_request(
log_dir, profiler_options, max_num_hosts=max_num_hosts
)

_logger.debug("Profile request: %s", profile_request)

Expand Down Expand Up @@ -366,13 +373,15 @@ def start_trace_patch(
create_perfetto_link: bool = False,
create_perfetto_trace: bool = False,
profiler_options: jax.profiler.ProfileOptions | None = None,
max_num_hosts: int = 1,
) -> None:
_logger.debug("jax.profile.start_trace patched with pathways' start_trace")
start_trace(
log_dir,
create_perfetto_link=create_perfetto_link,
create_perfetto_trace=create_perfetto_trace,
profiler_options=profiler_options,
max_num_hosts=max_num_hosts,
)

jax.profiler.start_trace = start_trace_patch
Expand Down
52 changes: 52 additions & 0 deletions pathwaysutils/test/profiling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def test_start_trace_success(self):
json.dumps({
"profileRequest": {
"traceLocation": "gs://test_bucket/test_dir",
"maxNumHosts": 1,
}
})
)
Expand All @@ -243,6 +244,25 @@ def test_start_trace_success(self):
)
self.assertIsNotNone(profiling._profile_state.executable)

def test_start_trace_with_max_num_hosts(self):
profiling.start_trace("gs://test_bucket/test_dir", max_num_hosts=10)

self.mock_toy_computation.assert_called_once()
self.mock_plugin_executable_cls.assert_called_once_with(
json.dumps({
"profileRequest": {
"traceLocation": "gs://test_bucket/test_dir",
"maxNumHosts": 10,
}
})
)
self.mock_plugin_executable_cls.return_value.call.assert_called_once()
self.mock_original_start_trace.assert_called_once_with(
log_dir="gs://test_bucket/test_dir",
create_perfetto_link=False,
create_perfetto_trace=False,
)

def test_start_trace_no_toy_computation_second_time(self):
profiling.start_trace("gs://test_bucket/test_dir")
profiling.stop_trace()
Expand Down Expand Up @@ -408,6 +428,24 @@ def test_monkey_patched_start_trace(self, profiler_module):
create_perfetto_link=False,
create_perfetto_trace=False,
profiler_options=None,
max_num_hosts=1,
)

@parameterized.named_parameters(
dict(testcase_name="jax_profiler", profiler_module=jax.profiler),
dict(testcase_name="jax_src_profiler", profiler_module=jax._src.profiler),
)
def test_monkey_patched_start_trace_with_max_num_hosts(self, profiler_module):
mocks = self._setup_monkey_patch()

profiler_module.start_trace("gs://bucket/dir", max_num_hosts=3)

mocks["start_trace"].assert_called_once_with(
"gs://bucket/dir",
create_perfetto_link=False,
create_perfetto_trace=False,
profiler_options=None,
max_num_hosts=3,
)

@parameterized.named_parameters(
Expand Down Expand Up @@ -444,6 +482,19 @@ def test_create_profile_request_default_options(self, profiler_options):
request,
{
"traceLocation": "gs://bucket/dir",
"maxNumHosts": 1,
},
)

def test_create_profile_request_with_max_num_hosts(self):
request = profiling._create_profile_request(
"gs://bucket/dir", max_num_hosts=5
)
self.assertEqual(
request,
{
"traceLocation": "gs://bucket/dir",
"maxNumHosts": 5,
},
)

Expand Down Expand Up @@ -471,6 +522,7 @@ def test_create_profile_request_with_options(self):
{
"traceLocation": "gs://bucket/dir",
"maxDurationSecs": 2.0,
"maxNumHosts": 1,
"xprofTraceOptions": {
"traceDirectory": "gs://bucket/dir",
"pwTraceOptions": {
Expand Down
Loading