1414
1515import json
1616import logging
17- import unittest
1817from unittest import mock
1918
2019from absl .testing import absltest
2120from absl .testing import parameterized
2221import jax
22+ from jax import numpy as jnp
2323from pathwaysutils import profiling
2424import requests
2525
@@ -213,10 +213,9 @@ def test_lock_released_on_stop_failure(self):
213213 """Tests that the lock is released if stop_trace fails."""
214214 profiling .start_trace ("gs://test_bucket/test_dir3" )
215215 self .assertFalse (profiling ._profile_state .lock .locked ())
216- mock_result = (
217- self . mock_plugin_executable_cls . return_value . call . return_value [ 1 ]
216+ self . mock_plugin_executable_cls . return_value . call . side_effect = (
217+ RuntimeError ( "stop failed" )
218218 )
219- mock_result .result .side_effect = RuntimeError ("stop failed" )
220219 with self .assertRaisesRegex (RuntimeError , "stop failed" ):
221220 profiling .stop_trace ()
222221 self .assertFalse (profiling ._profile_state .lock .locked ())
@@ -277,6 +276,34 @@ def test_stop_trace_success(self):
277276 with self .subTest ("executable_is_none" ):
278277 self .assertIsNone (profiling ._profile_state .executable )
279278
279+ @absltest .skipIf (
280+ jax .version .__version_info__ < (0 , 9 , 2 ),
281+ "ProfileOptions requires JAX 0.9.2 or newer" ,
282+ )
283+ def test_stop_trace_with_xprof_options_passes_out_avals (self ):
284+ options = jax .profiler .ProfileOptions ()
285+ options .duration_ms = 2000
286+
287+ # Bypass start_trace and explicitly populate profile state
288+ request = profiling ._create_profile_request (
289+ "gs://test_bucket/test_dir" , options
290+ )
291+ profiling ._profile_state .profile_request = request
292+ profiling ._profile_state .executable = (
293+ self .mock_plugin_executable_cls .return_value
294+ )
295+
296+ profiling .stop_trace ()
297+
298+ self .mock_plugin_executable_cls .return_value .call .assert_called_once ()
299+ _ , kwargs = self .mock_plugin_executable_cls .return_value .call .call_args
300+ self .assertIn ("out_avals" , kwargs )
301+ self .assertIn ("out_shardings" , kwargs )
302+ self .assertLen (kwargs ["out_avals" ], 1 )
303+ # Check that it's an object dtype ShapedArray
304+ self .assertEqual (kwargs ["out_avals" ][0 ].shape , (1 ,))
305+ self .assertEqual (kwargs ["out_avals" ][0 ].dtype , jnp .object_ )
306+
280307 def test_stop_trace_before_start_error (self ):
281308 with self .assertRaisesRegex (
282309 ValueError , "stop_trace called before a trace is being taken!"
@@ -406,7 +433,7 @@ def test_create_profile_request_default_options(self, profiler_options):
406433 },
407434 )
408435
409- @unittest .skipIf (
436+ @absltest .skipIf (
410437 jax .version .__version_info__ < (0 , 9 , 2 ),
411438 "ProfileOptions requires JAX 0.9.2 or newer" ,
412439 )
@@ -444,41 +471,45 @@ def test_create_profile_request_with_options(self):
444471 },
445472 )
446473
447- @unittest .skipIf (
474+ @absltest .skipIf (
448475 jax .version .__version_info__ < (0 , 9 , 2 ),
449476 "ProfileOptions requires JAX 0.9.2 or newer" ,
450477 )
451478 @parameterized .parameters (
452479 ({"traceLocation" : "gs://test_bucket/test_dir" },),
453- ({
454- "traceLocation" : "gs://test_bucket/test_dir" ,
455- "blockUntilStart" : True ,
456- "maxDurationSecs" : 10.0 ,
457- "devices" : {"deviceIds" : [1 , 2 ]},
458- "includeResourceManagers" : True ,
459- "maxNumHosts" : 5 ,
460- "xprofTraceOptions" : {
480+ (
481+ {
482+ "traceLocation" : "gs://test_bucket/test_dir" ,
461483 "blockUntilStart" : True ,
462- "traceDirectory" : "gs://test_bucket/test_dir" ,
484+ "maxDurationSecs" : 10.0 ,
485+ "devices" : {"deviceIds" : [1 , 2 ]},
486+ "includeResourceManagers" : True ,
487+ "maxNumHosts" : 5 ,
488+ "xprofTraceOptions" : {
489+ "blockUntilStart" : True ,
490+ "traceDirectory" : "gs://test_bucket/test_dir" ,
491+ },
463492 },
464- },),
465- ({
466- "traceLocation" : "gs://bucket/dir" ,
467- "xprofTraceOptions" : {
468- "hostTraceLevel" : 0 ,
469- "traceOptions" : {
470- "traceMode" : "TRACE_COMPUTE" ,
471- "numSparseCoresToTrace" : 1 ,
472- "numSparseCoreTilesToTrace" : 2 ,
473- "numChipsToProfilePerTask" : 3 ,
474- "powerTraceLevel" : 4 ,
475- "enableFwThrottleEvent" : True ,
476- "enableFwPowerLevelEvent" : True ,
477- "enableFwThermalEvent" : True ,
493+ ),
494+ (
495+ {
496+ "traceLocation" : "gs://bucket/dir" ,
497+ "xprofTraceOptions" : {
498+ "hostTraceLevel" : 0 ,
499+ "traceOptions" : {
500+ "traceMode" : "TRACE_COMPUTE" ,
501+ "numSparseCoresToTrace" : 1 ,
502+ "numSparseCoreTilesToTrace" : 2 ,
503+ "numChipsToProfilePerTask" : 3 ,
504+ "powerTraceLevel" : 4 ,
505+ "enableFwThrottleEvent" : True ,
506+ "enableFwPowerLevelEvent" : True ,
507+ "enableFwThermalEvent" : True ,
508+ },
509+ "traceDirectory" : "gs://bucket/dir" ,
478510 },
479- "traceDirectory" : "gs://bucket/dir" ,
480511 },
481- }, ),
512+ ),
482513 )
483514
484515 def test_start_pathways_trace_from_profile_request (self , profile_request ):
@@ -496,10 +527,9 @@ def test_original_stop_trace_called_on_stop_failure(self):
496527 """Tests that original_stop_trace is called if pathways stop_trace fails."""
497528 profiling .start_trace ("gs://test_bucket/test_dir" )
498529 self .assertFalse (profiling ._profile_state .lock .locked ())
499- mock_result = (
500- self . mock_plugin_executable_cls . return_value . call . return_value [ 1 ]
530+ self . mock_plugin_executable_cls . return_value . call . side_effect = (
531+ RuntimeError ( "stop failed" )
501532 )
502- mock_result .result .side_effect = RuntimeError ("stop failed" )
503533 with self .assertRaisesRegex (RuntimeError , "stop failed" ):
504534 profiling .stop_trace ()
505535 self .mock_original_stop_trace .assert_called_once ()
0 commit comments