From e1c906bdf93f2923b72cfbd545006e7ebdd68388 Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Thu, 18 Dec 2025 17:55:25 +0000 Subject: [PATCH 1/4] Implement gradient redux at nodes --- mosaic/core/task.py | 3 +- mosaic/core/tessera.py | 3 +- mosaic/runtime/monitor.py | 2 + mosaic/runtime/runtime.py | 45 +++++++++++++++++- mosaic/runtime/warehouse.py | 54 +++++++++++++++++++-- stride/core.py | 86 +++++++++++++++++++++++++++------- stride/physics/problem_type.py | 20 +++++++- 7 files changed, 188 insertions(+), 25 deletions(-) diff --git a/mosaic/core/task.py b/mosaic/core/task.py index 96f4ccd..d51c26a 100644 --- a/mosaic/core/task.py +++ b/mosaic/core/task.py @@ -675,8 +675,9 @@ class TaskProxy(ProxyBase): def __init__(self, proxy, method, *args, **kwargs): super().__init__(*args, **kwargs) + cls = proxy._cls.cls if hasattr(proxy._cls, 'cls') else proxy._cls self._uid = '%s-%s-%s-%s' % ('task', - proxy._cls.cls.__name__.lower(), + cls.__name__.lower(), method, uuid.uuid4().hex) self._tessera_proxy = proxy diff --git a/mosaic/core/tessera.py b/mosaic/core/tessera.py index 5e1e534..1c56dae 100644 --- a/mosaic/core/tessera.py +++ b/mosaic/core/tessera.py @@ -632,9 +632,10 @@ def __init__(self, cls, *args, **kwargs): runtime = kwargs.pop('runtime', None) self._runtime_id = runtime.uid if hasattr(runtime, 'uid') else runtime + uid = kwargs.pop('uid', None) self._uid = '%s-%s-%s' % ('tess', self._cls.__name__.lower(), - uuid.uuid4().hex) + uuid.uuid4().hex) if uid is None else uid self._cls_attr_names = None self._set_cls() diff --git a/mosaic/runtime/monitor.py b/mosaic/runtime/monitor.py index bf291c8..9c29f51 100644 --- a/mosaic/runtime/monitor.py +++ b/mosaic/runtime/monitor.py @@ -557,6 +557,8 @@ async def barrier(self, sender_id, timeout=None): if timeout is not None and (time.time() - tic) > timeout: break + await self._local_warehouse.run_barrier_tasks(reply=True) + self._monitored_tasks = dict() self._runtime_tasks = defaultdict(list) self._dirty_tasks = set() diff --git a/mosaic/runtime/runtime.py b/mosaic/runtime/runtime.py index 8ba4484..03fa6f4 100644 --- a/mosaic/runtime/runtime.py +++ b/mosaic/runtime/runtime.py @@ -179,6 +179,7 @@ def __init__(self, **kwargs): self._dealloc_queue = [] self._maintenance_queue = [] self._maintenance_msgs = {} + self._barrier_tasks = [] self._inside_async_for = False @@ -339,6 +340,25 @@ def cpu_load(self): return 0 + def register_barrier_task(self, task): + """ + Register task for execution at barrier. + + Parameters + ---------- + task + + Returns + ------- + + """ + self._barrier_tasks.append(task) + + async def run_barrier_tasks(self, sender_id): + for task in self._barrier_tasks: + await task() + self._barrier_tasks = [] + async def barrier(self, timeout=None): """ Wait until all pending tasks are done. If no timeout is @@ -1080,7 +1100,27 @@ async def get(self, uid, cache=True): return obj - async def drop(self, uid, cache_only=False): + async def exec(self, uid, func, func_args=None, func_kwargs=None): + """ + Retrieve an object from the warehouse. + + Parameters + ---------- + uid + func + func_args + func_kwargs + + Returns + ------- + + """ + ret = await self._local_warehouse.exec_remote(uid=uid, func=func, + func_args=func_args, func_kwargs=func_kwargs, + reply=True) + return ret + + async def drop(self, uid, cache_only=False, propagate=False): """ Delete an object from the warehouse. @@ -1088,13 +1128,14 @@ async def drop(self, uid, cache_only=False): ---------- uid cache_only + propagate Returns ------- """ if not cache_only: - await self._local_warehouse.drop_remote(uid=uid) + await self._local_warehouse.drop_remote(uid=uid, propagate=propagate) obj_uid = uid.uid if hasattr(uid, 'uid') else uid try: diff --git a/mosaic/runtime/warehouse.py b/mosaic/runtime/warehouse.py index 152bd8b..2849eb9 100644 --- a/mosaic/runtime/warehouse.py +++ b/mosaic/runtime/warehouse.py @@ -167,7 +167,42 @@ async def get_remote(self, sender_id, uid, warehouse_id=None, node_id=None): return obj - async def drop_remote(self, sender_id, uid): + async def exec_remote(self, sender_id, uid, func, func_args=None, func_kwargs=None): + """ + Retrieve an object from the warehouse and execute function on it. + + Parameters + ---------- + sender_id + uid + func + func_args + func_kwargs + + Returns + ------- + + """ + if isinstance(uid, WarehouseObject): + obj_id = uid.uid + else: + obj_id = uid + + try: + obj = self._local_warehouse[obj_id] + except KeyError: + obj = None + + func_args = func_args or () + func_kwargs = func_kwargs or {} + obj = await func(obj, *func_args, **func_kwargs) + + self._local_warehouse[obj_id] = obj + + warehouse_obj = WarehouseObject(obj, uid=obj_id) + return warehouse_obj + + async def drop_remote(self, sender_id, uid, propagate=False): """ Delete an object from the warehouse. @@ -175,16 +210,27 @@ async def drop_remote(self, sender_id, uid): ---------- sender_id uid + propagate Returns ------- """ if isinstance(uid, WarehouseObject): - uid = uid.uid + obj_id = uid.uid - if uid in self._local_warehouse: - del self._local_warehouse[uid] + if propagate: + node_id = uid.node_id + warehouse_id = uid.warehouse_id + + if node_id is not None and warehouse_id is not None: + if node_id not in self._warehouses: + self._warehouses[node_id] = self.proxy(uid=warehouse_id) + + await self._warehouses[node_id].drop_remote(uid=uid) + + if obj_id in self._local_warehouse: + del self._local_warehouse[obj_id] async def push_remote(self, sender_id, __dict__, uid=None, warehouse_id=None, node_id=None, diff --git a/stride/core.py b/stride/core.py index e08de56..c82cc89 100644 --- a/stride/core.py +++ b/stride/core.py @@ -2,14 +2,13 @@ import uuid import asyncio import inspect -import numpy as np from abc import abstractmethod from collections import OrderedDict import mosaic from mosaic import types from mosaic.core.base import CMDBase -from mosaic.core import TaskProxy +from mosaic.core import TesseraProxy, TaskProxy __all__ = ['Variable', 'Operator'] @@ -35,7 +34,6 @@ async def _maybe_sum(a, b): return b[0] + a, b[1] elif isinstance(a, tuple) and isinstance(b, tuple): return a[0] + b[0], a[1] + b[1] - return a + b @@ -295,6 +293,9 @@ def __init__(self, *args, **kwargs): self.graph = Graph() self.prev_op = None self.needs_grad = kwargs.pop('needs_grad', False) + + self._redux_grads = set() + self._redux_task = False async def adjoint(self, grad=None, **kwargs): """ @@ -322,6 +323,17 @@ async def adjoint(self, grad=None, **kwargs): runtime = mosaic.runtime() + async def redux(rec_grads, *grads): + if rec_grads is None: + sums = [ + _maybe_sum(None, g) for g in grads + ] + else: + sums = [ + _maybe_sum(r, g) for r, g in zip(rec_grads, grads) + ] + return await asyncio.gather(*sums) + def dealloc(objs): def _dealloc(*args): loop = mosaic.get_event_loop() @@ -351,7 +363,9 @@ def _dealloc(*args): except AttributeError: method = getattr(node.op.obj, node.method) if hasattr(node.op, 'is_parameter') and node.op.is_parameter: - ret = method(*output_grads, **{**kwargs_, **{'eager': True}}) + redux_grad = await runtime.exec('redux-%s' % node.op.uid, redux, output_grads) + ret = method((redux_grad,), **{**kwargs_, **{'eager': True, 'redux': True}}) + else: ret = method(*output_grads, **kwargs_) @@ -539,7 +553,7 @@ def process_grad(self): """ raise NotImplementedError('Unimplemented Variable method process_grad') - async def __call_adjoint__(self, grad, **kwargs): + async def __call_adjoint__(self, grad, redux=False, **kwargs): """ Adjoint operation of the variable, which accumulates the given gradient on the ``Variable.grad`` attribute. @@ -553,25 +567,65 @@ async def __call_adjoint__(self, grad, **kwargs): ------- """ + if redux: + self._redux_grads.add(grad[0]) + + if not self._redux_task: + runtime = mosaic.runtime() + runtime.register_barrier_task(self.__redux_adjoint__) + self._redux_task = True + + return + if grad is None or not self.needs_grad or self.grad is None: return - grad_data = grad.data if hasattr(grad, 'data') else grad - is_nan = np.any(np.isnan(grad_data)) - is_inf = np.any(np.isinf(grad_data)) + self.grad += grad + + async def __call_adjoint_list__(self, *grad, **kwargs): + """ + List adjoint operation of the variable, which accumulates the given + gradient on the ``Variable.grad`` attribute. + + Parameters + ---------- + grad : tuple + Provided gradient + + Returns + ------- - if is_nan or is_inf: - msg = 'Nan or inf detected in %s' % self.name + """ + for g in grad: + await self.__call_adjoint__(g[0]) - problem = kwargs.pop('problem', None) - shot_id = problem.shot.id if problem is not None else kwargs.pop('shot_id', None) - if shot_id is not None: - msg = '(ShotID %d) ' % shot_id + msg + async def __redux_adjoint__(self): + """ + Reduction adjoint operation of the variable, which accumulates the given + gradient on the ``Variable.grad`` attribute. - mosaic.logger().warn(msg) + Parameters + ---------- + + Returns + ------- + + """ + if not hasattr(self, '_tessera'): return - self.grad += grad + tess = self._tessera + redux_proxy = TesseraProxy(tess._cls, runtime=tess.runtime_id, uid=tess.uid) + redux_proxy.init_future.set_result(True) + redux_proxy.state_changed('listening') + redux_task = TaskProxy(redux_proxy, '__call_adjoint_list__', *self._redux_grads) + await redux_proxy._init_task(redux_task, *self._redux_grads) + await redux_task + + drops = [] + for g in self._redux_grads: + drops.append(g.drop(propagate=True)) + await asyncio.gather(*drops) def __repr__(self): return self.name diff --git a/stride/physics/problem_type.py b/stride/physics/problem_type.py index 83f7432..8970369 100644 --- a/stride/physics/problem_type.py +++ b/stride/physics/problem_type.py @@ -1,4 +1,5 @@ from abc import ABC +import numpy as np import mosaic @@ -292,6 +293,23 @@ def get_grad(self, *wrt, **kwargs): if method is None: raise ValueError('Variable %s not implemented' % variable.name) - grads.append(method(variable, wrt=wrt, **kwargs)) + grad = method(variable, wrt=wrt, **kwargs) + + grad_data = grad.data if hasattr(grad, 'data') else grad + is_nan = np.any(np.isnan(grad_data)) + is_inf = np.any(np.isinf(grad_data)) + + if is_nan or is_inf: + msg = 'Nan or inf detected in %s' % self.name + + problem = kwargs.pop('problem', None) + shot_id = problem.shot.id if problem is not None else kwargs.pop('shot_id', None) + if shot_id is not None: + msg = '(ShotID %d) ' % shot_id + msg + + self.logger.warn(msg) + return + + grads.append(grad) return tuple(grads) From 4e6b8b2d79ef7b1a55d30d8a2c715ec1e05d8957 Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Thu, 18 Dec 2025 19:39:44 +0000 Subject: [PATCH 2/4] Parallelise redux --- mosaic/runtime/monitor.py | 11 +++++++---- stride/core.py | 33 +++++++++++++-------------------- 2 files changed, 20 insertions(+), 24 deletions(-) diff --git a/mosaic/runtime/monitor.py b/mosaic/runtime/monitor.py index 9c29f51..89a2ad9 100644 --- a/mosaic/runtime/monitor.py +++ b/mosaic/runtime/monitor.py @@ -542,17 +542,20 @@ async def barrier(self, sender_id, timeout=None): pending_tasks.append(task) self.logger.info('Pending barrier tasks %d' % len(pending_tasks)) + num_tasks = len(self._monitored_tasks) + tic = time.time() while pending_tasks: - await asyncio.sleep(0.5) + await asyncio.sleep(0.1) for task in pending_tasks: if task.state in ['done', 'failed', 'collected']: pending_tasks.remove(task) - for task in self._monitored_tasks.values(): - if task.state not in ['done', 'failed', 'collected'] and task not in pending_tasks: - pending_tasks.append(task) + if len(self._monitored_tasks) > num_tasks: + for task in self._monitored_tasks.values(): + if task.state not in ['done', 'failed', 'collected'] and task not in pending_tasks: + pending_tasks.append(task) if timeout is not None and (time.time() - tic) > timeout: break diff --git a/stride/core.py b/stride/core.py index c82cc89..f8e8b0a 100644 --- a/stride/core.py +++ b/stride/core.py @@ -580,24 +580,10 @@ async def __call_adjoint__(self, grad, redux=False, **kwargs): if grad is None or not self.needs_grad or self.grad is None: return - self.grad += grad - - async def __call_adjoint_list__(self, *grad, **kwargs): - """ - List adjoint operation of the variable, which accumulates the given - gradient on the ``Variable.grad`` attribute. - - Parameters - ---------- - grad : tuple - Provided gradient + if isinstance(grad, (list, tuple)): + grad = grad[0] - Returns - ------- - - """ - for g in grad: - await self.__call_adjoint__(g[0]) + self.grad += grad async def __redux_adjoint__(self): """ @@ -618,9 +604,16 @@ async def __redux_adjoint__(self): redux_proxy = TesseraProxy(tess._cls, runtime=tess.runtime_id, uid=tess.uid) redux_proxy.init_future.set_result(True) redux_proxy.state_changed('listening') - redux_task = TaskProxy(redux_proxy, '__call_adjoint_list__', *self._redux_grads) - await redux_proxy._init_task(redux_task, *self._redux_grads) - await redux_task + + inits = [] + tasks = [] + for g in self._redux_grads: + redux_task = TaskProxy(redux_proxy, '__call_adjoint__', g) + inits.append(redux_proxy._init_task(redux_task, g)) + tasks.append(redux_task) + + await asyncio.gather(*inits) + await asyncio.gather(*tasks) drops = [] for g in self._redux_grads: From 99ef4c514fa9199533a7568a6e6a1a344efa3a94 Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Tue, 23 Dec 2025 11:59:24 +0000 Subject: [PATCH 3/4] Index redux tasks by warehouse --- stride/core.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/stride/core.py b/stride/core.py index f8e8b0a..8aefb08 100644 --- a/stride/core.py +++ b/stride/core.py @@ -294,7 +294,7 @@ def __init__(self, *args, **kwargs): self.prev_op = None self.needs_grad = kwargs.pop('needs_grad', False) - self._redux_grads = set() + self._redux_grads = dict() self._redux_task = False async def adjoint(self, grad=None, **kwargs): @@ -568,7 +568,7 @@ async def __call_adjoint__(self, grad, redux=False, **kwargs): """ if redux: - self._redux_grads.add(grad[0]) + self._redux_grads[grad[0].warehouse_id] = grad[0] if not self._redux_task: runtime = mosaic.runtime() @@ -607,7 +607,7 @@ async def __redux_adjoint__(self): inits = [] tasks = [] - for g in self._redux_grads: + for g in self._redux_grads.values(): redux_task = TaskProxy(redux_proxy, '__call_adjoint__', g) inits.append(redux_proxy._init_task(redux_task, g)) tasks.append(redux_task) @@ -616,10 +616,12 @@ async def __redux_adjoint__(self): await asyncio.gather(*tasks) drops = [] - for g in self._redux_grads: + for g in self._redux_grads.values(): drops.append(g.drop(propagate=True)) await asyncio.gather(*drops) + self._redux_grads = dict() + def __repr__(self): return self.name From 95f495a1adeb4f16d675d584cbc1ff12b28e5806 Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Wed, 24 Dec 2025 10:48:28 +0000 Subject: [PATCH 4/4] Redux task flag reset --- stride/core.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stride/core.py b/stride/core.py index 8aefb08..392d736 100644 --- a/stride/core.py +++ b/stride/core.py @@ -293,7 +293,7 @@ def __init__(self, *args, **kwargs): self.graph = Graph() self.prev_op = None self.needs_grad = kwargs.pop('needs_grad', False) - + self._redux_grads = dict() self._redux_task = False @@ -621,6 +621,7 @@ async def __redux_adjoint__(self): await asyncio.gather(*drops) self._redux_grads = dict() + self._redux_task = False def __repr__(self): return self.name