Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions mosaic/file_manipulation/h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,23 +175,23 @@ def read(obj, lazy=True, filter=None, only=None):


def _read_dataset(obj, lazy=True):
if 'is_none' in obj.attrs and obj.attrs['is_none']:
if obj.attrs.get('is_none'):
return None

if obj.attrs['is_ndarray']:

def load():
return obj[()]
if lazy is True:
def load():
return obj[()]

setattr(obj, 'load', load)
setattr(obj, 'load', load)

if lazy is True:
return obj

else:
return obj[()]

elif 'is_bytes' in obj.attrs and obj.attrs['is_bytes']:
elif obj.attrs.get('is_bytes'):
obj = obj[()].tobytes()

return obj
Expand Down
53 changes: 53 additions & 0 deletions mosaic/runtime/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import asyncio
import datetime
import subprocess as cmd_subprocess
from collections import defaultdict

import mosaic
from .runtime import Runtime, RuntimeProxy
Expand Down Expand Up @@ -75,6 +76,9 @@ def __init__(self, **kwargs):
self._monitored_tessera = dict()
self._monitored_tasks = dict()

self._runtime_tessera = defaultdict(list)
self._runtime_tasks = defaultdict(list)

self._dirty_tessera = set()
self._dirty_tasks = set()

Expand Down Expand Up @@ -329,6 +333,7 @@ def add_task_profile(self, sender_id, msgs):
def _add_tessera_event(self, sender_id, runtime_id, uid, **kwargs):
if uid not in self._monitored_tessera:
self._monitored_tessera[uid] = MonitoredObject(runtime_id, uid)
self._runtime_tessera[runtime_id].append(uid)

obj = self._monitored_tessera[uid]
obj.add_event(sender_id, **kwargs)
Expand All @@ -338,6 +343,7 @@ def _add_tessera_event(self, sender_id, runtime_id, uid, **kwargs):
def _add_task_event(self, sender_id, runtime_id, uid, tessera_id, **kwargs):
if uid not in self._monitored_tasks:
self._monitored_tasks[uid] = MonitoredObject(runtime_id, uid, tessera_id=tessera_id)
self._runtime_tasks[runtime_id].append(uid)

obj = self._monitored_tasks[uid]
obj.add_event(sender_id, **kwargs)
Expand All @@ -347,6 +353,7 @@ def _add_task_event(self, sender_id, runtime_id, uid, tessera_id, **kwargs):
def _add_tessera_profile(self, sender_id, runtime_id, uid, profile):
if uid not in self._monitored_tessera:
self._monitored_tessera[uid] = MonitoredObject(runtime_id, uid)
self._runtime_tessera[runtime_id].append(uid)

obj = self._monitored_tessera[uid]
obj.add_profile(sender_id, profile)
Expand All @@ -355,6 +362,7 @@ def _add_tessera_profile(self, sender_id, runtime_id, uid, profile):
def _add_task_profile(self, sender_id, runtime_id, uid, tessera_id, profile):
if uid not in self._monitored_tasks:
self._monitored_tasks[uid] = MonitoredObject(runtime_id, uid, tessera_id=tessera_id)
self._runtime_tasks[runtime_id].append(uid)

obj = self._monitored_tasks[uid]
obj.add_profile(sender_id, profile)
Expand Down Expand Up @@ -457,6 +465,43 @@ async def stop(self, sender_id=None):

await super().stop(sender_id)

def disconnect(self, sender_id, uid):
"""
Disconnect specific remote runtime.

Parameters
----------
sender_id : str
uid : str

Returns
-------

"""
super().disconnect(sender_id, uid)

# remove runtime from monitored nodes
try:
del self._monitored_nodes[uid]
except KeyError:
pass

# remove monitored tessera
for obj_uid in self._runtime_tessera[uid]:
try:
del self._monitored_tessera[obj_uid]
except KeyError:
pass
del self._runtime_tessera[uid]

# remove monitored tasks
for obj_uid in self._runtime_tasks[uid]:
try:
del self._monitored_tasks[obj_uid]
except KeyError:
pass
del self._runtime_tasks[uid]

