@@ -1848,7 +1848,7 @@ def add_task(self, task_pool_name, task_name, nproc, working_dir,
18481848 * args , keywords = keywords )
18491849
18501850 def submit_tasks (self , task_pool_name , block = True , use_dask = False , dask_nodes = 1 ,
1851- dask_ppn = None , launch_interval = 0.0 , use_shifter = False ):
1851+ dask_ppn = None , launch_interval = 0.0 , use_shifter = False , dask_worker_plugin = None ):
18521852 """
18531853 Launch all unfinished tasks in task pool *task_pool_name*. If *block* is ``True``,
18541854 return when all tasks have been launched. If *block* is ``False``, return when all
@@ -1860,7 +1860,7 @@ def submit_tasks(self, task_pool_name, block=True, use_dask=False, dask_nodes=1,
18601860 start_time = time .time ()
18611861 self ._send_monitor_event ('IPS_TASK_POOL_BEGIN' , 'task_pool = %s ' % task_pool_name )
18621862 task_pool : TaskPool = self .task_pools [task_pool_name ]
1863- retval = task_pool .submit_tasks (block , use_dask , dask_nodes , dask_ppn , launch_interval , use_shifter )
1863+ retval = task_pool .submit_tasks (block , use_dask , dask_nodes , dask_ppn , launch_interval , use_shifter , dask_worker_plugin )
18641864 elapsed_time = time .time () - start_time
18651865 self ._send_monitor_event ('IPS_TASK_POOL_END' , 'task_pool = %s elapsed time = %.2f S' %
18661866 (task_pool_name , elapsed_time ),
@@ -2066,7 +2066,7 @@ def add_task(self, task_name, nproc, working_dir, binary, *args, **keywords):
20662066 self .queued_tasks [task_name ] = Task (task_name , nproc , working_dir , binary_fullpath , * args ,
20672067 ** keywords ["keywords" ])
20682068
2069- def submit_dask_tasks (self , block = True , dask_nodes = 1 , dask_ppn = None , use_shifter = False ):
2069+ def submit_dask_tasks (self , block = True , dask_nodes = 1 , dask_ppn = None , use_shifter = False , dask_worker_plugin = None ):
20702070 """Launch tasks in *queued_tasks* using dask.
20712071
20722072 :param block: Unused, this will always return after tasks are submitted
@@ -2077,6 +2077,8 @@ def submit_dask_tasks(self, block=True, dask_nodes=1, dask_ppn=None, use_shifter
20772077 :type dask_ppn: int
20782078 :param use_shifter: Option to launch dask scheduler and workers in shifter container
20792079 :type use_shifter: bool
2080+ :param dask_worker_plugin: If provided this will be registered as a worker plugin with the dask client
2081+ :type dask_worker_plugin: distributed.diagnostics.plugin.WorkerPlugin
20802082 """
20812083 services : ServicesProxy = self .services
20822084 self .dask_file_name = os .path .join (os .getcwd (),
@@ -2115,6 +2117,9 @@ def submit_dask_tasks(self, block=True, dask_nodes=1, dask_ppn=None, use_shifter
21152117
21162118 self .dask_client = self .dask .distributed .Client (scheduler_file = self .dask_file_name )
21172119
2120+ if dask_worker_plugin is not None :
2121+ self .dask_client .register_worker_plugin (dask_worker_plugin )
2122+
21182123 try :
21192124 self .worker_event_logfile = services .sim_name + '_' + services .get_config_param ("PORTAL_RUNID" ) + '_' + self .name + '_{}.json'
21202125 except KeyError :
@@ -2135,7 +2140,7 @@ def submit_dask_tasks(self, block=True, dask_nodes=1, dask_ppn=None, use_shifter
21352140 self .queued_tasks = {}
21362141 return len (self .futures )
21372142
2138- def submit_tasks (self , block = True , use_dask = False , dask_nodes = 1 , dask_ppn = None , launch_interval = 0.0 , use_shifter = False ):
2143+ def submit_tasks (self , block = True , use_dask = False , dask_nodes = 1 , dask_ppn = None , launch_interval = 0.0 , use_shifter = False , dask_worker_plugin = None ):
21392144 """Launch tasks in *queued_tasks*. Finished tasks are handled before
21402145 launching new ones. If *block* is ``True``, the number of
21412146 tasks submitted is returned after all tasks have been launched
@@ -2157,7 +2162,8 @@ def submit_tasks(self, block=True, use_dask=False, dask_nodes=1, dask_ppn=None,
21572162 :type launch_internal: float
21582163 :param use_shifter: Option to launch dask scheduler and workers in shifter container
21592164 :type use_shifter: bool
2160-
2165+ :param dask_worker_plugin: If provided this will be registered as a worker plugin with the dask client
2166+ :type dask_worker_plugin: distributed.diagnostics.plugin.WorkerPlugin
21612167 """
21622168
21632169 if use_dask :
@@ -2167,7 +2173,7 @@ def submit_tasks(self, block=True, use_dask=False, dask_nodes=1, dask_ppn=None,
21672173 self .services .error ("Requested to run dask within shifter but shifter not available" )
21682174 raise Exception ("shifter not found" )
21692175 else :
2170- return self .submit_dask_tasks (block , dask_nodes , dask_ppn , use_shifter )
2176+ return self .submit_dask_tasks (block , dask_nodes , dask_ppn , use_shifter , dask_worker_plugin )
21712177 elif not TaskPool .dask :
21722178 self .services .warning ("Requested use_dask but cannot because import dask failed" )
21732179 elif not self .serial_pool :
0 commit comments