From 7c27079ca93626e31c60667721e3f3eeda817bde Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Tue, 3 Feb 2026 16:57:59 +0000 Subject: [PATCH 1/2] Fix eager run in adjoint --- stride/core.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/stride/core.py b/stride/core.py index 6b5b89f..2ddc446 100644 --- a/stride/core.py +++ b/stride/core.py @@ -342,6 +342,7 @@ def _dealloc(*args): returns = [] parallel_returns = [] deallocs = [] + self_uid = mosaic.runtime().uid for node in self.graph.toposort(self.prev_op): kwargs_ = kwargs.copy() @@ -356,27 +357,36 @@ def _dealloc(*args): # call adjoint method try: method = getattr(node.op, node.method) + is_proxy = isinstance(node.op, TesseraProxy) except AttributeError: method = getattr(node.op.obj, node.method) + is_proxy = isinstance(node.op.obj, TesseraProxy) + if hasattr(node.op, 'is_parameter') and node.op.is_parameter: redux_grad = await runtime.exec('redux-%s' % node.op.uid, redux, output_grads) ret = method((redux_grad,), **{**kwargs_, **{'eager': True, 'redux': True}}) else: - ret = method(*output_grads, **kwargs_) + if is_proxy: + ret = method(*output_grads, **{**kwargs_, **{'eager': True}}) + else: + ret = method(*output_grads, **kwargs_) if inspect.iscoroutine(ret) or inspect.iscoroutinefunction(ret): ret = await ret if isinstance(ret, TaskProxy): - if len(deallocs): - ret.add_done_callback(dealloc(deallocs)) - - if (not hasattr(node.op, 'has_tessera') or not node.op.has_tessera or not node.op.is_proxy) and \ - (not hasattr(node.op, 'is_parameter') or not node.op.is_parameter): - returns.append(ret) + if ret.runtime_id == self_uid: + ret = await ret else: - parallel_returns.append(ret) + if len(deallocs): + ret.add_done_callback(dealloc(deallocs)) + + if (not hasattr(node.op, 'has_tessera') or not node.op.has_tessera or not node.op.is_proxy) and \ + (not hasattr(node.op, 'is_parameter') or not node.op.is_parameter): + returns.append(ret) + else: + parallel_returns.append(ret) input_grads = ret.outputs else: From 7d8c688aacc0a2a37202fadba308dd523a805a5c Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Tue, 3 Feb 2026 17:10:33 +0000 Subject: [PATCH 2/2] Handle None runtime --- stride/core.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/stride/core.py b/stride/core.py index 2ddc446..7a09550 100644 --- a/stride/core.py +++ b/stride/core.py @@ -342,7 +342,10 @@ def _dealloc(*args): returns = [] parallel_returns = [] deallocs = [] - self_uid = mosaic.runtime().uid + try: + self_uid = mosaic.runtime().uid + except AttributeError: + self_uid = None for node in self.graph.toposort(self.prev_op): kwargs_ = kwargs.copy()