@@ -232,6 +232,13 @@ def test_start_trace_success(self):
232232 json .dumps ({
233233 "profileRequest" : {
234234 "traceLocation" : "gs://test_bucket/test_dir" ,
235+ "maxNumHosts" : 1 ,
236+ "xprofTraceOptions" : {
237+ "traceDirectory" : "gs://test_bucket/test_dir" ,
238+ "pwTraceOptions" : {
239+ "enablePythonTracer" : True ,
240+ },
241+ },
235242 }
236243 })
237244 )
@@ -243,6 +250,31 @@ def test_start_trace_success(self):
243250 )
244251 self .assertIsNotNone (profiling ._profile_state .executable )
245252
253+ def test_start_trace_with_max_num_hosts (self ):
254+ profiling .start_trace ("gs://test_bucket/test_dir" , max_num_hosts = 10 )
255+
256+ self .mock_toy_computation .assert_called_once ()
257+ self .mock_plugin_executable_cls .assert_called_once_with (
258+ json .dumps ({
259+ "profileRequest" : {
260+ "traceLocation" : "gs://test_bucket/test_dir" ,
261+ "maxNumHosts" : 10 ,
262+ "xprofTraceOptions" : {
263+ "traceDirectory" : "gs://test_bucket/test_dir" ,
264+ "pwTraceOptions" : {
265+ "enablePythonTracer" : True ,
266+ },
267+ },
268+ }
269+ })
270+ )
271+ self .mock_plugin_executable_cls .return_value .call .assert_called_once ()
272+ self .mock_original_start_trace .assert_called_once_with (
273+ log_dir = "gs://test_bucket/test_dir" ,
274+ create_perfetto_link = False ,
275+ create_perfetto_trace = False ,
276+ )
277+
246278 def test_start_trace_no_toy_computation_second_time (self ):
247279 profiling .start_trace ("gs://test_bucket/test_dir" )
248280 profiling .stop_trace ()
@@ -408,6 +440,24 @@ def test_monkey_patched_start_trace(self, profiler_module):
408440 create_perfetto_link = False ,
409441 create_perfetto_trace = False ,
410442 profiler_options = None ,
443+ max_num_hosts = 1 ,
444+ )
445+
446+ @parameterized .named_parameters (
447+ dict (testcase_name = "jax_profiler" , profiler_module = jax .profiler ),
448+ dict (testcase_name = "jax_src_profiler" , profiler_module = jax ._src .profiler ),
449+ )
450+ def test_monkey_patched_start_trace_with_max_num_hosts (self , profiler_module ):
451+ mocks = self ._setup_monkey_patch ()
452+
453+ profiler_module .start_trace ("gs://bucket/dir" , max_num_hosts = 3 )
454+
455+ mocks ["start_trace" ].assert_called_once_with (
456+ "gs://bucket/dir" ,
457+ create_perfetto_link = False ,
458+ create_perfetto_trace = False ,
459+ profiler_options = None ,
460+ max_num_hosts = 3 ,
411461 )
412462
413463 @parameterized .named_parameters (
@@ -444,6 +494,31 @@ def test_create_profile_request_default_options(self, profiler_options):
444494 request ,
445495 {
446496 "traceLocation" : "gs://bucket/dir" ,
497+ "maxNumHosts" : 1 ,
498+ "xprofTraceOptions" : {
499+ "traceDirectory" : "gs://bucket/dir" ,
500+ "pwTraceOptions" : {
501+ "enablePythonTracer" : True ,
502+ },
503+ },
504+ },
505+ )
506+
507+ def test_create_profile_request_with_max_num_hosts (self ):
508+ request = profiling ._create_profile_request (
509+ "gs://bucket/dir" , max_num_hosts = 5
510+ )
511+ self .assertEqual (
512+ request ,
513+ {
514+ "traceLocation" : "gs://bucket/dir" ,
515+ "maxNumHosts" : 5 ,
516+ "xprofTraceOptions" : {
517+ "traceDirectory" : "gs://bucket/dir" ,
518+ "pwTraceOptions" : {
519+ "enablePythonTracer" : True ,
520+ },
521+ },
447522 },
448523 )
449524
@@ -471,6 +546,7 @@ def test_create_profile_request_with_options(self):
471546 {
472547 "traceLocation" : "gs://bucket/dir" ,
473548 "maxDurationSecs" : 2.0 ,
549+ "maxNumHosts" : 1 ,
474550 "xprofTraceOptions" : {
475551 "traceDirectory" : "gs://bucket/dir" ,
476552 "pwTraceOptions" : {
0 commit comments