async def select_worker(self, sender_id):
"""
Select appropriate worker to allocate a tessera.
Expand Down Expand Up @@ -505,5 +550,13 @@ async def barrier(self, sender_id, timeout=None):
if task.state in ['done', 'failed', 'collected']:
pending_tasks.remove(task)

for task in self._monitored_tasks.values():
if task.state not in ['done', 'failed', 'collected'] and task not in pending_tasks:
pending_tasks.append(task)

if timeout is not None and (time.time() - tic) > timeout:
break

self._monitored_tasks = dict()
self._runtime_tasks = defaultdict(list)
self._dirty_tasks = set()
2 changes: 2 additions & 0 deletions mosaic/runtime/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def __init__(self, **kwargs):

os.environ['OMP_NUM_THREADS'] = str(self._num_threads)
os.environ['NUMBA_NUM_THREADS'] = str(self._num_threads)
os.environ['MKL_NUM_THREADS'] = str(self._num_threads)
os.environ['OPENBLAS_NUM_THREADS'] = str(self._num_threads)

def set_logger(self):
"""
Expand Down
6 changes: 2 additions & 4 deletions stride/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,7 @@ async def adjoint(problem, pde, loss, optimisation_loop, optimiser, *args, **kwa
filter_traces = kwargs.pop('filter_traces', True)
filter_wavelets = kwargs.pop('filter_wavelets', filter_traces)

fw3d_mode = kwargs.get('fw3d_mode', False)
filter_wavelets_relaxation = kwargs.pop('filter_wavelets_relaxation',
0.75 if not fw3d_mode else 0.725)
filter_wavelets_relaxation = kwargs.pop('filter_wavelets_relaxation', 0.75)
filter_traces_relaxation = kwargs.pop('filter_traces_relaxation',
0.75 if filter_wavelets else 1.00)

Expand Down Expand Up @@ -454,9 +452,9 @@ async def loop(worker, shot_id):
fun.clear_graph()

iteration.add_loss(fun)
iteration.add_completed(sub_problem.shot)
logger.perf('Functional value for shot %d: %s' % (shot_id, fun))

iteration.add_completed(sub_problem.shot)
logger.perf('Retrieved test step for shot %d (%d out of %d)'
% (sub_problem.shot_id,
iteration.num_completed, num_shots))
Expand Down
6 changes: 4 additions & 2 deletions stride/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,9 +401,11 @@ def _dealloc(*args):

prev[nxt.name_idx] = input_grad

parallel_inits = [task.init_future for task in parallel_returns]

eager = not len(returns) or returns[-1]._eager
if eager:
await asyncio.gather(*returns)
await asyncio.gather(*(returns + parallel_inits))
else:
summ_returns = []
summ_dependencies = []
Expand All @@ -413,7 +415,7 @@ def _dealloc(*args):
for runtime_deps in ret._dependencies.values():
summ_dependencies += list(runtime_deps.values())

await asyncio.gather(*summ_returns)
await asyncio.gather(*(summ_returns + parallel_inits))

self.clear_graph()

Expand Down
4 changes: 0 additions & 4 deletions stride/optimisation/optimisers/gradient_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,3 @@ async def pre_process(self, grad=None, processed_grad=None, **kwargs):
processed_grad=processed_grad,
**kwargs)
return processed_grad

def update_variable(self, step_size, variable, direction):
variable.data[:] -= step_size * direction.data
return variable
14 changes: 7 additions & 7 deletions stride/optimisation/optimisers/optimiser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

import numpy as np
from abc import ABC, abstractmethod
from abc import ABC

import mosaic

Expand Down Expand Up @@ -79,8 +79,8 @@ async def pre_process(self, grad=None, processed_grad=None, **kwargs):
logger = mosaic.logger()
logger.perf('Updating variable %s,' % self.variable.name)

problem = kwargs.pop('problem', None)
iteration = kwargs.pop('iteration', None)
problem = kwargs.get('problem', None)
iteration = kwargs.get('iteration', None)

if processed_grad is None:
if grad is None:
Expand Down Expand Up @@ -229,7 +229,7 @@ async def step(self, step_size=None, grad=None, processed_grad=None, **kwargs):
variable = self.variable.transform(self.variable)
else:
variable = self.variable
upd_variable = self.update_variable(next_step, variable, direction)
upd_variable = self.update_variable(next_step, variable, direction, **kwargs)
if self.variable.transform is not None:
upd_variable = self.variable.transform(upd_variable)
self.variable.data[:] = upd_variable.data.copy()
Expand Down Expand Up @@ -278,8 +278,7 @@ async def post_process(self, **kwargs):

self.variable.release_grad()

@abstractmethod
def update_variable(self, step_size, variable, direction):
def update_variable(self, step_size, variable, direction, **kwargs):
"""

Parameters
Expand All @@ -297,7 +296,8 @@ def update_variable(self, step_size, variable, direction):
Updated variable.

"""
pass
variable.data[:] -= step_size * direction.data
return variable

def reset(self, **kwargs):
"""
Expand Down
14 changes: 12 additions & 2 deletions stride/optimisation/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,24 @@ def __init__(self, steps=None, **kwargs):
step, do_raise = step

if isinstance(step, str):
step_name = step
step_cls = steps_registry.get(step, None)
if step_cls is None and do_raise:
raise ValueError('Pipeline step %s does not exist in the registry' % step)

if step_cls is not None:
self._steps[step] = step_cls(**kwargs)
step = step_cls(**kwargs)
else:
continue
else:
self._steps[str(step)] = step
step_name = str(step)

cnt = 0
while step_name in self._steps:
step_name = '%s%d' % (step_name, cnt)
cnt += 1

self._steps[step_name] = step

def insert(self, loc, key, step):
pos = list(self._steps.keys()).index(loc)
Expand Down
2 changes: 1 addition & 1 deletion stride/optimisation/pipelines/steps/mask_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def forward(self, field, **kwargs):
mask = kwargs.pop('mask', None)
mask_rampoff = kwargs.pop('mask_rampoff', self.mask_rampoff)
mask = self._mask if mask is None else mask
if mask is None:
if mask is None or np.any([m != f for m, f in zip(mask.shape, field.extended_shape)]):
mask = np.zeros(field.extended_shape, dtype=np.float32)
mask[field.inner] = 1
mask *= _rampoff_mask(mask.shape, mask_rampoff)
Expand Down
4 changes: 3 additions & 1 deletion stride/optimisation/pipelines/steps/shift_traces.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@

import numpy as np

from .utils import name_from_op_name
from ....core import Operator

Expand Down Expand Up @@ -56,7 +58,7 @@ def _apply(self, traces, **kwargs):
return traces

f_max_dim_less = 1/relaxation*f_max*time.step if f_max is not None else 0
period = int(1 / f_max_dim_less)
period = int(np.round(1 / f_max_dim_less))
shift = period//4

out_data = traces.extended_data.copy()
Expand Down
Loading