From 4a5aff958b4579954cf3eeedaf81b2fca10a1f5c Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Tue, 29 Jul 2025 10:38:01 +0100 Subject: [PATCH 01/23] Change rounding of filter length --- mosaic/runtime/worker.py | 2 ++ stride/__init__.py | 3 +-- stride/optimisation/pipelines/steps/shift_traces.py | 2 +- stride/physics/common/devito.py | 4 ++++ stride/physics/problem_type.py | 8 ++++---- stride/problem/data.py | 6 +++--- stride/utils/filters.py | 2 +- 7 files changed, 16 insertions(+), 11 deletions(-) diff --git a/mosaic/runtime/worker.py b/mosaic/runtime/worker.py index b6c85128..8129571a 100644 --- a/mosaic/runtime/worker.py +++ b/mosaic/runtime/worker.py @@ -32,6 +32,8 @@ def __init__(self, **kwargs): os.environ['OMP_NUM_THREADS'] = str(self._num_threads) os.environ['NUMBA_NUM_THREADS'] = str(self._num_threads) + os.environ['MKL_NUM_THREADS'] = str(self._num_threads) + os.environ['OPENBLAS_NUM_THREADS'] = str(self._num_threads) def set_logger(self): """ diff --git a/stride/__init__.py b/stride/__init__.py index 0847fe76..50bb2d3f 100644 --- a/stride/__init__.py +++ b/stride/__init__.py @@ -233,8 +233,7 @@ async def adjoint(problem, pde, loss, optimisation_loop, optimiser, *args, **kwa filter_wavelets = kwargs.pop('filter_wavelets', filter_traces) fw3d_mode = kwargs.get('fw3d_mode', False) - filter_wavelets_relaxation = kwargs.pop('filter_wavelets_relaxation', - 0.75 if not fw3d_mode else 0.725) + filter_wavelets_relaxation = kwargs.pop('filter_wavelets_relaxation', 0.75) filter_traces_relaxation = kwargs.pop('filter_traces_relaxation', 0.75 if filter_wavelets else 1.00) diff --git a/stride/optimisation/pipelines/steps/shift_traces.py b/stride/optimisation/pipelines/steps/shift_traces.py index 763175c4..fd6a0a03 100644 --- a/stride/optimisation/pipelines/steps/shift_traces.py +++ b/stride/optimisation/pipelines/steps/shift_traces.py @@ -56,7 +56,7 @@ def _apply(self, traces, **kwargs): return traces f_max_dim_less = 1/relaxation*f_max*time.step if f_max is not None else 0 - period = int(1 / f_max_dim_less) + period = int(np.round(1 / f_max_dim_less)) shift = period//4 out_data = traces.extended_data.copy() diff --git a/stride/physics/common/devito.py b/stride/physics/common/devito.py index dc5893df..1b4af226 100644 --- a/stride/physics/common/devito.py +++ b/stride/physics/common/devito.py @@ -517,6 +517,7 @@ def undersampled_time_function(self, name, factor, time_bounds=None, """ time_bounds = time_bounds or (0, self.time.extended_num-1) + time_bounds = (time_bounds[0] or 0, time_bounds[1] or self.time.extended_num - 1) time_under, buffer_size = self._time_undersampled('time_under', factor, time_bounds) @@ -546,6 +547,7 @@ def undersampled_time_derivative(self, fun, factor, time_bounds=None, offset=Non def _time_undersampled(self, name, factor, time_bounds=None, offset=None): time_bounds = time_bounds or (0, self.time.extended_num - 1) + time_bounds = (time_bounds[0] or 0, time_bounds[1] or self.time.extended_num - 1) offset = offset or (0, 0) time_dim = self.devito_grid.time_dim @@ -598,6 +600,7 @@ def sparse_time_function(self, name, num=1, space_order=None, time_order=None, space_order = self.space_order if space_order is None else space_order time_order = self.time_order if time_order is None else time_order time_bounds = kwargs.pop('time_bounds', (0, self.time.extended_num)) + time_bounds = (time_bounds[0] or 0, time_bounds[1] or self.time.extended_num - 1) smooth = kwargs.pop('smooth', False) # Define variables @@ -938,6 +941,7 @@ def set_operator(self, op, **kwargs): 'name': self.name, 'subs': subs, 'opt': 'advanced', + 'autotuning': 'off', 'compiler': 'cuda', 'language': 'cuda', 'platform': 'nvidiaX', diff --git a/stride/physics/problem_type.py b/stride/physics/problem_type.py index 92cd7bfe..83f74321 100644 --- a/stride/physics/problem_type.py +++ b/stride/physics/problem_type.py @@ -74,13 +74,13 @@ def forward(self, *args, **kwargs): self.before_forward(*args, **kwargs) self.logger.perf('%sRunning state equation for shot' % pre_str) - self.run_forward(*args, **kwargs) + fw_output = self.run_forward(*args, **kwargs) self.logger.perf('%sCompleting state equation run for shot' % pre_str) output = self.after_forward(*args, **kwargs) self.logger.perf('%sCompleted state equation run for shot' % pre_str) - return output + return output if output is not None else fw_output def adjoint(self, *args, **kwargs): """ @@ -102,13 +102,13 @@ def adjoint(self, *args, **kwargs): self.before_adjoint(*args, **kwargs) self.logger.perf('%sRunning adjoint equation for shot' % pre_str) - self.run_adjoint(*args, **kwargs) + ad_output = self.run_adjoint(*args, **kwargs) self.logger.perf('%sCompleting adjoint equation run for shot' % pre_str) output = self.after_adjoint(*args, **kwargs) self.logger.perf('%sCompleted adjoint equation run for shot' % pre_str) - return output + return output if output is not None else ad_output def before_forward(self, *args, **kwargs): """ diff --git a/stride/problem/data.py b/stride/problem/data.py index d62ccc1a..916a1185 100644 --- a/stride/problem/data.py +++ b/stride/problem/data.py @@ -1769,7 +1769,7 @@ def alike(self, *args, **kwargs): kwargs['num'] = kwargs.pop('num', self.num) kwargs['dim'] = kwargs.pop('dim', self.dim) kwargs['time_dependent'] = kwargs.pop('time_dependent', self.time_dependent) - kwargs['slow_time_dependent'] = kwargs.pop('slow_time_dependent', self.time_dependent) + kwargs['slow_time_dependent'] = kwargs.pop('slow_time_dependent', self.slow_time_dependent) return super().alike(*args, **kwargs) @@ -1787,7 +1787,7 @@ def detach(self, *args, **kwargs): kwargs['num'] = kwargs.pop('num', self.num) kwargs['dim'] = kwargs.pop('dim', self.dim) kwargs['time_dependent'] = kwargs.pop('time_dependent', self.time_dependent) - kwargs['slow_time_dependent'] = kwargs.pop('slow_time_dependent', self.time_dependent) + kwargs['slow_time_dependent'] = kwargs.pop('slow_time_dependent', self.slow_time_dependent) return super().detach(*args, **kwargs) @@ -1805,7 +1805,7 @@ def as_parameter(self, *args, **kwargs): kwargs['num'] = kwargs.pop('num', self.num) kwargs['dim'] = kwargs.pop('dim', self.dim) kwargs['time_dependent'] = kwargs.pop('time_dependent', self.time_dependent) - kwargs['slow_time_dependent'] = kwargs.pop('slow_time_dependent', self.time_dependent) + kwargs['slow_time_dependent'] = kwargs.pop('slow_time_dependent', self.slow_time_dependent) return super().as_parameter(*args, **kwargs) diff --git a/stride/utils/filters.py b/stride/utils/filters.py index 16646fb5..eb0daf71 100644 --- a/stride/utils/filters.py +++ b/stride/utils/filters.py @@ -388,7 +388,7 @@ def lowpass_filter_cos(data, f_max, order=1, """ f_max = f_max / 0.5 - period = int(1 / f_max) + period = int(np.round(1 / f_max)) filter_length = 2*period + 1 table = _make_filter_cos(filter_length) From 73739a68ee939c0a5e069b164a64292dc752aa21 Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Wed, 30 Jul 2025 11:53:05 +0100 Subject: [PATCH 02/23] Update trace sampling point --- .../optimisation/pipelines/steps/shift_traces.py | 2 ++ stride/physics/iso_acoustic/devito.py | 15 ++++++++------- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/stride/optimisation/pipelines/steps/shift_traces.py b/stride/optimisation/pipelines/steps/shift_traces.py index fd6a0a03..78e6733c 100644 --- a/stride/optimisation/pipelines/steps/shift_traces.py +++ b/stride/optimisation/pipelines/steps/shift_traces.py @@ -1,4 +1,6 @@ +import numpy as np + from .utils import name_from_op_name from ....core import Operator diff --git a/stride/physics/iso_acoustic/devito.py b/stride/physics/iso_acoustic/devito.py index 8d806433..e2464ce8 100644 --- a/stride/physics/iso_acoustic/devito.py +++ b/stride/physics/iso_acoustic/devito.py @@ -360,10 +360,7 @@ def before_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs): src_scale /= self.time.step src_term = src.inject(field=p.forward, expr=src * src_scale) - if not fw3d_mode: - rec_term = rec.interpolate(expr=p) - else: - rec_term = rec.interpolate(expr=p.forward) + rec_term = rec.interpolate(expr=p.forward) # Define the saving of the wavefield if save_wavefield is True: @@ -749,7 +746,7 @@ def before_adjoint(self, adjoint_source, wavelets, vp, rho=None, alpha=None, **k t = rec.time_dim vp2 = self.dev_grid.vars.vp**2 if not fw3d_mode: - rec_term = rec.inject(field=p_a.backward, expr=-rec.subs({t: t-1}) * self.time.step**2 * vp2) + rec_term = rec.inject(field=p_a, expr=-rec.subs({t: t-1}) * self.time.step**2 * vp2) else: rec_term = rec.inject(field=p_a.backward, expr=-rec * self.time.step**2 * vp2) @@ -788,8 +785,12 @@ def before_adjoint(self, adjoint_source, wavelets, vp, rho=None, alpha=None, **k if self.attenuation_power == 2: kwargs['devito_config']['opt'] = 'noop' - self.adjoint_operator.set_operator(stencil + rec_term + src_term + gradient_update + update_saved, - **kwargs) + if not fw3d_mode: + self.adjoint_operator.set_operator(stencil + rec_term + src_term + gradient_update + update_saved, + **kwargs) + else: + self.adjoint_operator.set_operator(rec_term + stencil + src_term + gradient_update + update_saved, + **kwargs) self.adjoint_operator.compile() else: From 96a62a1134aa160e3b387df040fd5c27370ed66c Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Mon, 4 Aug 2025 10:13:08 +0100 Subject: [PATCH 03/23] Minor fix to adjoint injection --- .../pipelines/steps/mask_field.py | 2 +- stride/physics/common/devito.py | 20 +++++++++++++++++++ stride/physics/iso_acoustic/devito.py | 3 ++- 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/stride/optimisation/pipelines/steps/mask_field.py b/stride/optimisation/pipelines/steps/mask_field.py index 695ac966..0bc697dd 100644 --- a/stride/optimisation/pipelines/steps/mask_field.py +++ b/stride/optimisation/pipelines/steps/mask_field.py @@ -51,7 +51,7 @@ def forward(self, field, **kwargs): mask = kwargs.pop('mask', None) mask_rampoff = kwargs.pop('mask_rampoff', self.mask_rampoff) mask = self._mask if mask is None else mask - if mask is None: + if mask is None or np.any([m != f for m, f in zip(mask.shape, field.extended_shape)]): mask = np.zeros(field.extended_shape, dtype=np.float32) mask[field.inner] = 1 mask *= _rampoff_mask(mask.shape, mask_rampoff) diff --git a/stride/physics/common/devito.py b/stride/physics/common/devito.py index 1b4af226..43fd97f6 100644 --- a/stride/physics/common/devito.py +++ b/stride/physics/common/devito.py @@ -771,6 +771,26 @@ def delete(self, name, collect=False): if collect: devito.clear_cache(force=True) + def clear_cache(self, collect=False): + """ + Remove all internal references to devito functions. + + Parameters + ---------- + collect : bool, optional + Whether to garbage collect after deallocate, defaults to ``False``. + + Returns + ------- + + """ + self.vars = Struct() + self.cached_args = Struct() + self.cached_funcs = Struct() + + if collect: + devito.clear_cache(force=True) + def with_halo(self, data, value=None, time_dependent=False, is_vector=False, **kwargs): """ Pad ndarray with appropriate halo given the grid space order. diff --git a/stride/physics/iso_acoustic/devito.py b/stride/physics/iso_acoustic/devito.py index e2464ce8..e499da0c 100644 --- a/stride/physics/iso_acoustic/devito.py +++ b/stride/physics/iso_acoustic/devito.py @@ -186,6 +186,7 @@ def clear_operators(self): self.state_operator.devito_operator = None self.state_operator_save.devito_operator = None self.adjoint_operator.devito_operator = None + self.dev_grid.clear_cache() def deallocate_wavefield(self, platform='cpu', deallocate=False, **kwargs): if (platform and 'nvidia' in platform) \ @@ -785,7 +786,7 @@ def before_adjoint(self, adjoint_source, wavelets, vp, rho=None, alpha=None, **k if self.attenuation_power == 2: kwargs['devito_config']['opt'] = 'noop' - if not fw3d_mode: + if fw3d_mode: self.adjoint_operator.set_operator(stencil + rec_term + src_term + gradient_update + update_saved, **kwargs) else: From 71d9fe447cdf004b9140c08eb1067f142f46d041 Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Thu, 7 Aug 2025 14:10:50 +0100 Subject: [PATCH 04/23] Enable same-named steps in pipeline --- stride/optimisation/pipelines/pipeline.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/stride/optimisation/pipelines/pipeline.py b/stride/optimisation/pipelines/pipeline.py index e026fdc5..29c3e0f6 100644 --- a/stride/optimisation/pipelines/pipeline.py +++ b/stride/optimisation/pipelines/pipeline.py @@ -67,14 +67,24 @@ def __init__(self, steps=None, **kwargs): step, do_raise = step if isinstance(step, str): + step_name = step step_cls = steps_registry.get(step, None) if step_cls is None and do_raise: raise ValueError('Pipeline step %s does not exist in the registry' % step) if step_cls is not None: - self._steps[step] = step_cls(**kwargs) + step = step_cls(**kwargs) + else: + continue else: - self._steps[str(step)] = step + step_name = str(step) + + cnt = 0 + while step_name in self._steps: + step_name = '%s%d' % (step_name, cnt) + cnt += 1 + + self._steps[step_name] = step def insert(self, loc, key, step): pos = list(self._steps.keys()).index(loc) From 75b516c2cfcc8fcd11a825ad39448bf40f515992 Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Mon, 18 Aug 2025 13:14:20 +0100 Subject: [PATCH 05/23] Ensure monitor waits for warehouse tasks --- mosaic/runtime/monitor.py | 4 ++++ stride/core.py | 6 ++++-- stride/problem/data.py | 2 +- stride/problem/problem.py | 2 +- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/mosaic/runtime/monitor.py b/mosaic/runtime/monitor.py index 3c771325..30f25909 100644 --- a/mosaic/runtime/monitor.py +++ b/mosaic/runtime/monitor.py @@ -505,5 +505,9 @@ async def barrier(self, sender_id, timeout=None): if task.state in ['done', 'failed', 'collected']: pending_tasks.remove(task) + for task in self._monitored_tasks.values(): + if task.state not in ['done', 'failed', 'collected'] and task not in pending_tasks: + pending_tasks.append(task) + if timeout is not None and (time.time() - tic) > timeout: break diff --git a/stride/core.py b/stride/core.py index 536b7b50..e08de564 100644 --- a/stride/core.py +++ b/stride/core.py @@ -401,9 +401,11 @@ def _dealloc(*args): prev[nxt.name_idx] = input_grad + parallel_inits = [task.init_future for task in parallel_returns] + eager = not len(returns) or returns[-1]._eager if eager: - await asyncio.gather(*returns) + await asyncio.gather(*(returns + parallel_inits)) else: summ_returns = [] summ_dependencies = [] @@ -413,7 +415,7 @@ def _dealloc(*args): for runtime_deps in ret._dependencies.values(): summ_dependencies += list(runtime_deps.values()) - await asyncio.gather(*summ_returns) + await asyncio.gather(*(summ_returns + parallel_inits)) self.clear_graph() diff --git a/stride/problem/data.py b/stride/problem/data.py index 37182d68..3a296816 100644 --- a/stride/problem/data.py +++ b/stride/problem/data.py @@ -405,7 +405,7 @@ def process_grad(self, global_prec=True, **kwargs): self.grad.apply_prec(**kwargs) return self.grad - def apply_prec(self, prec_scale=4.0, prec_smooth=None, prec_op=None, prec=None, **kwargs): + def apply_prec(self, prec_scale=4.0, prec_smooth=1.0, prec_op=None, prec=None, **kwargs): """ Apply a pre-conditioner to the current field. diff --git a/stride/problem/problem.py b/stride/problem/problem.py index 399fef0a..49015f86 100644 --- a/stride/problem/problem.py +++ b/stride/problem/problem.py @@ -148,7 +148,7 @@ def space_resample(self, new_spacing, new_extra=None, new_absorbing=None, **kwar self.medium.fields[field]._resample(old_spacing, new_spacing, slowness=True, **kwargs) else: self.medium.fields[field]._resample(old_spacing, new_spacing, **kwargs) - return [self.medium.fields[field] for field in self.medium.fields] + return tuple(self.medium.fields[field] for field in self.medium.fields) def time_resample(self, new_step, new_num=None, **kwargs): ''' From 47cdc8f5a5c2e6bf4e33013aa41040d3fc25d36a Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Mon, 18 Aug 2025 15:47:25 +0100 Subject: [PATCH 06/23] Update undersampled time definition --- stride/physics/common/devito.py | 1 - 1 file changed, 1 deletion(-) diff --git a/stride/physics/common/devito.py b/stride/physics/common/devito.py index 43fd97f6..b9e2d741 100644 --- a/stride/physics/common/devito.py +++ b/stride/physics/common/devito.py @@ -558,7 +558,6 @@ def _time_undersampled(self, name, factor, time_bounds=None, offset=None): time_under = devito.ConditionalDimension('%s%d' % (name, self._time_under_count), parent=time_dim, - factor=factor, condition=condition) self._time_under_count += 1 From bf8bc449b671a13bcff83682ca0e7408299bbe25 Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Mon, 18 Aug 2025 17:45:30 +0100 Subject: [PATCH 07/23] Update undersampled time definition --- stride/physics/common/devito.py | 1 + 1 file changed, 1 insertion(+) diff --git a/stride/physics/common/devito.py b/stride/physics/common/devito.py index b9e2d741..43fd97f6 100644 --- a/stride/physics/common/devito.py +++ b/stride/physics/common/devito.py @@ -558,6 +558,7 @@ def _time_undersampled(self, name, factor, time_bounds=None, offset=None): time_under = devito.ConditionalDimension('%s%d' % (name, self._time_under_count), parent=time_dim, + factor=factor, condition=condition) self._time_under_count += 1 From 4acc1ea66074e54f37284db0cbceaa9693b4947a Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Mon, 18 Aug 2025 18:59:08 +0100 Subject: [PATCH 08/23] Patch time bounds --- stride/physics/common/devito.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stride/physics/common/devito.py b/stride/physics/common/devito.py index 43fd97f6..e4b2826c 100644 --- a/stride/physics/common/devito.py +++ b/stride/physics/common/devito.py @@ -562,7 +562,8 @@ def _time_undersampled(self, name, factor, time_bounds=None, offset=None): condition=condition) self._time_under_count += 1 - buffer_size = (time_bounds[1] - time_bounds[0] + factor) // factor + 1 + buffer_size = (self.time.extended_num - time_bounds[0] + factor) // factor + 1 + # buffer_size = (time_bounds[1] - time_bounds[0] + factor) // factor + 1 return time_under, buffer_size From 16b7ca15146627e241f1076b8b78b733a3480497 Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Wed, 17 Sep 2025 09:55:34 +0100 Subject: [PATCH 09/23] Reinstate time bounds buffer --- stride/physics/common/devito.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/stride/physics/common/devito.py b/stride/physics/common/devito.py index e4b2826c..43fd97f6 100644 --- a/stride/physics/common/devito.py +++ b/stride/physics/common/devito.py @@ -562,8 +562,7 @@ def _time_undersampled(self, name, factor, time_bounds=None, offset=None): condition=condition) self._time_under_count += 1 - buffer_size = (self.time.extended_num - time_bounds[0] + factor) // factor + 1 - # buffer_size = (time_bounds[1] - time_bounds[0] + factor) // factor + 1 + buffer_size = (time_bounds[1] - time_bounds[0] + factor) // factor + 1 return time_under, buffer_size From 424397b02522425659d7a5e22c0f6a8af1c03553 Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Thu, 18 Sep 2025 14:50:28 +0100 Subject: [PATCH 10/23] Ensure iteration carried gradient processing --- stride/__init__.py | 3 +-- stride/optimisation/optimisers/optimiser.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/stride/__init__.py b/stride/__init__.py index 50bb2d3f..20a03510 100644 --- a/stride/__init__.py +++ b/stride/__init__.py @@ -232,7 +232,6 @@ async def adjoint(problem, pde, loss, optimisation_loop, optimiser, *args, **kwa filter_traces = kwargs.pop('filter_traces', True) filter_wavelets = kwargs.pop('filter_wavelets', filter_traces) - fw3d_mode = kwargs.get('fw3d_mode', False) filter_wavelets_relaxation = kwargs.pop('filter_wavelets_relaxation', 0.75) filter_traces_relaxation = kwargs.pop('filter_traces_relaxation', 0.75 if filter_wavelets else 1.00) @@ -453,9 +452,9 @@ async def loop(worker, shot_id): fun.clear_graph() iteration.add_loss(fun) - iteration.add_completed(sub_problem.shot) logger.perf('Functional value for shot %d: %s' % (shot_id, fun)) + iteration.add_completed(sub_problem.shot) logger.perf('Retrieved test step for shot %d (%d out of %d)' % (sub_problem.shot_id, iteration.num_completed, num_shots)) diff --git a/stride/optimisation/optimisers/optimiser.py b/stride/optimisation/optimisers/optimiser.py index 770ac962..85cdc272 100644 --- a/stride/optimisation/optimisers/optimiser.py +++ b/stride/optimisation/optimisers/optimiser.py @@ -79,8 +79,8 @@ async def pre_process(self, grad=None, processed_grad=None, **kwargs): logger = mosaic.logger() logger.perf('Updating variable %s,' % self.variable.name) - problem = kwargs.pop('problem', None) - iteration = kwargs.pop('iteration', None) + problem = kwargs.get('problem', None) + iteration = kwargs.get('iteration', None) if processed_grad is None: if grad is None: From 2b24c8c7e5cd017f035feac03db95c8800f00bd7 Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Thu, 25 Sep 2025 16:23:37 +0100 Subject: [PATCH 11/23] Changed variable update procedure --- stride/optimisation/optimisers/gradient_descent.py | 4 ---- stride/optimisation/optimisers/optimiser.py | 8 ++++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/stride/optimisation/optimisers/gradient_descent.py b/stride/optimisation/optimisers/gradient_descent.py index 21a7dd8e..bd9cb1b6 100644 --- a/stride/optimisation/optimisers/gradient_descent.py +++ b/stride/optimisation/optimisers/gradient_descent.py @@ -29,7 +29,3 @@ async def pre_process(self, grad=None, processed_grad=None, **kwargs): processed_grad=processed_grad, **kwargs) return processed_grad - - def update_variable(self, step_size, variable, direction): - variable.data[:] -= step_size * direction.data - return variable diff --git a/stride/optimisation/optimisers/optimiser.py b/stride/optimisation/optimisers/optimiser.py index 85cdc272..64f5f306 100644 --- a/stride/optimisation/optimisers/optimiser.py +++ b/stride/optimisation/optimisers/optimiser.py @@ -229,7 +229,7 @@ async def step(self, step_size=None, grad=None, processed_grad=None, **kwargs): variable = self.variable.transform(self.variable) else: variable = self.variable - upd_variable = self.update_variable(next_step, variable, direction) + upd_variable = self.update_variable(next_step, variable, direction, **kwargs) if self.variable.transform is not None: upd_variable = self.variable.transform(upd_variable) self.variable.data[:] = upd_variable.data.copy() @@ -278,8 +278,7 @@ async def post_process(self, **kwargs): self.variable.release_grad() - @abstractmethod - def update_variable(self, step_size, variable, direction): + def update_variable(self, step_size, variable, direction, **kwargs): """ Parameters @@ -297,7 +296,8 @@ def update_variable(self, step_size, variable, direction): Updated variable. """ - pass + variable.data[:] -= step_size * direction.data + return variable def reset(self, **kwargs): """ From a3430c4d72cc18300f349e1f32b22dda9ba1781c Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Thu, 23 Oct 2025 17:59:47 +0100 Subject: [PATCH 12/23] Ensure autotuning is used --- stride/physics/common/devito.py | 2 +- stride/physics/common/import_devito.py | 1 + stride/physics/iso_acoustic/devito.py | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/stride/physics/common/devito.py b/stride/physics/common/devito.py index 43fd97f6..6e333f77 100644 --- a/stride/physics/common/devito.py +++ b/stride/physics/common/devito.py @@ -1026,7 +1026,7 @@ def run(self, **kwargs): if arg.name in self.grid.vars: default_kwargs[arg.name] = self.grid.vars[arg.name] - autotune = kwargs.pop('autotune', None) + autotune = kwargs.get('autotune', None) default_kwargs.update(kwargs) if self.grid.time_dim: diff --git a/stride/physics/common/import_devito.py b/stride/physics/common/import_devito.py index 7403c359..a70b36b6 100644 --- a/stride/physics/common/import_devito.py +++ b/stride/physics/common/import_devito.py @@ -2,6 +2,7 @@ from devito import * # noqa: F401 from devito.types import Symbol, Scalar # noqa: F401 from devito.symbolics import INT, IntDiv, CondEq # noqa: F401 +from devito.tools import frozendict # noqa: F401 from devito import TimeFunction as TimeFunctionOSS # noqa: F401 try: diff --git a/stride/physics/iso_acoustic/devito.py b/stride/physics/iso_acoustic/devito.py index e499da0c..6a8c3f55 100644 --- a/stride/physics/iso_acoustic/devito.py +++ b/stride/physics/iso_acoustic/devito.py @@ -421,7 +421,7 @@ def before_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs): if dump_forward_wavefield: if dump_wavefield_id == shot.id: factor = dump_forward_wavefield \ - if isinstance(dump_forward_wavefield, int) else self.undersampling_factor + if type(dump_forward_wavefield) == int else self.undersampling_factor layers = devito.Host if is_nvidia else devito.NoLayers p_dump = self.dev_grid.undersampled_time_function('p_dump', time_bounds=time_bounds, @@ -764,7 +764,7 @@ def before_adjoint(self, adjoint_source, wavelets, vp, rho=None, alpha=None, **k dump_wavefield_id = kwargs.pop('dump_wavefield_id', shot.id) if dump_adjoint_wavefield and dump_wavefield_id == shot.id: factor = dump_adjoint_wavefield \ - if isinstance(dump_adjoint_wavefield, int) else self.undersampling_factor + if type(dump_adjoint_wavefield) == int else self.undersampling_factor layers = devito.Host if is_nvidia else devito.NoLayers p_dump = self.dev_grid.undersampled_time_function('p_a_dump', time_bounds=time_bounds, From 608f41f90a8e78b436c51d0961f4fa55c6534490 Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Tue, 4 Nov 2025 16:38:55 +0000 Subject: [PATCH 13/23] Workaround for Devito bug on spilling/CondEq --- stride/physics/common/devito.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/stride/physics/common/devito.py b/stride/physics/common/devito.py index 6e333f77..9c1cab7e 100644 --- a/stride/physics/common/devito.py +++ b/stride/physics/common/devito.py @@ -516,10 +516,13 @@ def undersampled_time_function(self, name, factor, time_bounds=None, Generated function. """ + layers = kwargs.pop('layers', devito.NoLayers) + time_bounds = time_bounds or (0, self.time.extended_num-1) time_bounds = (time_bounds[0] or 0, time_bounds[1] or self.time.extended_num - 1) - time_under, buffer_size = self._time_undersampled('time_under', factor, time_bounds) + time_under, buffer_size = self._time_undersampled('time_under', factor, time_bounds, + layers=layers) compression = kwargs.pop('compression', None) @@ -530,6 +533,7 @@ def undersampled_time_function(self, name, factor, time_bounds=None, save=buffer_size, dtype=kwargs.pop('dtype', self.dtype), compression=compression, + layers=layers, **kwargs) return fun @@ -545,16 +549,20 @@ def undersampled_time_derivative(self, fun, factor, time_bounds=None, offset=Non return deriv - def _time_undersampled(self, name, factor, time_bounds=None, offset=None): + def _time_undersampled(self, name, factor, time_bounds=None, offset=None, layers=devito.NoLayers): time_bounds = time_bounds or (0, self.time.extended_num - 1) time_bounds = (time_bounds[0] or 0, time_bounds[1] or self.time.extended_num - 1) offset = offset or (0, 0) time_dim = self.devito_grid.time_dim - condition = sympy.And(devito.CondEq(time_dim % factor, 0), - devito.Ge(time_dim, time_bounds[0] + offset[0]), - devito.Le(time_dim, time_bounds[1] - offset[1]), ) + # TODO Current incompatibility between disk spilling and conditional dimensions + if layers in [devito.DiskHost, devito.DiskHostDevice, devito.DiskDevice]: + condition = None + else: + condition = sympy.And(devito.CondEq(time_dim % factor, 0), + devito.Ge(time_dim, time_bounds[0] + offset[0]), + devito.Le(time_dim, time_bounds[1] - offset[1]), ) time_under = devito.ConditionalDimension('%s%d' % (name, self._time_under_count), parent=time_dim, @@ -562,7 +570,10 @@ def _time_undersampled(self, name, factor, time_bounds=None, offset=None): condition=condition) self._time_under_count += 1 - buffer_size = (time_bounds[1] - time_bounds[0] + factor) // factor + 1 + if layers in [devito.DiskHost, devito.DiskHostDevice, devito.DiskDevice]: + buffer_size = (self.time.extended_num - 1 + factor) // factor + 1 + else: + buffer_size = (time_bounds[1] - time_bounds[0] + factor) // factor + 1 return time_under, buffer_size From 1a5034f53bdc5a93d305d2b209288ca1a1353587 Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Wed, 5 Nov 2025 14:50:52 +0000 Subject: [PATCH 14/23] Loading performance improvements --- mosaic/file_manipulation/h5.py | 12 ++-- stride/problem/acquisitions.py | 107 ++++++++++++++++++--------------- stride/problem/geometry.py | 16 +++-- stride/problem/medium.py | 4 +- stride/problem/transducers.py | 10 ++- 5 files changed, 78 insertions(+), 71 deletions(-) diff --git a/mosaic/file_manipulation/h5.py b/mosaic/file_manipulation/h5.py index 54afddd8..a52c19b1 100644 --- a/mosaic/file_manipulation/h5.py +++ b/mosaic/file_manipulation/h5.py @@ -175,23 +175,23 @@ def read(obj, lazy=True, filter=None, only=None): def _read_dataset(obj, lazy=True): - if 'is_none' in obj.attrs and obj.attrs['is_none']: + if obj.attrs.get('is_none'): return None if obj.attrs['is_ndarray']: - def load(): - return obj[()] + if lazy is True: + def load(): + return obj[()] - setattr(obj, 'load', load) + setattr(obj, 'load', load) - if lazy is True: return obj else: return obj[()] - elif 'is_bytes' in obj.attrs and obj.attrs['is_bytes']: + elif obj.attrs.get('is_bytes'): obj = obj[()].tobytes() return obj diff --git a/stride/problem/acquisitions.py b/stride/problem/acquisitions.py index 95124a14..47fc0dd2 100644 --- a/stride/problem/acquisitions.py +++ b/stride/problem/acquisitions.py @@ -1,7 +1,6 @@ import functools import numpy as np -from collections import OrderedDict from cached_property import cached_property import mosaic.types @@ -106,8 +105,8 @@ def __init__(self, id, name=None, problem=None, **kwargs): self._acquisitions = None self._sequence = None - self._sources = OrderedDict() - self._receivers = OrderedDict() + self._source_ids = [] + self._receiver_ids = [] self.wavelets = None self.observed = None @@ -118,10 +117,10 @@ def __init__(self, id, name=None, problem=None, **kwargs): if sources is not None and receivers is not None: for source in sources: - self._sources[source.id] = source + self._source_ids.append(source.id) for receiver in receivers: - self._receivers[receiver.id] = receiver + self._receiver_ids.append(receiver.id) self.wavelets = self._traces(name='wavelets', transducer_ids=self.source_ids, grid=self.grid) @@ -144,7 +143,7 @@ def source_ids(self): Get ids of sources in this Shot in a list. """ - return list(self._sources.keys()) + return self._source_ids @property def receiver_ids(self): @@ -152,23 +151,39 @@ def receiver_ids(self): Get ids of receivers in this Shot in a list. """ - return list(self._receivers.keys()) + return self._receiver_ids - @property + @cached_property def sources(self): """ Get sources in this Shot as a list. """ - return list(self._sources.values()) + locations = [] + for loc_id in self.source_ids: + try: + loc = self._geometry.get(loc_id) + except AttributeError: + return None + locations.append(loc) - @property + return locations + + @cached_property def receivers(self): """ Get receivers in this Shot as a list. """ - return list(self._receivers.values()) + locations = [] + for loc_id in self.receiver_ids: + try: + loc = self._geometry.get(loc_id) + except AttributeError: + return None + locations.append(loc) + + return locations @property def slow_time_index(self): @@ -453,36 +468,37 @@ def __get_desc__(self, **kwargs): return description def __set_desc__(self, description, **kwargs): + try: + del self.__dict__['sources'] + except: + pass + try: + del self.__dict__['receivers'] + except: + pass + compressed = kwargs.pop('compressed', False) self.id = description.id for source_id in description.source_ids: - if self._geometry: - source = self._geometry.get(source_id) - else: - source = None - self._sources[source_id] = source + self._source_ids.append(source_id) for receiver_id in description.receiver_ids: - if self._geometry: - receiver = self._geometry.get(receiver_id) - else: - receiver = None - self._receivers[receiver_id] = receiver + self._receiver_ids.append(receiver_id) lazy_loading = kwargs.pop('lazy_loading', False) self.wavelets = self._traces(name='wavelets', transducer_ids=self.source_ids, grid=self.grid) - if 'wavelets' in description and not lazy_loading: + if not lazy_loading and 'wavelets' in description: self.wavelets.__set_desc__(description.wavelets, **kwargs) self.observed = self._traces(name='observed', transducer_ids=self.receiver_ids, compressed=compressed, grid=self.grid) - if 'observed' in description and not lazy_loading: + if not lazy_loading and 'observed' in description: self.observed.__set_desc__(description.observed, **kwargs) self.delays = self._traces(name='delays', transducer_ids=self.source_ids, shape=(len(self.source_ids), 1), grid=self.grid) - if 'delays' in description and not lazy_loading: + if not lazy_loading and 'delays' in description: self.delays.__set_desc__(description.delays, **kwargs) @@ -532,7 +548,7 @@ def __init__(self, id, acq=0, name=None, problem=None, **kwargs): self._geometry = geometry self._acquisitions = acquisitions - self._shots = OrderedDict() + self._shots = dict() self._shot_selection = [] @property @@ -607,14 +623,13 @@ def get(self, frame): Found Shot. """ - if isinstance(frame, (np.int32, np.int64)): - frame = int(frame) - - if not isinstance(frame, int) or frame < 0: + try: + return self._shots[frame] + except KeyError: + if isinstance(frame, (np.int32, np.int64)): + return self.get(int(frame)) raise ValueError('Shot frames have to be positive integer numbers') - return self._shots[frame] - def set(self, frame, item): """ Change an existing shot in the Sequence. @@ -736,8 +751,8 @@ def __init__(self, name='acquisitions', problem=None, **kwargs): geometry = kwargs.pop('geometry', None) self._geometry = geometry - self._shots = OrderedDict() - self._sequences = OrderedDict() + self._shots = dict() + self._sequences = dict() self._shot_selection = [] self._sequence_selection = [] self._prev_load = None, None @@ -814,7 +829,7 @@ def remaining_shots(self): Get dict of all shots that have no observed allocated. """ - shots = OrderedDict() + shots = dict() for shot_id, shot in self._shots.items(): if not shot.observed.allocated: shots[shot_id] = shot @@ -887,14 +902,13 @@ def get(self, id): Found Shot. """ - if isinstance(id, (np.int32, np.int64)): - id = int(id) - - if not isinstance(id, int) or id < 0: + try: + return self._shots[id] + except KeyError: + if isinstance(id, (np.int32, np.int64)): + return self.get(int(id)) raise ValueError('Shot IDs have to be positive integer numbers') - return self._shots[id] - def get_sequence(self, id): """ Get a sequence from the Acquisitions with a known id. @@ -910,14 +924,13 @@ def get_sequence(self, id): Found Sequence. """ - if isinstance(id, (np.int32, np.int64)): - id = int(id) - - if not isinstance(id, int) or id < 0: + try: + return self._sequences[id] + except KeyError: + if isinstance(id, (np.int32, np.int64)): + return self.get(int(id)) raise ValueError('Sequence IDs have to be positive integer numbers') - return self._sequences[id] - def set(self, item): """ Change an existing shot in the Acquisitions. @@ -1338,7 +1351,7 @@ def __set_desc__(self, description, **kwargs): problem=self.problem, grid=self.grid) self.add(shot) - shot = self.get(shot_desc.id) + shot = self._shots[shot_desc.id] shot.__set_desc__(shot_desc, compressed=compressed, **kwargs) if 'sequences' in description: diff --git a/stride/problem/geometry.py b/stride/problem/geometry.py index dce199e7..28b5cdbe 100644 --- a/stride/problem/geometry.py +++ b/stride/problem/geometry.py @@ -1,5 +1,4 @@ import numpy as np -from collections import OrderedDict import mosaic.types from .base import GriddedSaved, ProblemBase @@ -136,7 +135,7 @@ def __init__(self, name='geometry', problem=None, **kwargs): else: transducers = kwargs.pop('transducers', None) - self._locations = OrderedDict() + self._locations = dict() self._transducers = transducers def add(self, id, transducer, coordinates, orientation=None): @@ -198,14 +197,13 @@ def get(self, id): Found TransducerLocation. """ - if isinstance(id, (np.int32, np.int64)): - id = int(id) - - if not isinstance(id, int) or id < 0: + try: + return self._locations[id] + except KeyError: + if isinstance(id, (np.int32, np.int64)): + return self.get(int(id)) raise ValueError('Transducer IDs have to be positive integer numbers') - return self._locations[id] - def get_slice(self, start=None, end=None, step=None): """ Get a slice of the indices of the locations using ``slice(start, stop, step)``. @@ -225,7 +223,7 @@ def get_slice(self, start=None, end=None, step=None): Found transducer locations in the slice. """ - section = OrderedDict() + section = dict() if start is None: _range = range(end) elif step is None: diff --git a/stride/problem/medium.py b/stride/problem/medium.py index 0ef919eb..042a5532 100644 --- a/stride/problem/medium.py +++ b/stride/problem/medium.py @@ -1,6 +1,4 @@ -from collections import OrderedDict - from .base import ProblemBase @@ -34,7 +32,7 @@ class Medium(ProblemBase): def __init__(self, name='medium', problem=None, **kwargs): super().__init__(name=name, problem=problem, **kwargs) - self._fields = OrderedDict() + self._fields = dict() def _get(self, item): if item in super().__getattribute__('_fields').keys(): diff --git a/stride/problem/transducers.py b/stride/problem/transducers.py index d52977db..d424d6ae 100644 --- a/stride/problem/transducers.py +++ b/stride/problem/transducers.py @@ -1,6 +1,4 @@ -from collections import OrderedDict - from mosaic.utils import camel_case from .base import ProblemBase @@ -34,7 +32,7 @@ class Transducers(ProblemBase): def __init__(self, name='transducers', problem=None, **kwargs): super().__init__(name=name, problem=problem, **kwargs) - self._transducers = OrderedDict() + self._transducers = dict() def add(self, item): """ @@ -69,11 +67,11 @@ def get(self, id): Found Transducer. """ - if not isinstance(id, int) or id < 0: + try: + return self._transducers[id] + except KeyError: raise ValueError('Transducer IDs have to be positive integer numbers') - return self._transducers[id] - def get_slice(self, start=None, end=None, step=None): """ Get a slice of the indices of the transducer using ``slice(start, stop, step)``. From ef787a891e7b00a63a7e3fa1931bcb2764b2f0f4 Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Wed, 5 Nov 2025 19:54:39 +0000 Subject: [PATCH 15/23] Fix shot subproblem --- stride/problem/acquisitions.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/stride/problem/acquisitions.py b/stride/problem/acquisitions.py index 47fc0dd2..48a4093b 100644 --- a/stride/problem/acquisitions.py +++ b/stride/problem/acquisitions.py @@ -314,12 +314,10 @@ def sub_problem(self, shot, sub_problem): grid=self.grid, geometry=sub_problem.geometry) for source_id in self.source_ids: - location = sub_problem.geometry.get(source_id) - shot._sources[location.id] = location + shot._source_ids.append(source_id) for receiver_id in self.receiver_ids: - location = sub_problem.geometry.get(receiver_id) - shot._receivers[location.id] = location + shot._receiver_ids.append(receiver_id) if self.wavelets is not None: shot.wavelets = self.wavelets From 702ea15c30f6a03818f9d2991fdaa2bb8d260a7c Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Fri, 7 Nov 2025 13:23:37 +0000 Subject: [PATCH 16/23] Introduced DiskTraces --- stride/problem/acquisitions.py | 29 ++++-- stride/problem/base.py | 2 +- stride/problem/data.py | 169 ++++++++++++++++++++++++++++++++- 3 files changed, 191 insertions(+), 9 deletions(-) diff --git a/stride/problem/acquisitions.py b/stride/problem/acquisitions.py index 48a4093b..058d4de5 100644 --- a/stride/problem/acquisitions.py +++ b/stride/problem/acquisitions.py @@ -6,7 +6,7 @@ import mosaic.types from mosaic.file_manipulation import h5 -from .data import Traces +from .data import Traces, DiskTraces from .base import ProblemBase from .. import plotting @@ -441,9 +441,12 @@ def append_observed(self, *args, **kwargs): else: self._acquisitions.dump(*args, shot_ids=[self.id], **kwargs) - @staticmethod - def _traces(*args, **kwargs): - return Traces(*args, **kwargs) + def _traces(self, *args, **kwargs): + if kwargs.pop('lazy_loading', False): + return DiskTraces(*args, **kwargs, + path='/shots/%d/%s' % (self.id, kwargs.get('name'))) + else: + return Traces(*args, **kwargs) def __get_desc__(self, **kwargs): description = { @@ -487,15 +490,27 @@ def __set_desc__(self, description, **kwargs): lazy_loading = kwargs.pop('lazy_loading', False) - self.wavelets = self._traces(name='wavelets', transducer_ids=self.source_ids, grid=self.grid) + self.wavelets = self._traces( + name='wavelets', transducer_ids=self.source_ids, + grid=self.grid, + lazy_loading=lazy_loading, filename=kwargs.get('filename', None), + ) if not lazy_loading and 'wavelets' in description: self.wavelets.__set_desc__(description.wavelets, **kwargs) - self.observed = self._traces(name='observed', transducer_ids=self.receiver_ids, compressed=compressed, grid=self.grid) + self.observed = self._traces( + name='observed', transducer_ids=self.receiver_ids, + compressed=compressed, grid=self.grid, + lazy_loading=lazy_loading, filename=kwargs.get('filename', None), + ) if not lazy_loading and 'observed' in description: self.observed.__set_desc__(description.observed, **kwargs) - self.delays = self._traces(name='delays', transducer_ids=self.source_ids, shape=(len(self.source_ids), 1), grid=self.grid) + self.delays = self._traces( + name='delays', transducer_ids=self.source_ids, + shape=(len(self.source_ids), 1), grid=self.grid, + lazy_loading=lazy_loading, filename=kwargs.get('filename', None), + ) if not lazy_loading and 'delays' in description: self.delays.__set_desc__(description.delays, **kwargs) diff --git a/stride/problem/base.py b/stride/problem/base.py index 41e1b4f1..579ab269 100644 --- a/stride/problem/base.py +++ b/stride/problem/base.py @@ -271,7 +271,7 @@ def load(self, *args, **kwargs): self._grid.slow_time = slow_time - self.__set_desc__(description, **kwargs) + self.__set_desc__(description, filename=file.filename, **kwargs) def grid_description(self): """ diff --git a/stride/problem/data.py b/stride/problem/data.py index 3a296816..d875fd70 100644 --- a/stride/problem/data.py +++ b/stride/problem/data.py @@ -18,6 +18,7 @@ import mosaic from mosaic.core.tessera import PickleClass from mosaic.comms.compression import maybe_compress, decompress +from mosaic.file_manipulation import h5 from .base import GriddedSaved from ..core import Variable @@ -25,7 +26,7 @@ __all__ = ['Data', 'StructuredData', 'Scalar', 'ScalarField', 'VectorField', 'Traces', - 'SparseField', 'SparseCoordinates'] + 'DiskTraces', 'SparseField', 'SparseCoordinates'] def inv_transform(x): @@ -1640,6 +1641,172 @@ def __set_desc__(self, description, **kwargs): self._transducer_ids = description.transducer_ids +@mosaic.tessera +class DiskTraces(Traces): + """ + Objects of this type describe a set of time traces defined over the time grid, + which are lazily loaded from disk. + + See ``Traces`` for definition of parameters. + + """ + + def __init__(self, **kwargs): + self._filename = kwargs.pop('filename', None) + self._path = kwargs.pop('path', None) + + data = kwargs.pop('data', None) + super().__init__(**kwargs) + + @property + def _data(self): + return None + + @_data.setter + def _data(self, value): + pass + + def load(self, **kwargs): + with h5.HDF5(filename=self._filename, **kwargs, mode='r') as file: + file = file.file + data = file['%s/data' % self.path][()] + + return Traces( + data=data, + transducer_ids=self.transducer_ids, + grid=self.grid, + ) + + def get(self, id): + """ + Get one trace based on a transducer ID, selecting the inner domain. + + Parameters + ---------- + id : int + Transducer ID. + + Returns + ------- + 1d-array + Time trace. + + """ + return self.load().get(id) + + def get_extended(self, id): + """ + Get one trace based on a transducer ID, selecting the extended domain. + + Parameters + ---------- + id : int + Transducer ID. + + Returns + ------- + 1d-array + Time trace. + + """ + return self.load().get_extended(id) + + def plot(self, **kwargs): + """ + Plot the inner domain of the traces as a shot gather. + + Parameters + ---------- + kwargs + Arguments for plotting. + + Returns + ------- + axes + Axes on which the plotting is done. + + """ + return self.load().plot(**kwargs) + + def plot_one(self, id, **kwargs): + """ + Plot the the inner domain of one of the traces. + + Parameters + ---------- + id : int + Transducer ID. + kwargs + Arguments for plotting. + + Returns + ------- + axes + Axes on which the plotting is done. + + """ + return self.load().plot_one(id, **kwargs) + + def _resample(self, old_step, new_step, new_num, **kwargs): + ''' + Resample the current trace to a new time-spacing. + + Parameters + ---------- + old_step: float + The original time step. + new_step: float + The new time step. + new_num + The length of the trace. + ''' + + # TODO Enable lazy resampling by marking Traces to be resampled on load + raise RuntimeError('DiskTraces cannot be resampled yet') + + def __get_desc__(self, **kwargs): + return self.load().__get_desc__(**kwargs) + + def __set_desc__(self, description, **kwargs): + del description.data + super().__set_desc__(description, **kwargs) + + _serialisation_attrs = ['name', 'uname', '_init_name', '_shape', '_extended_shape', '_inner', + '_dtype', 'needs_grad', '_compressed', '_compression', + 'transform', 'grad', 'prec', '_transducer_ids', + '_filename', '_path', '_grid'] + + def _serialisation_helper(self): + state = {} + + for attr in self._serialisation_attrs: + state[attr] = getattr(self, attr) + + return state + + @classmethod + def _deserialisation_helper(cls, state): + instance = Traces.__new__(Traces) + + with h5.HDF5(filename=state.pop('_filename'), mode='r') as file: + file = file.file + try: + data = file['%s/data' % state.pop('_path')][()] + except KeyError: + data = None + + instance._data = data + + for attr, value in state.items(): + setattr(instance, attr, value) + + return instance + + def __reduce__(self): + state = self._serialisation_helper() + return self._deserialisation_helper, (state,) + + @mosaic.tessera class SparseField(StructuredData): """ From 6cba3c64688249cea8865c9b46b197984e593bdd Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Fri, 7 Nov 2025 14:10:38 +0000 Subject: [PATCH 17/23] FIx duplicated filename --- stride/problem/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stride/problem/base.py b/stride/problem/base.py index 579ab269..e79a241f 100644 --- a/stride/problem/base.py +++ b/stride/problem/base.py @@ -271,7 +271,8 @@ def load(self, *args, **kwargs): self._grid.slow_time = slow_time - self.__set_desc__(description, filename=file.filename, **kwargs) + kwargs['filename'] = kwargs.pop('filename', file.filename) + self.__set_desc__(description, **kwargs) def grid_description(self): """ From bb7599a4ce437293a7139f6d0762834e15a480d7 Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Tue, 18 Nov 2025 14:15:24 +0000 Subject: [PATCH 18/23] Test await returns in core --- stride/core.py | 8 +++++--- stride/problem/base.py | 3 ++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/stride/core.py b/stride/core.py index e08de564..e0647424 100644 --- a/stride/core.py +++ b/stride/core.py @@ -401,11 +401,11 @@ def _dealloc(*args): prev[nxt.name_idx] = input_grad - parallel_inits = [task.init_future for task in parallel_returns] + # parallel_inits = [task.init_future for task in parallel_returns] eager = not len(returns) or returns[-1]._eager if eager: - await asyncio.gather(*(returns + parallel_inits)) + await asyncio.gather(*(returns + parallel_returns)) else: summ_returns = [] summ_dependencies = [] @@ -415,10 +415,12 @@ def _dealloc(*args): for runtime_deps in ret._dependencies.values(): summ_dependencies += list(runtime_deps.values()) - await asyncio.gather(*(summ_returns + parallel_inits)) + await asyncio.gather(*(summ_returns + parallel_returns)) self.clear_graph() + mosaic.logger().info('Done %s' % mosaic.runtime().uid) + return self def detach(self, *args, **kwargs): diff --git a/stride/problem/base.py b/stride/problem/base.py index e79a241f..cbbbc0e2 100644 --- a/stride/problem/base.py +++ b/stride/problem/base.py @@ -72,7 +72,7 @@ def slow_time(self): def resample(self, grid=None, space=None, time=None, slow_time=None): raise NotImplementedError('Resampling has not been implemented in this class yet.' + - ' Alternatively, try to access space_resample or time_resample via problem.') + ' Alternatively, try to access space_resample or time_resample via problem.') class Saved: @@ -149,6 +149,7 @@ def load(self, *args, **kwargs): with h5.HDF5(*args, **kwargs, mode='r') as file: description = file.load(filter=kwargs.pop('filter', None), only=kwargs.pop('only', None)) + kwargs['filename'] = kwargs.pop('filename', file.filename) self.__set_desc__(description, **kwargs) def rm(self, *args, **kwargs): From 7acaf45b68e16039ebf552f1e7809745918eb09b Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Wed, 19 Nov 2025 11:59:56 +0000 Subject: [PATCH 19/23] Test bottleneck --- mosaic/runtime/monitor.py | 49 +++++++++++++++++++++++++++++++++++++++ stride/core.py | 8 +++---- 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/mosaic/runtime/monitor.py b/mosaic/runtime/monitor.py index 30f25909..bf291c89 100644 --- a/mosaic/runtime/monitor.py +++ b/mosaic/runtime/monitor.py @@ -5,6 +5,7 @@ import asyncio import datetime import subprocess as cmd_subprocess +from collections import defaultdict import mosaic from .runtime import Runtime, RuntimeProxy @@ -75,6 +76,9 @@ def __init__(self, **kwargs): self._monitored_tessera = dict() self._monitored_tasks = dict() + self._runtime_tessera = defaultdict(list) + self._runtime_tasks = defaultdict(list) + self._dirty_tessera = set() self._dirty_tasks = set() @@ -329,6 +333,7 @@ def add_task_profile(self, sender_id, msgs): def _add_tessera_event(self, sender_id, runtime_id, uid, **kwargs): if uid not in self._monitored_tessera: self._monitored_tessera[uid] = MonitoredObject(runtime_id, uid) + self._runtime_tessera[runtime_id].append(uid) obj = self._monitored_tessera[uid] obj.add_event(sender_id, **kwargs) @@ -338,6 +343,7 @@ def _add_tessera_event(self, sender_id, runtime_id, uid, **kwargs): def _add_task_event(self, sender_id, runtime_id, uid, tessera_id, **kwargs): if uid not in self._monitored_tasks: self._monitored_tasks[uid] = MonitoredObject(runtime_id, uid, tessera_id=tessera_id) + self._runtime_tasks[runtime_id].append(uid) obj = self._monitored_tasks[uid] obj.add_event(sender_id, **kwargs) @@ -347,6 +353,7 @@ def _add_task_event(self, sender_id, runtime_id, uid, tessera_id, **kwargs): def _add_tessera_profile(self, sender_id, runtime_id, uid, profile): if uid not in self._monitored_tessera: self._monitored_tessera[uid] = MonitoredObject(runtime_id, uid) + self._runtime_tessera[runtime_id].append(uid) obj = self._monitored_tessera[uid] obj.add_profile(sender_id, profile) @@ -355,6 +362,7 @@ def _add_tessera_profile(self, sender_id, runtime_id, uid, profile): def _add_task_profile(self, sender_id, runtime_id, uid, tessera_id, profile): if uid not in self._monitored_tasks: self._monitored_tasks[uid] = MonitoredObject(runtime_id, uid, tessera_id=tessera_id) + self._runtime_tasks[runtime_id].append(uid) obj = self._monitored_tasks[uid] obj.add_profile(sender_id, profile) @@ -457,6 +465,43 @@ async def stop(self, sender_id=None): await super().stop(sender_id) + def disconnect(self, sender_id, uid): + """ + Disconnect specific remote runtime. + + Parameters + ---------- + sender_id : str + uid : str + + Returns + ------- + + """ + super().disconnect(sender_id, uid) + + # remove runtime from monitored nodes + try: + del self._monitored_nodes[uid] + except KeyError: + pass + + # remove monitored tessera + for obj_uid in self._runtime_tessera[uid]: + try: + del self._monitored_tessera[obj_uid] + except KeyError: + pass + del self._runtime_tessera[uid] + + # remove monitored tasks + for obj_uid in self._runtime_tasks[uid]: + try: + del self._monitored_tasks[obj_uid] + except KeyError: + pass + del self._runtime_tasks[uid] + async def select_worker(self, sender_id): """ Select appropriate worker to allocate a tessera. @@ -511,3 +556,7 @@ async def barrier(self, sender_id, timeout=None): if timeout is not None and (time.time() - tic) > timeout: break + + self._monitored_tasks = dict() + self._runtime_tasks = defaultdict(list) + self._dirty_tasks = set() diff --git a/stride/core.py b/stride/core.py index e0647424..cc4e2737 100644 --- a/stride/core.py +++ b/stride/core.py @@ -401,11 +401,11 @@ def _dealloc(*args): prev[nxt.name_idx] = input_grad - # parallel_inits = [task.init_future for task in parallel_returns] + parallel_inits = [task.init_future for task in parallel_returns] eager = not len(returns) or returns[-1]._eager if eager: - await asyncio.gather(*(returns + parallel_returns)) + await asyncio.gather(*(returns + parallel_inits)) else: summ_returns = [] summ_dependencies = [] @@ -415,11 +415,11 @@ def _dealloc(*args): for runtime_deps in ret._dependencies.values(): summ_dependencies += list(runtime_deps.values()) - await asyncio.gather(*(summ_returns + parallel_returns)) + await asyncio.gather(*(summ_returns + parallel_inits)) self.clear_graph() - mosaic.logger().info('Done %s' % mosaic.runtime().uid) + mosaic.logger().info('==> Done %s' % mosaic.runtime().uid) return self From e42ac6c6ae4f61a73e831a9609572161538e044a Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Wed, 19 Nov 2025 14:07:14 +0000 Subject: [PATCH 20/23] Test warehouse accumulation --- stride/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stride/core.py b/stride/core.py index cc4e2737..f0bdcefc 100644 --- a/stride/core.py +++ b/stride/core.py @@ -419,8 +419,6 @@ def _dealloc(*args): self.clear_graph() - mosaic.logger().info('==> Done %s' % mosaic.runtime().uid) - return self def detach(self, *args, **kwargs): @@ -575,6 +573,8 @@ async def __call_adjoint__(self, grad, **kwargs): self.grad += grad + mosaic.logger().info('==> Done %s' % mosaic.runtime().uid) + def __repr__(self): return self.name From 3eee9fd6f2867e165196765ecb87f0032b76077b Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Fri, 28 Nov 2025 14:30:51 +0000 Subject: [PATCH 21/23] Change time_M for iso-acoustic --- stride/core.py | 2 -- stride/physics/iso_acoustic/devito.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/stride/core.py b/stride/core.py index f0bdcefc..e08de564 100644 --- a/stride/core.py +++ b/stride/core.py @@ -573,8 +573,6 @@ async def __call_adjoint__(self, grad, **kwargs): self.grad += grad - mosaic.logger().info('==> Done %s' % mosaic.runtime().uid) - def __repr__(self): return self.name diff --git a/stride/physics/iso_acoustic/devito.py b/stride/physics/iso_acoustic/devito.py index 6a8c3f55..dcb0d7be 100644 --- a/stride/physics/iso_acoustic/devito.py +++ b/stride/physics/iso_acoustic/devito.py @@ -576,7 +576,7 @@ def run_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs): time_bounds = kwargs.get('time_bounds', (0, self.time.extended_num)) op.run(dt=self.time.step, time_m=1, - time_M=time_bounds[1]-1, + time_M=time_bounds[1]-2, **functions, **devito_args) From 66caffd3fa3c80081a09b4600e1d4b347afb9d4c Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Tue, 2 Dec 2025 09:39:36 +0000 Subject: [PATCH 22/23] Quick fix to Devito issue --- stride/physics/iso_acoustic/devito.py | 1 + 1 file changed, 1 insertion(+) diff --git a/stride/physics/iso_acoustic/devito.py b/stride/physics/iso_acoustic/devito.py index dcb0d7be..0c88a513 100644 --- a/stride/physics/iso_acoustic/devito.py +++ b/stride/physics/iso_acoustic/devito.py @@ -722,6 +722,7 @@ def before_adjoint(self, adjoint_source, wavelets, vp, rho=None, alpha=None, **k time_bounds = kwargs.get('time_bounds', (0, self.time.extended_num)) fw3d_mode = kwargs.pop('fw3d_mode', False) + fw3d_mode = True # TODO Force fw3d_mode pending Devito-side fix platform = kwargs.get('platform', 'cpu') is_nvidia = platform is not None and 'nvidia' in platform From c06be079fc71565142743ac119568d7fd481cdee Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Tue, 2 Dec 2025 09:48:41 +0000 Subject: [PATCH 23/23] Flake8 fix --- stride/optimisation/optimisers/optimiser.py | 2 +- stride/physics/iso_acoustic/devito.py | 4 ++-- stride/problem/transducers.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/stride/optimisation/optimisers/optimiser.py b/stride/optimisation/optimisers/optimiser.py index 64f5f306..dc71c864 100644 --- a/stride/optimisation/optimisers/optimiser.py +++ b/stride/optimisation/optimisers/optimiser.py @@ -1,6 +1,6 @@ import numpy as np -from abc import ABC, abstractmethod +from abc import ABC import mosaic diff --git a/stride/physics/iso_acoustic/devito.py b/stride/physics/iso_acoustic/devito.py index 0c88a513..44fcfd0e 100644 --- a/stride/physics/iso_acoustic/devito.py +++ b/stride/physics/iso_acoustic/devito.py @@ -421,7 +421,7 @@ def before_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs): if dump_forward_wavefield: if dump_wavefield_id == shot.id: factor = dump_forward_wavefield \ - if type(dump_forward_wavefield) == int else self.undersampling_factor + if type(dump_forward_wavefield) is int else self.undersampling_factor layers = devito.Host if is_nvidia else devito.NoLayers p_dump = self.dev_grid.undersampled_time_function('p_dump', time_bounds=time_bounds, @@ -765,7 +765,7 @@ def before_adjoint(self, adjoint_source, wavelets, vp, rho=None, alpha=None, **k dump_wavefield_id = kwargs.pop('dump_wavefield_id', shot.id) if dump_adjoint_wavefield and dump_wavefield_id == shot.id: factor = dump_adjoint_wavefield \ - if type(dump_adjoint_wavefield) == int else self.undersampling_factor + if type(dump_adjoint_wavefield) is int else self.undersampling_factor layers = devito.Host if is_nvidia else devito.NoLayers p_dump = self.dev_grid.undersampled_time_function('p_a_dump', time_bounds=time_bounds, diff --git a/stride/problem/transducers.py b/stride/problem/transducers.py index d424d6ae..938b3ed9 100644 --- a/stride/problem/transducers.py +++ b/stride/problem/transducers.py @@ -91,7 +91,7 @@ def get_slice(self, start=None, end=None, step=None): Found transducers in the slice. """ - section = OrderedDict() + section = dict() if start is None: _range = range(end) elif step is None: