Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mosaic/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,8 +675,9 @@ class TaskProxy(ProxyBase):
def __init__(self, proxy, method, *args, **kwargs):
super().__init__(*args, **kwargs)

cls = proxy._cls.cls if hasattr(proxy._cls, 'cls') else proxy._cls
self._uid = '%s-%s-%s-%s' % ('task',
proxy._cls.cls.__name__.lower(),
cls.__name__.lower(),
method,
uuid.uuid4().hex)
self._tessera_proxy = proxy
Expand Down
3 changes: 2 additions & 1 deletion mosaic/core/tessera.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,9 +632,10 @@ def __init__(self, cls, *args, **kwargs):
runtime = kwargs.pop('runtime', None)
self._runtime_id = runtime.uid if hasattr(runtime, 'uid') else runtime

uid = kwargs.pop('uid', None)
self._uid = '%s-%s-%s' % ('tess',
self._cls.__name__.lower(),
uuid.uuid4().hex)
uuid.uuid4().hex) if uid is None else uid

self._cls_attr_names = None
self._set_cls()
Expand Down
13 changes: 9 additions & 4 deletions mosaic/runtime/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,21 +542,26 @@ async def barrier(self, sender_id, timeout=None):
pending_tasks.append(task)
self.logger.info('Pending barrier tasks %d' % len(pending_tasks))

num_tasks = len(self._monitored_tasks)

tic = time.time()
while pending_tasks:
await asyncio.sleep(0.5)
await asyncio.sleep(0.1)

for task in pending_tasks:
if task.state in ['done', 'failed', 'collected']:
pending_tasks.remove(task)

for task in self._monitored_tasks.values():
if task.state not in ['done', 'failed', 'collected'] and task not in pending_tasks:
pending_tasks.append(task)
if len(self._monitored_tasks) > num_tasks:
for task in self._monitored_tasks.values():
if task.state not in ['done', 'failed', 'collected'] and task not in pending_tasks:
pending_tasks.append(task)

if timeout is not None and (time.time() - tic) > timeout:
break

await self._local_warehouse.run_barrier_tasks(reply=True)

self._monitored_tasks = dict()
self._runtime_tasks = defaultdict(list)
self._dirty_tasks = set()
45 changes: 43 additions & 2 deletions mosaic/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def __init__(self, **kwargs):
self._dealloc_queue = []
self._maintenance_queue = []
self._maintenance_msgs = {}
self._barrier_tasks = []

self._inside_async_for = False

Expand Down Expand Up @@ -339,6 +340,25 @@ def cpu_load(self):

return 0

def register_barrier_task(self, task):
"""
Register task for execution at barrier.

Parameters
----------
task

Returns
-------

"""
self._barrier_tasks.append(task)

async def run_barrier_tasks(self, sender_id):
for task in self._barrier_tasks:
await task()
self._barrier_tasks = []

async def barrier(self, timeout=None):
"""
Wait until all pending tasks are done. If no timeout is
Expand Down Expand Up @@ -1080,21 +1100,42 @@ async def get(self, uid, cache=True):

return obj

async def drop(self, uid, cache_only=False):
async def exec(self, uid, func, func_args=None, func_kwargs=None):
"""
Retrieve an object from the warehouse.

Parameters
----------
uid
func
func_args
func_kwargs

Returns
-------

"""
ret = await self._local_warehouse.exec_remote(uid=uid, func=func,
func_args=func_args, func_kwargs=func_kwargs,
reply=True)
return ret

async def drop(self, uid, cache_only=False, propagate=False):
"""
Delete an object from the warehouse.

Parameters
----------
uid
cache_only
propagate

Returns
-------

"""
if not cache_only:
await self._local_warehouse.drop_remote(uid=uid)
await self._local_warehouse.drop_remote(uid=uid, propagate=propagate)

obj_uid = uid.uid if hasattr(uid, 'uid') else uid
try:
Expand Down
54 changes: 50 additions & 4 deletions mosaic/runtime/warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,24 +167,70 @@ async def get_remote(self, sender_id, uid, warehouse_id=None, node_id=None):

return obj

async def drop_remote(self, sender_id, uid):
async def exec_remote(self, sender_id, uid, func, func_args=None, func_kwargs=None):
"""
Retrieve an object from the warehouse and execute function on it.

Parameters
----------
sender_id
uid
func
func_args
func_kwargs

Returns
-------

"""
if isinstance(uid, WarehouseObject):
obj_id = uid.uid
else:
obj_id = uid

try:
obj = self._local_warehouse[obj_id]
except KeyError:
obj = None

func_args = func_args or ()
func_kwargs = func_kwargs or {}
obj = await func(obj, *func_args, **func_kwargs)

self._local_warehouse[obj_id] = obj

warehouse_obj = WarehouseObject(obj, uid=obj_id)
return warehouse_obj

async def drop_remote(self, sender_id, uid, propagate=False):
"""
Delete an object from the warehouse.

Parameters
----------
sender_id
uid
propagate

Returns
-------

"""
if isinstance(uid, WarehouseObject):
uid = uid.uid
obj_id = uid.uid

if uid in self._local_warehouse:
del self._local_warehouse[uid]
if propagate:
node_id = uid.node_id
warehouse_id = uid.warehouse_id

if node_id is not None and warehouse_id is not None:
if node_id not in self._warehouses:
self._warehouses[node_id] = self.proxy(uid=warehouse_id)

await self._warehouses[node_id].drop_remote(uid=uid)

if obj_id in self._local_warehouse:
del self._local_warehouse[obj_id]

async def push_remote(self, sender_id, __dict__,
uid=None, warehouse_id=None, node_id=None,
Expand Down
82 changes: 66 additions & 16 deletions stride/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
import uuid
import asyncio
import inspect
import numpy as np
from abc import abstractmethod
from collections import OrderedDict

import mosaic
from mosaic import types
from mosaic.core.base import CMDBase
from mosaic.core import TaskProxy
from mosaic.core import TesseraProxy, TaskProxy


__all__ = ['Variable', 'Operator']
Expand All @@ -35,7 +34,6 @@ async def _maybe_sum(a, b):
return b[0] + a, b[1]
elif isinstance(a, tuple) and isinstance(b, tuple):
return a[0] + b[0], a[1] + b[1]

return a + b


Expand Down Expand Up @@ -296,6 +294,9 @@ def __init__(self, *args, **kwargs):
self.prev_op = None
self.needs_grad = kwargs.pop('needs_grad', False)

self._redux_grads = dict()
self._redux_task = False

async def adjoint(self, grad=None, **kwargs):
"""
Run the adjoint graph that has this variable as its root.
Expand All @@ -322,6 +323,17 @@ async def adjoint(self, grad=None, **kwargs):

runtime = mosaic.runtime()

async def redux(rec_grads, *grads):
if rec_grads is None:
sums = [
_maybe_sum(None, g) for g in grads
]
else:
sums = [
_maybe_sum(r, g) for r, g in zip(rec_grads, grads)
]
return await asyncio.gather(*sums)

def dealloc(objs):
def _dealloc(*args):
loop = mosaic.get_event_loop()
Expand Down Expand Up @@ -351,7 +363,9 @@ def _dealloc(*args):
except AttributeError:
method = getattr(node.op.obj, node.method)
if hasattr(node.op, 'is_parameter') and node.op.is_parameter:
ret = method(*output_grads, **{**kwargs_, **{'eager': True}})
redux_grad = await runtime.exec('redux-%s' % node.op.uid, redux, output_grads)
ret = method((redux_grad,), **{**kwargs_, **{'eager': True, 'redux': True}})

else:
ret = method(*output_grads, **kwargs_)

Expand Down Expand Up @@ -539,7 +553,7 @@ def process_grad(self):
"""
raise NotImplementedError('Unimplemented Variable method process_grad')

async def __call_adjoint__(self, grad, **kwargs):
async def __call_adjoint__(self, grad, redux=False, **kwargs):
"""
Adjoint operation of the variable, which accumulates the given
gradient on the ``Variable.grad`` attribute.
Expand All @@ -553,25 +567,61 @@ async def __call_adjoint__(self, grad, **kwargs):
-------

"""
if redux:
self._redux_grads[grad[0].warehouse_id] = grad[0]

if not self._redux_task:
runtime = mosaic.runtime()
runtime.register_barrier_task(self.__redux_adjoint__)
self._redux_task = True

return

if grad is None or not self.needs_grad or self.grad is None:
return

grad_data = grad.data if hasattr(grad, 'data') else grad
is_nan = np.any(np.isnan(grad_data))
is_inf = np.any(np.isinf(grad_data))
if isinstance(grad, (list, tuple)):
grad = grad[0]

self.grad += grad

async def __redux_adjoint__(self):
"""
Reduction adjoint operation of the variable, which accumulates the given
gradient on the ``Variable.grad`` attribute.

if is_nan or is_inf:
msg = 'Nan or inf detected in %s' % self.name
Parameters
----------

problem = kwargs.pop('problem', None)
shot_id = problem.shot.id if problem is not None else kwargs.pop('shot_id', None)
if shot_id is not None:
msg = '(ShotID %d) ' % shot_id + msg
Returns
-------

mosaic.logger().warn(msg)
"""
if not hasattr(self, '_tessera'):
return

self.grad += grad
tess = self._tessera
redux_proxy = TesseraProxy(tess._cls, runtime=tess.runtime_id, uid=tess.uid)
redux_proxy.init_future.set_result(True)
redux_proxy.state_changed('listening')

inits = []
tasks = []
for g in self._redux_grads.values():
redux_task = TaskProxy(redux_proxy, '__call_adjoint__', g)
inits.append(redux_proxy._init_task(redux_task, g))
tasks.append(redux_task)

await asyncio.gather(*inits)
await asyncio.gather(*tasks)

drops = []
for g in self._redux_grads.values():
drops.append(g.drop(propagate=True))
await asyncio.gather(*drops)

self._redux_grads = dict()
self._redux_task = False

def __repr__(self):
return self.name
Expand Down
20 changes: 19 additions & 1 deletion stride/physics/problem_type.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC
import numpy as np

import mosaic

Expand Down Expand Up @@ -292,6 +293,23 @@ def get_grad(self, *wrt, **kwargs):
if method is None:
raise ValueError('Variable %s not implemented' % variable.name)

grads.append(method(variable, wrt=wrt, **kwargs))
grad = method(variable, wrt=wrt, **kwargs)

grad_data = grad.data if hasattr(grad, 'data') else grad
is_nan = np.any(np.isnan(grad_data))
is_inf = np.any(np.isinf(grad_data))

if is_nan or is_inf:
msg = 'Nan or inf detected in %s' % self.name

problem = kwargs.pop('problem', None)
shot_id = problem.shot.id if problem is not None else kwargs.pop('shot_id', None)
if shot_id is not None:
msg = '(ShotID %d) ' % shot_id + msg

self.logger.warn(msg)
return

grads.append(grad)

return tuple(grads)