diff --git a/mosaic/file_manipulation/h5.py b/mosaic/file_manipulation/h5.py index 54afddd..a52c19b 100644 --- a/mosaic/file_manipulation/h5.py +++ b/mosaic/file_manipulation/h5.py @@ -175,23 +175,23 @@ def read(obj, lazy=True, filter=None, only=None): def _read_dataset(obj, lazy=True): - if 'is_none' in obj.attrs and obj.attrs['is_none']: + if obj.attrs.get('is_none'): return None if obj.attrs['is_ndarray']: - def load(): - return obj[()] + if lazy is True: + def load(): + return obj[()] - setattr(obj, 'load', load) + setattr(obj, 'load', load) - if lazy is True: return obj else: return obj[()] - elif 'is_bytes' in obj.attrs and obj.attrs['is_bytes']: + elif obj.attrs.get('is_bytes'): obj = obj[()].tobytes() return obj diff --git a/mosaic/runtime/monitor.py b/mosaic/runtime/monitor.py index 3c77132..bf291c8 100644 --- a/mosaic/runtime/monitor.py +++ b/mosaic/runtime/monitor.py @@ -5,6 +5,7 @@ import asyncio import datetime import subprocess as cmd_subprocess +from collections import defaultdict import mosaic from .runtime import Runtime, RuntimeProxy @@ -75,6 +76,9 @@ def __init__(self, **kwargs): self._monitored_tessera = dict() self._monitored_tasks = dict() + self._runtime_tessera = defaultdict(list) + self._runtime_tasks = defaultdict(list) + self._dirty_tessera = set() self._dirty_tasks = set() @@ -329,6 +333,7 @@ def add_task_profile(self, sender_id, msgs): def _add_tessera_event(self, sender_id, runtime_id, uid, **kwargs): if uid not in self._monitored_tessera: self._monitored_tessera[uid] = MonitoredObject(runtime_id, uid) + self._runtime_tessera[runtime_id].append(uid) obj = self._monitored_tessera[uid] obj.add_event(sender_id, **kwargs) @@ -338,6 +343,7 @@ def _add_tessera_event(self, sender_id, runtime_id, uid, **kwargs): def _add_task_event(self, sender_id, runtime_id, uid, tessera_id, **kwargs): if uid not in self._monitored_tasks: self._monitored_tasks[uid] = MonitoredObject(runtime_id, uid, tessera_id=tessera_id) + self._runtime_tasks[runtime_id].append(uid) obj = self._monitored_tasks[uid] obj.add_event(sender_id, **kwargs) @@ -347,6 +353,7 @@ def _add_task_event(self, sender_id, runtime_id, uid, tessera_id, **kwargs): def _add_tessera_profile(self, sender_id, runtime_id, uid, profile): if uid not in self._monitored_tessera: self._monitored_tessera[uid] = MonitoredObject(runtime_id, uid) + self._runtime_tessera[runtime_id].append(uid) obj = self._monitored_tessera[uid] obj.add_profile(sender_id, profile) @@ -355,6 +362,7 @@ def _add_tessera_profile(self, sender_id, runtime_id, uid, profile): def _add_task_profile(self, sender_id, runtime_id, uid, tessera_id, profile): if uid not in self._monitored_tasks: self._monitored_tasks[uid] = MonitoredObject(runtime_id, uid, tessera_id=tessera_id) + self._runtime_tasks[runtime_id].append(uid) obj = self._monitored_tasks[uid] obj.add_profile(sender_id, profile) @@ -457,6 +465,43 @@ async def stop(self, sender_id=None): await super().stop(sender_id) + def disconnect(self, sender_id, uid): + """ + Disconnect specific remote runtime. + + Parameters + ---------- + sender_id : str + uid : str + + Returns + ------- + + """ + super().disconnect(sender_id, uid) + + # remove runtime from monitored nodes + try: + del self._monitored_nodes[uid] + except KeyError: + pass + + # remove monitored tessera + for obj_uid in self._runtime_tessera[uid]: + try: + del self._monitored_tessera[obj_uid] + except KeyError: + pass + del self._runtime_tessera[uid] + + # remove monitored tasks + for obj_uid in self._runtime_tasks[uid]: + try: + del self._monitored_tasks[obj_uid] + except KeyError: + pass + del self._runtime_tasks[uid] + async def select_worker(self, sender_id): """ Select appropriate worker to allocate a tessera. @@ -505,5 +550,13 @@ async def barrier(self, sender_id, timeout=None): if task.state in ['done', 'failed', 'collected']: pending_tasks.remove(task) + for task in self._monitored_tasks.values(): + if task.state not in ['done', 'failed', 'collected'] and task not in pending_tasks: + pending_tasks.append(task) + if timeout is not None and (time.time() - tic) > timeout: break + + self._monitored_tasks = dict() + self._runtime_tasks = defaultdict(list) + self._dirty_tasks = set() diff --git a/mosaic/runtime/worker.py b/mosaic/runtime/worker.py index b6c8512..8129571 100644 --- a/mosaic/runtime/worker.py +++ b/mosaic/runtime/worker.py @@ -32,6 +32,8 @@ def __init__(self, **kwargs): os.environ['OMP_NUM_THREADS'] = str(self._num_threads) os.environ['NUMBA_NUM_THREADS'] = str(self._num_threads) + os.environ['MKL_NUM_THREADS'] = str(self._num_threads) + os.environ['OPENBLAS_NUM_THREADS'] = str(self._num_threads) def set_logger(self): """ diff --git a/stride/__init__.py b/stride/__init__.py index 0847fe7..20a0351 100644 --- a/stride/__init__.py +++ b/stride/__init__.py @@ -232,9 +232,7 @@ async def adjoint(problem, pde, loss, optimisation_loop, optimiser, *args, **kwa filter_traces = kwargs.pop('filter_traces', True) filter_wavelets = kwargs.pop('filter_wavelets', filter_traces) - fw3d_mode = kwargs.get('fw3d_mode', False) - filter_wavelets_relaxation = kwargs.pop('filter_wavelets_relaxation', - 0.75 if not fw3d_mode else 0.725) + filter_wavelets_relaxation = kwargs.pop('filter_wavelets_relaxation', 0.75) filter_traces_relaxation = kwargs.pop('filter_traces_relaxation', 0.75 if filter_wavelets else 1.00) @@ -454,9 +452,9 @@ async def loop(worker, shot_id): fun.clear_graph() iteration.add_loss(fun) - iteration.add_completed(sub_problem.shot) logger.perf('Functional value for shot %d: %s' % (shot_id, fun)) + iteration.add_completed(sub_problem.shot) logger.perf('Retrieved test step for shot %d (%d out of %d)' % (sub_problem.shot_id, iteration.num_completed, num_shots)) diff --git a/stride/core.py b/stride/core.py index 536b7b5..e08de56 100644 --- a/stride/core.py +++ b/stride/core.py @@ -401,9 +401,11 @@ def _dealloc(*args): prev[nxt.name_idx] = input_grad + parallel_inits = [task.init_future for task in parallel_returns] + eager = not len(returns) or returns[-1]._eager if eager: - await asyncio.gather(*returns) + await asyncio.gather(*(returns + parallel_inits)) else: summ_returns = [] summ_dependencies = [] @@ -413,7 +415,7 @@ def _dealloc(*args): for runtime_deps in ret._dependencies.values(): summ_dependencies += list(runtime_deps.values()) - await asyncio.gather(*summ_returns) + await asyncio.gather(*(summ_returns + parallel_inits)) self.clear_graph() diff --git a/stride/optimisation/optimisers/gradient_descent.py b/stride/optimisation/optimisers/gradient_descent.py index 21a7dd8..bd9cb1b 100644 --- a/stride/optimisation/optimisers/gradient_descent.py +++ b/stride/optimisation/optimisers/gradient_descent.py @@ -29,7 +29,3 @@ async def pre_process(self, grad=None, processed_grad=None, **kwargs): processed_grad=processed_grad, **kwargs) return processed_grad - - def update_variable(self, step_size, variable, direction): - variable.data[:] -= step_size * direction.data - return variable diff --git a/stride/optimisation/optimisers/optimiser.py b/stride/optimisation/optimisers/optimiser.py index 770ac96..dc71c86 100644 --- a/stride/optimisation/optimisers/optimiser.py +++ b/stride/optimisation/optimisers/optimiser.py @@ -1,6 +1,6 @@ import numpy as np -from abc import ABC, abstractmethod +from abc import ABC import mosaic @@ -79,8 +79,8 @@ async def pre_process(self, grad=None, processed_grad=None, **kwargs): logger = mosaic.logger() logger.perf('Updating variable %s,' % self.variable.name) - problem = kwargs.pop('problem', None) - iteration = kwargs.pop('iteration', None) + problem = kwargs.get('problem', None) + iteration = kwargs.get('iteration', None) if processed_grad is None: if grad is None: @@ -229,7 +229,7 @@ async def step(self, step_size=None, grad=None, processed_grad=None, **kwargs): variable = self.variable.transform(self.variable) else: variable = self.variable - upd_variable = self.update_variable(next_step, variable, direction) + upd_variable = self.update_variable(next_step, variable, direction, **kwargs) if self.variable.transform is not None: upd_variable = self.variable.transform(upd_variable) self.variable.data[:] = upd_variable.data.copy() @@ -278,8 +278,7 @@ async def post_process(self, **kwargs): self.variable.release_grad() - @abstractmethod - def update_variable(self, step_size, variable, direction): + def update_variable(self, step_size, variable, direction, **kwargs): """ Parameters @@ -297,7 +296,8 @@ def update_variable(self, step_size, variable, direction): Updated variable. """ - pass + variable.data[:] -= step_size * direction.data + return variable def reset(self, **kwargs): """ diff --git a/stride/optimisation/pipelines/pipeline.py b/stride/optimisation/pipelines/pipeline.py index e026fdc..29c3e0f 100644 --- a/stride/optimisation/pipelines/pipeline.py +++ b/stride/optimisation/pipelines/pipeline.py @@ -67,14 +67,24 @@ def __init__(self, steps=None, **kwargs): step, do_raise = step if isinstance(step, str): + step_name = step step_cls = steps_registry.get(step, None) if step_cls is None and do_raise: raise ValueError('Pipeline step %s does not exist in the registry' % step) if step_cls is not None: - self._steps[step] = step_cls(**kwargs) + step = step_cls(**kwargs) + else: + continue else: - self._steps[str(step)] = step + step_name = str(step) + + cnt = 0 + while step_name in self._steps: + step_name = '%s%d' % (step_name, cnt) + cnt += 1 + + self._steps[step_name] = step def insert(self, loc, key, step): pos = list(self._steps.keys()).index(loc) diff --git a/stride/optimisation/pipelines/steps/mask_field.py b/stride/optimisation/pipelines/steps/mask_field.py index 695ac96..0bc697d 100644 --- a/stride/optimisation/pipelines/steps/mask_field.py +++ b/stride/optimisation/pipelines/steps/mask_field.py @@ -51,7 +51,7 @@ def forward(self, field, **kwargs): mask = kwargs.pop('mask', None) mask_rampoff = kwargs.pop('mask_rampoff', self.mask_rampoff) mask = self._mask if mask is None else mask - if mask is None: + if mask is None or np.any([m != f for m, f in zip(mask.shape, field.extended_shape)]): mask = np.zeros(field.extended_shape, dtype=np.float32) mask[field.inner] = 1 mask *= _rampoff_mask(mask.shape, mask_rampoff) diff --git a/stride/optimisation/pipelines/steps/shift_traces.py b/stride/optimisation/pipelines/steps/shift_traces.py index 763175c..78e6733 100644 --- a/stride/optimisation/pipelines/steps/shift_traces.py +++ b/stride/optimisation/pipelines/steps/shift_traces.py @@ -1,4 +1,6 @@ +import numpy as np + from .utils import name_from_op_name from ....core import Operator @@ -56,7 +58,7 @@ def _apply(self, traces, **kwargs): return traces f_max_dim_less = 1/relaxation*f_max*time.step if f_max is not None else 0 - period = int(1 / f_max_dim_less) + period = int(np.round(1 / f_max_dim_less)) shift = period//4 out_data = traces.extended_data.copy() diff --git a/stride/physics/common/devito.py b/stride/physics/common/devito.py index dc5893d..9c1cab7 100644 --- a/stride/physics/common/devito.py +++ b/stride/physics/common/devito.py @@ -516,9 +516,13 @@ def undersampled_time_function(self, name, factor, time_bounds=None, Generated function. """ + layers = kwargs.pop('layers', devito.NoLayers) + time_bounds = time_bounds or (0, self.time.extended_num-1) + time_bounds = (time_bounds[0] or 0, time_bounds[1] or self.time.extended_num - 1) - time_under, buffer_size = self._time_undersampled('time_under', factor, time_bounds) + time_under, buffer_size = self._time_undersampled('time_under', factor, time_bounds, + layers=layers) compression = kwargs.pop('compression', None) @@ -529,6 +533,7 @@ def undersampled_time_function(self, name, factor, time_bounds=None, save=buffer_size, dtype=kwargs.pop('dtype', self.dtype), compression=compression, + layers=layers, **kwargs) return fun @@ -544,15 +549,20 @@ def undersampled_time_derivative(self, fun, factor, time_bounds=None, offset=Non return deriv - def _time_undersampled(self, name, factor, time_bounds=None, offset=None): + def _time_undersampled(self, name, factor, time_bounds=None, offset=None, layers=devito.NoLayers): time_bounds = time_bounds or (0, self.time.extended_num - 1) + time_bounds = (time_bounds[0] or 0, time_bounds[1] or self.time.extended_num - 1) offset = offset or (0, 0) time_dim = self.devito_grid.time_dim - condition = sympy.And(devito.CondEq(time_dim % factor, 0), - devito.Ge(time_dim, time_bounds[0] + offset[0]), - devito.Le(time_dim, time_bounds[1] - offset[1]), ) + # TODO Current incompatibility between disk spilling and conditional dimensions + if layers in [devito.DiskHost, devito.DiskHostDevice, devito.DiskDevice]: + condition = None + else: + condition = sympy.And(devito.CondEq(time_dim % factor, 0), + devito.Ge(time_dim, time_bounds[0] + offset[0]), + devito.Le(time_dim, time_bounds[1] - offset[1]), ) time_under = devito.ConditionalDimension('%s%d' % (name, self._time_under_count), parent=time_dim, @@ -560,7 +570,10 @@ def _time_undersampled(self, name, factor, time_bounds=None, offset=None): condition=condition) self._time_under_count += 1 - buffer_size = (time_bounds[1] - time_bounds[0] + factor) // factor + 1 + if layers in [devito.DiskHost, devito.DiskHostDevice, devito.DiskDevice]: + buffer_size = (self.time.extended_num - 1 + factor) // factor + 1 + else: + buffer_size = (time_bounds[1] - time_bounds[0] + factor) // factor + 1 return time_under, buffer_size @@ -598,6 +611,7 @@ def sparse_time_function(self, name, num=1, space_order=None, time_order=None, space_order = self.space_order if space_order is None else space_order time_order = self.time_order if time_order is None else time_order time_bounds = kwargs.pop('time_bounds', (0, self.time.extended_num)) + time_bounds = (time_bounds[0] or 0, time_bounds[1] or self.time.extended_num - 1) smooth = kwargs.pop('smooth', False) # Define variables @@ -768,6 +782,26 @@ def delete(self, name, collect=False): if collect: devito.clear_cache(force=True) + def clear_cache(self, collect=False): + """ + Remove all internal references to devito functions. + + Parameters + ---------- + collect : bool, optional + Whether to garbage collect after deallocate, defaults to ``False``. + + Returns + ------- + + """ + self.vars = Struct() + self.cached_args = Struct() + self.cached_funcs = Struct() + + if collect: + devito.clear_cache(force=True) + def with_halo(self, data, value=None, time_dependent=False, is_vector=False, **kwargs): """ Pad ndarray with appropriate halo given the grid space order. @@ -938,6 +972,7 @@ def set_operator(self, op, **kwargs): 'name': self.name, 'subs': subs, 'opt': 'advanced', + 'autotuning': 'off', 'compiler': 'cuda', 'language': 'cuda', 'platform': 'nvidiaX', @@ -1002,7 +1037,7 @@ def run(self, **kwargs): if arg.name in self.grid.vars: default_kwargs[arg.name] = self.grid.vars[arg.name] - autotune = kwargs.pop('autotune', None) + autotune = kwargs.get('autotune', None) default_kwargs.update(kwargs) if self.grid.time_dim: diff --git a/stride/physics/common/import_devito.py b/stride/physics/common/import_devito.py index 7403c35..a70b36b 100644 --- a/stride/physics/common/import_devito.py +++ b/stride/physics/common/import_devito.py @@ -2,6 +2,7 @@ from devito import * # noqa: F401 from devito.types import Symbol, Scalar # noqa: F401 from devito.symbolics import INT, IntDiv, CondEq # noqa: F401 +from devito.tools import frozendict # noqa: F401 from devito import TimeFunction as TimeFunctionOSS # noqa: F401 try: diff --git a/stride/physics/iso_acoustic/devito.py b/stride/physics/iso_acoustic/devito.py index 8d80643..44fcfd0 100644 --- a/stride/physics/iso_acoustic/devito.py +++ b/stride/physics/iso_acoustic/devito.py @@ -186,6 +186,7 @@ def clear_operators(self): self.state_operator.devito_operator = None self.state_operator_save.devito_operator = None self.adjoint_operator.devito_operator = None + self.dev_grid.clear_cache() def deallocate_wavefield(self, platform='cpu', deallocate=False, **kwargs): if (platform and 'nvidia' in platform) \ @@ -360,10 +361,7 @@ def before_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs): src_scale /= self.time.step src_term = src.inject(field=p.forward, expr=src * src_scale) - if not fw3d_mode: - rec_term = rec.interpolate(expr=p) - else: - rec_term = rec.interpolate(expr=p.forward) + rec_term = rec.interpolate(expr=p.forward) # Define the saving of the wavefield if save_wavefield is True: @@ -423,7 +421,7 @@ def before_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs): if dump_forward_wavefield: if dump_wavefield_id == shot.id: factor = dump_forward_wavefield \ - if isinstance(dump_forward_wavefield, int) else self.undersampling_factor + if type(dump_forward_wavefield) is int else self.undersampling_factor layers = devito.Host if is_nvidia else devito.NoLayers p_dump = self.dev_grid.undersampled_time_function('p_dump', time_bounds=time_bounds, @@ -578,7 +576,7 @@ def run_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs): time_bounds = kwargs.get('time_bounds', (0, self.time.extended_num)) op.run(dt=self.time.step, time_m=1, - time_M=time_bounds[1]-1, + time_M=time_bounds[1]-2, **functions, **devito_args) @@ -724,6 +722,7 @@ def before_adjoint(self, adjoint_source, wavelets, vp, rho=None, alpha=None, **k time_bounds = kwargs.get('time_bounds', (0, self.time.extended_num)) fw3d_mode = kwargs.pop('fw3d_mode', False) + fw3d_mode = True # TODO Force fw3d_mode pending Devito-side fix platform = kwargs.get('platform', 'cpu') is_nvidia = platform is not None and 'nvidia' in platform @@ -749,7 +748,7 @@ def before_adjoint(self, adjoint_source, wavelets, vp, rho=None, alpha=None, **k t = rec.time_dim vp2 = self.dev_grid.vars.vp**2 if not fw3d_mode: - rec_term = rec.inject(field=p_a.backward, expr=-rec.subs({t: t-1}) * self.time.step**2 * vp2) + rec_term = rec.inject(field=p_a, expr=-rec.subs({t: t-1}) * self.time.step**2 * vp2) else: rec_term = rec.inject(field=p_a.backward, expr=-rec * self.time.step**2 * vp2) @@ -766,7 +765,7 @@ def before_adjoint(self, adjoint_source, wavelets, vp, rho=None, alpha=None, **k dump_wavefield_id = kwargs.pop('dump_wavefield_id', shot.id) if dump_adjoint_wavefield and dump_wavefield_id == shot.id: factor = dump_adjoint_wavefield \ - if isinstance(dump_adjoint_wavefield, int) else self.undersampling_factor + if type(dump_adjoint_wavefield) is int else self.undersampling_factor layers = devito.Host if is_nvidia else devito.NoLayers p_dump = self.dev_grid.undersampled_time_function('p_a_dump', time_bounds=time_bounds, @@ -788,8 +787,12 @@ def before_adjoint(self, adjoint_source, wavelets, vp, rho=None, alpha=None, **k if self.attenuation_power == 2: kwargs['devito_config']['opt'] = 'noop' - self.adjoint_operator.set_operator(stencil + rec_term + src_term + gradient_update + update_saved, - **kwargs) + if fw3d_mode: + self.adjoint_operator.set_operator(stencil + rec_term + src_term + gradient_update + update_saved, + **kwargs) + else: + self.adjoint_operator.set_operator(rec_term + stencil + src_term + gradient_update + update_saved, + **kwargs) self.adjoint_operator.compile() else: diff --git a/stride/physics/problem_type.py b/stride/physics/problem_type.py index 92cd7bf..83f7432 100644 --- a/stride/physics/problem_type.py +++ b/stride/physics/problem_type.py @@ -74,13 +74,13 @@ def forward(self, *args, **kwargs): self.before_forward(*args, **kwargs) self.logger.perf('%sRunning state equation for shot' % pre_str) - self.run_forward(*args, **kwargs) + fw_output = self.run_forward(*args, **kwargs) self.logger.perf('%sCompleting state equation run for shot' % pre_str) output = self.after_forward(*args, **kwargs) self.logger.perf('%sCompleted state equation run for shot' % pre_str) - return output + return output if output is not None else fw_output def adjoint(self, *args, **kwargs): """ @@ -102,13 +102,13 @@ def adjoint(self, *args, **kwargs): self.before_adjoint(*args, **kwargs) self.logger.perf('%sRunning adjoint equation for shot' % pre_str) - self.run_adjoint(*args, **kwargs) + ad_output = self.run_adjoint(*args, **kwargs) self.logger.perf('%sCompleting adjoint equation run for shot' % pre_str) output = self.after_adjoint(*args, **kwargs) self.logger.perf('%sCompleted adjoint equation run for shot' % pre_str) - return output + return output if output is not None else ad_output def before_forward(self, *args, **kwargs): """ diff --git a/stride/problem/acquisitions.py b/stride/problem/acquisitions.py index 95124a1..058d4de 100644 --- a/stride/problem/acquisitions.py +++ b/stride/problem/acquisitions.py @@ -1,13 +1,12 @@ import functools import numpy as np -from collections import OrderedDict from cached_property import cached_property import mosaic.types from mosaic.file_manipulation import h5 -from .data import Traces +from .data import Traces, DiskTraces from .base import ProblemBase from .. import plotting @@ -106,8 +105,8 @@ def __init__(self, id, name=None, problem=None, **kwargs): self._acquisitions = None self._sequence = None - self._sources = OrderedDict() - self._receivers = OrderedDict() + self._source_ids = [] + self._receiver_ids = [] self.wavelets = None self.observed = None @@ -118,10 +117,10 @@ def __init__(self, id, name=None, problem=None, **kwargs): if sources is not None and receivers is not None: for source in sources: - self._sources[source.id] = source + self._source_ids.append(source.id) for receiver in receivers: - self._receivers[receiver.id] = receiver + self._receiver_ids.append(receiver.id) self.wavelets = self._traces(name='wavelets', transducer_ids=self.source_ids, grid=self.grid) @@ -144,7 +143,7 @@ def source_ids(self): Get ids of sources in this Shot in a list. """ - return list(self._sources.keys()) + return self._source_ids @property def receiver_ids(self): @@ -152,23 +151,39 @@ def receiver_ids(self): Get ids of receivers in this Shot in a list. """ - return list(self._receivers.keys()) + return self._receiver_ids - @property + @cached_property def sources(self): """ Get sources in this Shot as a list. """ - return list(self._sources.values()) + locations = [] + for loc_id in self.source_ids: + try: + loc = self._geometry.get(loc_id) + except AttributeError: + return None + locations.append(loc) - @property + return locations + + @cached_property def receivers(self): """ Get receivers in this Shot as a list. """ - return list(self._receivers.values()) + locations = [] + for loc_id in self.receiver_ids: + try: + loc = self._geometry.get(loc_id) + except AttributeError: + return None + locations.append(loc) + + return locations @property def slow_time_index(self): @@ -299,12 +314,10 @@ def sub_problem(self, shot, sub_problem): grid=self.grid, geometry=sub_problem.geometry) for source_id in self.source_ids: - location = sub_problem.geometry.get(source_id) - shot._sources[location.id] = location + shot._source_ids.append(source_id) for receiver_id in self.receiver_ids: - location = sub_problem.geometry.get(receiver_id) - shot._receivers[location.id] = location + shot._receiver_ids.append(receiver_id) if self.wavelets is not None: shot.wavelets = self.wavelets @@ -428,9 +441,12 @@ def append_observed(self, *args, **kwargs): else: self._acquisitions.dump(*args, shot_ids=[self.id], **kwargs) - @staticmethod - def _traces(*args, **kwargs): - return Traces(*args, **kwargs) + def _traces(self, *args, **kwargs): + if kwargs.pop('lazy_loading', False): + return DiskTraces(*args, **kwargs, + path='/shots/%d/%s' % (self.id, kwargs.get('name'))) + else: + return Traces(*args, **kwargs) def __get_desc__(self, **kwargs): description = { @@ -453,36 +469,49 @@ def __get_desc__(self, **kwargs): return description def __set_desc__(self, description, **kwargs): + try: + del self.__dict__['sources'] + except: + pass + try: + del self.__dict__['receivers'] + except: + pass + compressed = kwargs.pop('compressed', False) self.id = description.id for source_id in description.source_ids: - if self._geometry: - source = self._geometry.get(source_id) - else: - source = None - self._sources[source_id] = source + self._source_ids.append(source_id) for receiver_id in description.receiver_ids: - if self._geometry: - receiver = self._geometry.get(receiver_id) - else: - receiver = None - self._receivers[receiver_id] = receiver + self._receiver_ids.append(receiver_id) lazy_loading = kwargs.pop('lazy_loading', False) - self.wavelets = self._traces(name='wavelets', transducer_ids=self.source_ids, grid=self.grid) - if 'wavelets' in description and not lazy_loading: + self.wavelets = self._traces( + name='wavelets', transducer_ids=self.source_ids, + grid=self.grid, + lazy_loading=lazy_loading, filename=kwargs.get('filename', None), + ) + if not lazy_loading and 'wavelets' in description: self.wavelets.__set_desc__(description.wavelets, **kwargs) - self.observed = self._traces(name='observed', transducer_ids=self.receiver_ids, compressed=compressed, grid=self.grid) - if 'observed' in description and not lazy_loading: + self.observed = self._traces( + name='observed', transducer_ids=self.receiver_ids, + compressed=compressed, grid=self.grid, + lazy_loading=lazy_loading, filename=kwargs.get('filename', None), + ) + if not lazy_loading and 'observed' in description: self.observed.__set_desc__(description.observed, **kwargs) - self.delays = self._traces(name='delays', transducer_ids=self.source_ids, shape=(len(self.source_ids), 1), grid=self.grid) - if 'delays' in description and not lazy_loading: + self.delays = self._traces( + name='delays', transducer_ids=self.source_ids, + shape=(len(self.source_ids), 1), grid=self.grid, + lazy_loading=lazy_loading, filename=kwargs.get('filename', None), + ) + if not lazy_loading and 'delays' in description: self.delays.__set_desc__(description.delays, **kwargs) @@ -532,7 +561,7 @@ def __init__(self, id, acq=0, name=None, problem=None, **kwargs): self._geometry = geometry self._acquisitions = acquisitions - self._shots = OrderedDict() + self._shots = dict() self._shot_selection = [] @property @@ -607,14 +636,13 @@ def get(self, frame): Found Shot. """ - if isinstance(frame, (np.int32, np.int64)): - frame = int(frame) - - if not isinstance(frame, int) or frame < 0: + try: + return self._shots[frame] + except KeyError: + if isinstance(frame, (np.int32, np.int64)): + return self.get(int(frame)) raise ValueError('Shot frames have to be positive integer numbers') - return self._shots[frame] - def set(self, frame, item): """ Change an existing shot in the Sequence. @@ -736,8 +764,8 @@ def __init__(self, name='acquisitions', problem=None, **kwargs): geometry = kwargs.pop('geometry', None) self._geometry = geometry - self._shots = OrderedDict() - self._sequences = OrderedDict() + self._shots = dict() + self._sequences = dict() self._shot_selection = [] self._sequence_selection = [] self._prev_load = None, None @@ -814,7 +842,7 @@ def remaining_shots(self): Get dict of all shots that have no observed allocated. """ - shots = OrderedDict() + shots = dict() for shot_id, shot in self._shots.items(): if not shot.observed.allocated: shots[shot_id] = shot @@ -887,14 +915,13 @@ def get(self, id): Found Shot. """ - if isinstance(id, (np.int32, np.int64)): - id = int(id) - - if not isinstance(id, int) or id < 0: + try: + return self._shots[id] + except KeyError: + if isinstance(id, (np.int32, np.int64)): + return self.get(int(id)) raise ValueError('Shot IDs have to be positive integer numbers') - return self._shots[id] - def get_sequence(self, id): """ Get a sequence from the Acquisitions with a known id. @@ -910,14 +937,13 @@ def get_sequence(self, id): Found Sequence. """ - if isinstance(id, (np.int32, np.int64)): - id = int(id) - - if not isinstance(id, int) or id < 0: + try: + return self._sequences[id] + except KeyError: + if isinstance(id, (np.int32, np.int64)): + return self.get(int(id)) raise ValueError('Sequence IDs have to be positive integer numbers') - return self._sequences[id] - def set(self, item): """ Change an existing shot in the Acquisitions. @@ -1338,7 +1364,7 @@ def __set_desc__(self, description, **kwargs): problem=self.problem, grid=self.grid) self.add(shot) - shot = self.get(shot_desc.id) + shot = self._shots[shot_desc.id] shot.__set_desc__(shot_desc, compressed=compressed, **kwargs) if 'sequences' in description: diff --git a/stride/problem/base.py b/stride/problem/base.py index 41e1b4f..cbbbc0e 100644 --- a/stride/problem/base.py +++ b/stride/problem/base.py @@ -72,7 +72,7 @@ def slow_time(self): def resample(self, grid=None, space=None, time=None, slow_time=None): raise NotImplementedError('Resampling has not been implemented in this class yet.' + - ' Alternatively, try to access space_resample or time_resample via problem.') + ' Alternatively, try to access space_resample or time_resample via problem.') class Saved: @@ -149,6 +149,7 @@ def load(self, *args, **kwargs): with h5.HDF5(*args, **kwargs, mode='r') as file: description = file.load(filter=kwargs.pop('filter', None), only=kwargs.pop('only', None)) + kwargs['filename'] = kwargs.pop('filename', file.filename) self.__set_desc__(description, **kwargs) def rm(self, *args, **kwargs): @@ -271,6 +272,7 @@ def load(self, *args, **kwargs): self._grid.slow_time = slow_time + kwargs['filename'] = kwargs.pop('filename', file.filename) self.__set_desc__(description, **kwargs) def grid_description(self): diff --git a/stride/problem/data.py b/stride/problem/data.py index 021c3e8..d875fd7 100644 --- a/stride/problem/data.py +++ b/stride/problem/data.py @@ -18,6 +18,7 @@ import mosaic from mosaic.core.tessera import PickleClass from mosaic.comms.compression import maybe_compress, decompress +from mosaic.file_manipulation import h5 from .base import GriddedSaved from ..core import Variable @@ -25,7 +26,7 @@ __all__ = ['Data', 'StructuredData', 'Scalar', 'ScalarField', 'VectorField', 'Traces', - 'SparseField', 'SparseCoordinates'] + 'DiskTraces', 'SparseField', 'SparseCoordinates'] def inv_transform(x): @@ -405,7 +406,7 @@ def process_grad(self, global_prec=True, **kwargs): self.grad.apply_prec(**kwargs) return self.grad - def apply_prec(self, prec_scale=4.0, prec_smooth=None, prec_op=None, prec=None, **kwargs): + def apply_prec(self, prec_scale=4.0, prec_smooth=1.0, prec_op=None, prec=None, **kwargs): """ Apply a pre-conditioner to the current field. @@ -1640,6 +1641,172 @@ def __set_desc__(self, description, **kwargs): self._transducer_ids = description.transducer_ids +@mosaic.tessera +class DiskTraces(Traces): + """ + Objects of this type describe a set of time traces defined over the time grid, + which are lazily loaded from disk. + + See ``Traces`` for definition of parameters. + + """ + + def __init__(self, **kwargs): + self._filename = kwargs.pop('filename', None) + self._path = kwargs.pop('path', None) + + data = kwargs.pop('data', None) + super().__init__(**kwargs) + + @property + def _data(self): + return None + + @_data.setter + def _data(self, value): + pass + + def load(self, **kwargs): + with h5.HDF5(filename=self._filename, **kwargs, mode='r') as file: + file = file.file + data = file['%s/data' % self.path][()] + + return Traces( + data=data, + transducer_ids=self.transducer_ids, + grid=self.grid, + ) + + def get(self, id): + """ + Get one trace based on a transducer ID, selecting the inner domain. + + Parameters + ---------- + id : int + Transducer ID. + + Returns + ------- + 1d-array + Time trace. + + """ + return self.load().get(id) + + def get_extended(self, id): + """ + Get one trace based on a transducer ID, selecting the extended domain. + + Parameters + ---------- + id : int + Transducer ID. + + Returns + ------- + 1d-array + Time trace. + + """ + return self.load().get_extended(id) + + def plot(self, **kwargs): + """ + Plot the inner domain of the traces as a shot gather. + + Parameters + ---------- + kwargs + Arguments for plotting. + + Returns + ------- + axes + Axes on which the plotting is done. + + """ + return self.load().plot(**kwargs) + + def plot_one(self, id, **kwargs): + """ + Plot the the inner domain of one of the traces. + + Parameters + ---------- + id : int + Transducer ID. + kwargs + Arguments for plotting. + + Returns + ------- + axes + Axes on which the plotting is done. + + """ + return self.load().plot_one(id, **kwargs) + + def _resample(self, old_step, new_step, new_num, **kwargs): + ''' + Resample the current trace to a new time-spacing. + + Parameters + ---------- + old_step: float + The original time step. + new_step: float + The new time step. + new_num + The length of the trace. + ''' + + # TODO Enable lazy resampling by marking Traces to be resampled on load + raise RuntimeError('DiskTraces cannot be resampled yet') + + def __get_desc__(self, **kwargs): + return self.load().__get_desc__(**kwargs) + + def __set_desc__(self, description, **kwargs): + del description.data + super().__set_desc__(description, **kwargs) + + _serialisation_attrs = ['name', 'uname', '_init_name', '_shape', '_extended_shape', '_inner', + '_dtype', 'needs_grad', '_compressed', '_compression', + 'transform', 'grad', 'prec', '_transducer_ids', + '_filename', '_path', '_grid'] + + def _serialisation_helper(self): + state = {} + + for attr in self._serialisation_attrs: + state[attr] = getattr(self, attr) + + return state + + @classmethod + def _deserialisation_helper(cls, state): + instance = Traces.__new__(Traces) + + with h5.HDF5(filename=state.pop('_filename'), mode='r') as file: + file = file.file + try: + data = file['%s/data' % state.pop('_path')][()] + except KeyError: + data = None + + instance._data = data + + for attr, value in state.items(): + setattr(instance, attr, value) + + return instance + + def __reduce__(self): + state = self._serialisation_helper() + return self._deserialisation_helper, (state,) + + @mosaic.tessera class SparseField(StructuredData): """ @@ -1733,7 +1900,7 @@ def alike(self, *args, **kwargs): kwargs['num'] = kwargs.pop('num', self.num) kwargs['dim'] = kwargs.pop('dim', self.dim) kwargs['time_dependent'] = kwargs.pop('time_dependent', self.time_dependent) - kwargs['slow_time_dependent'] = kwargs.pop('slow_time_dependent', self.time_dependent) + kwargs['slow_time_dependent'] = kwargs.pop('slow_time_dependent', self.slow_time_dependent) return super().alike(*args, **kwargs) @@ -1751,7 +1918,7 @@ def detach(self, *args, **kwargs): kwargs['num'] = kwargs.pop('num', self.num) kwargs['dim'] = kwargs.pop('dim', self.dim) kwargs['time_dependent'] = kwargs.pop('time_dependent', self.time_dependent) - kwargs['slow_time_dependent'] = kwargs.pop('slow_time_dependent', self.time_dependent) + kwargs['slow_time_dependent'] = kwargs.pop('slow_time_dependent', self.slow_time_dependent) return super().detach(*args, **kwargs) @@ -1769,7 +1936,7 @@ def as_parameter(self, *args, **kwargs): kwargs['num'] = kwargs.pop('num', self.num) kwargs['dim'] = kwargs.pop('dim', self.dim) kwargs['time_dependent'] = kwargs.pop('time_dependent', self.time_dependent) - kwargs['slow_time_dependent'] = kwargs.pop('slow_time_dependent', self.time_dependent) + kwargs['slow_time_dependent'] = kwargs.pop('slow_time_dependent', self.slow_time_dependent) return super().as_parameter(*args, **kwargs) diff --git a/stride/problem/geometry.py b/stride/problem/geometry.py index dce199e..28b5cdb 100644 --- a/stride/problem/geometry.py +++ b/stride/problem/geometry.py @@ -1,5 +1,4 @@ import numpy as np -from collections import OrderedDict import mosaic.types from .base import GriddedSaved, ProblemBase @@ -136,7 +135,7 @@ def __init__(self, name='geometry', problem=None, **kwargs): else: transducers = kwargs.pop('transducers', None) - self._locations = OrderedDict() + self._locations = dict() self._transducers = transducers def add(self, id, transducer, coordinates, orientation=None): @@ -198,14 +197,13 @@ def get(self, id): Found TransducerLocation. """ - if isinstance(id, (np.int32, np.int64)): - id = int(id) - - if not isinstance(id, int) or id < 0: + try: + return self._locations[id] + except KeyError: + if isinstance(id, (np.int32, np.int64)): + return self.get(int(id)) raise ValueError('Transducer IDs have to be positive integer numbers') - return self._locations[id] - def get_slice(self, start=None, end=None, step=None): """ Get a slice of the indices of the locations using ``slice(start, stop, step)``. @@ -225,7 +223,7 @@ def get_slice(self, start=None, end=None, step=None): Found transducer locations in the slice. """ - section = OrderedDict() + section = dict() if start is None: _range = range(end) elif step is None: diff --git a/stride/problem/medium.py b/stride/problem/medium.py index 0ef919e..042a553 100644 --- a/stride/problem/medium.py +++ b/stride/problem/medium.py @@ -1,6 +1,4 @@ -from collections import OrderedDict - from .base import ProblemBase @@ -34,7 +32,7 @@ class Medium(ProblemBase): def __init__(self, name='medium', problem=None, **kwargs): super().__init__(name=name, problem=problem, **kwargs) - self._fields = OrderedDict() + self._fields = dict() def _get(self, item): if item in super().__getattribute__('_fields').keys(): diff --git a/stride/problem/problem.py b/stride/problem/problem.py index 399fef0..49015f8 100644 --- a/stride/problem/problem.py +++ b/stride/problem/problem.py @@ -148,7 +148,7 @@ def space_resample(self, new_spacing, new_extra=None, new_absorbing=None, **kwar self.medium.fields[field]._resample(old_spacing, new_spacing, slowness=True, **kwargs) else: self.medium.fields[field]._resample(old_spacing, new_spacing, **kwargs) - return [self.medium.fields[field] for field in self.medium.fields] + return tuple(self.medium.fields[field] for field in self.medium.fields) def time_resample(self, new_step, new_num=None, **kwargs): ''' diff --git a/stride/problem/transducers.py b/stride/problem/transducers.py index d52977d..938b3ed 100644 --- a/stride/problem/transducers.py +++ b/stride/problem/transducers.py @@ -1,6 +1,4 @@ -from collections import OrderedDict - from mosaic.utils import camel_case from .base import ProblemBase @@ -34,7 +32,7 @@ class Transducers(ProblemBase): def __init__(self, name='transducers', problem=None, **kwargs): super().__init__(name=name, problem=problem, **kwargs) - self._transducers = OrderedDict() + self._transducers = dict() def add(self, item): """ @@ -69,11 +67,11 @@ def get(self, id): Found Transducer. """ - if not isinstance(id, int) or id < 0: + try: + return self._transducers[id] + except KeyError: raise ValueError('Transducer IDs have to be positive integer numbers') - return self._transducers[id] - def get_slice(self, start=None, end=None, step=None): """ Get a slice of the indices of the transducer using ``slice(start, stop, step)``. @@ -93,7 +91,7 @@ def get_slice(self, start=None, end=None, step=None): Found transducers in the slice. """ - section = OrderedDict() + section = dict() if start is None: _range = range(end) elif step is None: diff --git a/stride/utils/filters.py b/stride/utils/filters.py index 7734d3c..2a9bb27 100644 --- a/stride/utils/filters.py +++ b/stride/utils/filters.py @@ -439,7 +439,7 @@ def lowpass_filter_cos(data, f_max, order=1, """ f_max = f_max / 0.5 - period = int(1 / f_max) + period = int(np.round(1 / f_max)) filter_length = 2*period + 1 table = _make_filter_cos(filter_length)