Skip to content

Commit a7e9e13

Browse files
lukebaumanncopybara-github
authored andcommitted
Fix JaxRuntimeError during profiler stop_trace array verification
When executing a profiler request with Xprof Trace Options, the IFRT proxy outputs a single JAX shaped array. Previously, the Python profiler expected 0 outputs, causing a JaxRuntimeError (`Mismatch between out_handlers and num_results: 0 vs 1`). This CL updates the profiler state implementation to correctly expect `(1,)` output from `PluginExecutable.call()` when stopping a trace with Xprof options, consuming the URL suffix parameter and fixing the crash. PiperOrigin-RevId: 888268746
1 parent 44d0853 commit a7e9e13

2 files changed

Lines changed: 81 additions & 36 deletions

File tree

pathwaysutils/profiling.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,17 @@
3535

3636
class _ProfileState:
3737
executable: plugin_executable.PluginExecutable | None = None
38+
profile_request: Mapping[str, Any] | None = None
3839
lock: threading.Lock
3940

4041
def __init__(self) -> None:
4142
self.executable = None
43+
self.profile_request = None
4244
self.lock = threading.Lock()
4345

4446
def reset(self) -> None:
4547
self.executable = None
48+
self.profile_request = None
4649

4750

4851
_first_profile_start = True
@@ -153,6 +156,7 @@ def _start_pathways_trace_from_profile_request(
153156
_profile_state.executable = plugin_executable.PluginExecutable(
154157
json.dumps({"profileRequest": profile_request})
155158
)
159+
_profile_state.profile_request = profile_request
156160
try:
157161
_, result_future = _profile_state.executable.call()
158162
result_future.result()
@@ -233,8 +237,19 @@ def stop_trace() -> None:
233237
if _profile_state.executable is None:
234238
raise ValueError("stop_trace called before a trace is being taken!")
235239
try:
236-
_, result_future = _profile_state.executable.call()
237-
result_future.result()
240+
if (
241+
_profile_state.profile_request
242+
and "xprofTraceOptions" in _profile_state.profile_request
243+
):
244+
out_avals = [jax.core.ShapedArray((1,), jnp.object_)]
245+
out_shardings = [jax.sharding.SingleDeviceSharding(jax.devices()[0])]
246+
else:
247+
out_avals = ()
248+
out_shardings = ()
249+
250+
_profile_state.executable.call(
251+
out_avals=out_avals, out_shardings=out_shardings
252+
)
238253
finally:
239254
_profile_state.reset()
240255
finally:

pathwaysutils/test/profiling_test.py

Lines changed: 64 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414

1515
import json
1616
import logging
17-
import unittest
1817
from unittest import mock
1918

2019
from absl.testing import absltest
2120
from absl.testing import parameterized
2221
import jax
22+
from jax import numpy as jnp
2323
from pathwaysutils import profiling
2424
import requests
2525

@@ -213,10 +213,9 @@ def test_lock_released_on_stop_failure(self):
213213
"""Tests that the lock is released if stop_trace fails."""
214214
profiling.start_trace("gs://test_bucket/test_dir3")
215215
self.assertFalse(profiling._profile_state.lock.locked())
216-
mock_result = (
217-
self.mock_plugin_executable_cls.return_value.call.return_value[1]
216+
self.mock_plugin_executable_cls.return_value.call.side_effect = (
217+
RuntimeError("stop failed")
218218
)
219-
mock_result.result.side_effect = RuntimeError("stop failed")
220219
with self.assertRaisesRegex(RuntimeError, "stop failed"):
221220
profiling.stop_trace()
222221
self.assertFalse(profiling._profile_state.lock.locked())
@@ -277,6 +276,34 @@ def test_stop_trace_success(self):
277276
with self.subTest("executable_is_none"):
278277
self.assertIsNone(profiling._profile_state.executable)
279278

279+
@absltest.skipIf(
280+
jax.version.__version_info__ < (0, 9, 2),
281+
"ProfileOptions requires JAX 0.9.2 or newer",
282+
)
283+
def test_stop_trace_with_xprof_options_passes_out_avals(self):
284+
options = jax.profiler.ProfileOptions()
285+
options.duration_ms = 2000
286+
287+
# Bypass start_trace and explicitly populate profile state
288+
request = profiling._create_profile_request(
289+
"gs://test_bucket/test_dir", options
290+
)
291+
profiling._profile_state.profile_request = request
292+
profiling._profile_state.executable = (
293+
self.mock_plugin_executable_cls.return_value
294+
)
295+
296+
profiling.stop_trace()
297+
298+
self.mock_plugin_executable_cls.return_value.call.assert_called_once()
299+
_, kwargs = self.mock_plugin_executable_cls.return_value.call.call_args
300+
self.assertIn("out_avals", kwargs)
301+
self.assertIn("out_shardings", kwargs)
302+
self.assertLen(kwargs["out_avals"], 1)
303+
# Check that it's an object dtype ShapedArray
304+
self.assertEqual(kwargs["out_avals"][0].shape, (1,))
305+
self.assertEqual(kwargs["out_avals"][0].dtype, jnp.object_)
306+
280307
def test_stop_trace_before_start_error(self):
281308
with self.assertRaisesRegex(
282309
ValueError, "stop_trace called before a trace is being taken!"
@@ -406,7 +433,7 @@ def test_create_profile_request_default_options(self, profiler_options):
406433
},
407434
)
408435

409-
@unittest.skipIf(
436+
@absltest.skipIf(
410437
jax.version.__version_info__ < (0, 9, 2),
411438
"ProfileOptions requires JAX 0.9.2 or newer",
412439
)
@@ -444,41 +471,45 @@ def test_create_profile_request_with_options(self):
444471
},
445472
)
446473

447-
@unittest.skipIf(
474+
@absltest.skipIf(
448475
jax.version.__version_info__ < (0, 9, 2),
449476
"ProfileOptions requires JAX 0.9.2 or newer",
450477
)
451478
@parameterized.parameters(
452479
({"traceLocation": "gs://test_bucket/test_dir"},),
453-
({
454-
"traceLocation": "gs://test_bucket/test_dir",
455-
"blockUntilStart": True,
456-
"maxDurationSecs": 10.0,
457-
"devices": {"deviceIds": [1, 2]},
458-
"includeResourceManagers": True,
459-
"maxNumHosts": 5,
460-
"xprofTraceOptions": {
480+
(
481+
{
482+
"traceLocation": "gs://test_bucket/test_dir",
461483
"blockUntilStart": True,
462-
"traceDirectory": "gs://test_bucket/test_dir",
484+
"maxDurationSecs": 10.0,
485+
"devices": {"deviceIds": [1, 2]},
486+
"includeResourceManagers": True,
487+
"maxNumHosts": 5,
488+
"xprofTraceOptions": {
489+
"blockUntilStart": True,
490+
"traceDirectory": "gs://test_bucket/test_dir",
491+
},
463492
},
464-
},),
465-
({
466-
"traceLocation": "gs://bucket/dir",
467-
"xprofTraceOptions": {
468-
"hostTraceLevel": 0,
469-
"traceOptions": {
470-
"traceMode": "TRACE_COMPUTE",
471-
"numSparseCoresToTrace": 1,
472-
"numSparseCoreTilesToTrace": 2,
473-
"numChipsToProfilePerTask": 3,
474-
"powerTraceLevel": 4,
475-
"enableFwThrottleEvent": True,
476-
"enableFwPowerLevelEvent": True,
477-
"enableFwThermalEvent": True,
493+
),
494+
(
495+
{
496+
"traceLocation": "gs://bucket/dir",
497+
"xprofTraceOptions": {
498+
"hostTraceLevel": 0,
499+
"traceOptions": {
500+
"traceMode": "TRACE_COMPUTE",
501+
"numSparseCoresToTrace": 1,
502+
"numSparseCoreTilesToTrace": 2,
503+
"numChipsToProfilePerTask": 3,
504+
"powerTraceLevel": 4,
505+
"enableFwThrottleEvent": True,
506+
"enableFwPowerLevelEvent": True,
507+
"enableFwThermalEvent": True,
508+
},
509+
"traceDirectory": "gs://bucket/dir",
478510
},
479-
"traceDirectory": "gs://bucket/dir",
480511
},
481-
},),
512+
),
482513
)
483514

484515
def test_start_pathways_trace_from_profile_request(self, profile_request):
@@ -496,10 +527,9 @@ def test_original_stop_trace_called_on_stop_failure(self):
496527
"""Tests that original_stop_trace is called if pathways stop_trace fails."""
497528
profiling.start_trace("gs://test_bucket/test_dir")
498529
self.assertFalse(profiling._profile_state.lock.locked())
499-
mock_result = (
500-
self.mock_plugin_executable_cls.return_value.call.return_value[1]
530+
self.mock_plugin_executable_cls.return_value.call.side_effect = (
531+
RuntimeError("stop failed")
501532
)
502-
mock_result.result.side_effect = RuntimeError("stop failed")
503533
with self.assertRaisesRegex(RuntimeError, "stop failed"):
504534
profiling.stop_trace()
505535
self.mock_original_stop_trace.assert_called_once()

0 commit comments

Comments
 (0)