From 2f5bd19e522d57c09594639713cbddf7f7bcdb7e Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Wed, 14 Jan 2026 15:30:05 +0000 Subject: [PATCH 1/2] New loop restart system --- stride/__init__.py | 10 +- stride/optimisation/optimisation_loop.py | 155 ++++++++++++++--------- 2 files changed, 102 insertions(+), 63 deletions(-) diff --git a/stride/__init__.py b/stride/__init__.py index 20a0351..0c22d2e 100644 --- a/stride/__init__.py +++ b/stride/__init__.py @@ -219,9 +219,6 @@ async def adjoint(problem, pde, loss, optimisation_loop, optimiser, *args, **kwa num_iters = kwargs.pop('num_iters', 1) select_shots = kwargs.pop('select_shots', {}) - restart = kwargs.pop('restart', None) - restart_id = kwargs.pop('restart_id', -1) - lazy_loading = kwargs.pop('lazy_loading', False) dump = kwargs.pop('dump', True) safe = kwargs.pop('safe', True) @@ -267,7 +264,7 @@ async def adjoint(problem, pde, loss, optimisation_loop, optimiser, *args, **kwa if optimiser.reset_block: optimiser.reset() - for iteration in block.iterations(num_iters, restart=restart, restart_id=restart_id): + for iteration in block.iterations(num_iters): optimiser.clear_grad() if optimiser.reset_iteration: @@ -281,12 +278,15 @@ async def adjoint(problem, pde, loss, optimisation_loop, optimiser, *args, **kwa (iteration.id+1, block.num_iterations, block.id+1, optimisation_loop.num_blocks)) - if dump and block.restart and not optimisation_loop.started: + if dump and optimisation_loop.restart and not optimisation_loop.started: if iteration.abs_id > 0: try: optimiser.load(path=problem.output_folder, project_name=problem.name, version=iteration.abs_id) + + logger.perf('\n') + logger.perf('Loaded optimiser variable for restart, version %d' % iteration.abs_id) except OSError: raise OSError('Optimisation loop cannot be restarted,' 'variable version or optimiser version %d cannot be found.' % diff --git a/stride/optimisation/optimisation_loop.py b/stride/optimisation/optimisation_loop.py index eaa8bc7..5e5026a 100644 --- a/stride/optimisation/optimisation_loop.py +++ b/stride/optimisation/optimisation_loop.py @@ -217,6 +217,19 @@ def next_run(self): self._runs[self._curr_run_idx] = IterationRun(self._curr_run_idx, self) return self.curr_run + def clear(self): + """ + Clear iteration. + + Returns + ------- + + """ + self._runs = { + 0: IterationRun(0, self), + } + self._curr_run_idx = 0 + def clear_run(self): """ Clear run memory. @@ -292,9 +305,9 @@ def append_iteration(self, *args, **kwargs): except AttributeError: return - self_desc = self.__get_desc__(**kwargs) - if h5.file_exists(*args, **dump_kwargs): + self_desc = self.__get_desc__(**kwargs) + description = { 'running_id': loop.running_id, 'num_blocks': loop.num_blocks, @@ -340,13 +353,20 @@ def __get_desc__(self, **kwargs): for run in self._runs.values(): description['runs'][str(run.id)] = run.__get_desc__(**kwargs) + if kwargs.pop('no_runs', False): + del description['runs'] + return description def __set_desc__(self, description, **kwargs): self.id = description.id self.abs_id = description.abs_id - runs = description.runs + try: + runs = description.runs + except AttributeError: + self.clear() + return if isinstance(runs, mosaic.types.Struct): runs = runs.values() @@ -415,7 +435,6 @@ def __init__(self, id, opt_loop, **kwargs): self._num_iterations = None self._iterations = dict() self._current_iteration = None - self.restart = False @property def num_iterations(self): @@ -477,35 +496,6 @@ def iterations(self, num, *iters, restart=None, restart_id=-1): Iteration iterables. """ - loop_restart = self._optimisation_loop.restart - restart = loop_restart if restart is None else restart - self.restart = restart - - if restart is False: - self.clear() - else: - if type(restart_id) is int and restart_id < 0: - curr_iter_id = self._current_iteration.id if self._current_iteration is not None else -1 - iteration = Iteration(curr_iter_id+1, self._optimisation_loop.running_id, - self, self._optimisation_loop) - - self._iterations[curr_iter_id+1] = iteration - self._optimisation_loop.running_id += 1 - self._current_iteration = iteration - - elif type(restart_id) is int and restart_id >= 0: - if restart_id not in self._iterations: - raise ValueError('Iteration %d does not exist, so loop cannot be ' - 'restarted from that point' % restart_id) - - self._current_iteration = self._iterations[restart_id] - - for index in range(restart_id+1, self._num_iterations): - if index in self._iterations: - del self._iterations[index] - - if self._current_iteration is not None: - self._optimisation_loop.running_id = self._current_iteration.abs_id+1 if self._num_iterations is None: self._num_iterations = num @@ -525,14 +515,17 @@ def iterations(self, num, *iters, restart=None, restart_id=-1): self._optimisation_loop.running_id += 1 self._current_iteration = self._iterations[index] + # dump the empty iteration to keep track of loop + self._current_iteration.append_iteration(no_runs=True) if len(zipped) > 1: - yield (self._iterations[index],) + zipped[1:] + yield (self._current_iteration,) + zipped[1:] else: - yield self._iterations[index] + yield self._current_iteration self._optimisation_loop.started = True - self._iterations[index].append_iteration() + # dump completed iteration to enable restart + self._current_iteration.append_iteration() def __get_desc__(self, **kwargs): legacy = kwargs.get('legacy', False) @@ -679,10 +672,10 @@ def blocks(self, num, *iters, restart=False, restart_id=-1, **kwargs): restart : int or bool, optional Whether or not attempt to restart the loop from a previous block. Defaults to ``False``. - restart_id : int, optional + restart_id : int or tuple, optional If an integer greater than zero, it will restart from - a specific block. Otherwise, it will restart from the latest - available block. + a specific iteration. If -1, it will restart from the latest + iteration. If a tuple, it will restart from ``(block_id, iteration_id)``. Returns ------- @@ -703,27 +696,74 @@ def blocks(self, num, *iters, restart=False, restart_id=-1, **kwargs): self.load(**load_kwargs) + except (OSError, AttributeError): + self.clear() + self.restart = False + + else: + # restart with an absolute iteration_id if type(restart_id) is int and restart_id >= 0: - if restart_id not in self._blocks: - raise ValueError('Block %d does not exist, so loop cannot be ' + if restart_id >= self.running_id: + raise ValueError('Iteration %d does not exist, so loop cannot be ' 'restarted from that point' % restart_id) - self._current_block = self._blocks[restart_id] - last_iter = list(self._current_block._iterations.values())[-1] - self.running_id = last_iter.abs_id + block = self._current_block + iter = list(self._current_block._iterations.values())[-1] - if restart_id-1 in self._blocks: - prev_block = self._blocks[restart_id-1] - last_iter = prev_block._iterations[prev_block.num_iterations-1] - self.running_id = last_iter.abs_id + found_iter = False + for block in self._blocks.values(): + for iter in block._iterations.values(): + if iter.abs_id >= restart_id: + found_iter = True + break + if found_iter: + break - for index in range(restart_id+1, self._num_blocks): - if index in self._blocks: - del self._blocks[index] + self._current_block = block + block._current_iteration = iter + self.running_id = iter.abs_id+1 - except (OSError, AttributeError): - self.clear() - self.restart = False + # restart with a tuple (block_id, iteration_id) + elif type(restart_id) is tuple: + block_id, iteration_id = restart_id + + # point to requested block + if block_id not in self._blocks: + raise ValueError('Block %d does not exist, so loop cannot be ' + 'restarted from that point' % block_id) + + self._current_block = self._blocks[block_id] + block = self._current_block + + # point to requested iteration + if iteration_id < 0: + last_iter = list(block._iterations.values())[-1] + self.running_id = last_iter.abs_id+1 + + else: + if iteration_id not in block._iterations: + raise ValueError('Iteration %d does not exist, so loop cannot be ' + 'restarted from that point' % iteration_id) + + block._current_iteration = block._iterations[iteration_id] + self.running_id = block._current_iteration.abs_id+1 + + # access restarted block an iteration + block = self._current_block + iter = self._current_block._current_iteration + + # clear iteration + iter.clear() + + # delete any other blocks + for index in range(block.id+1, self._num_blocks): + if index in self._blocks: + del self._blocks[index] + + # delete any other iterations + for index in range(iter.id+1, block._num_iterations): + if index in block._iterations: + del block._iterations[index] if self._num_blocks is None: self._num_blocks = num @@ -742,11 +782,9 @@ def blocks(self, num, *iters, restart=False, restart_id=-1, **kwargs): self._current_block = self._blocks[index] if len(zipped) > 1: - yield (self._blocks[index],) + zipped[1:] + yield (self._current_block,) + zipped[1:] else: - yield self._blocks[index] - - self.restart = False + yield self._current_block def dump(self, *args, **kwargs): """ @@ -767,6 +805,7 @@ def dump(self, *args, **kwargs): dump_kwargs = dict(path=self.problem.output_folder, project_name=self.problem.name, version=0) dump_kwargs.update(self._file_kwargs) + dump_kwargs.update(kwargs) super().dump(*args, **dump_kwargs) except AttributeError: From f3cb6fa32d35702b95b338905dd1c4d7f85edaf5 Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Wed, 14 Jan 2026 16:21:17 +0000 Subject: [PATCH 2/2] Filter out previous loop shots --- stride/__init__.py | 4 +++ stride/problem/acquisitions.py | 48 ++++++++++++++++++++++++++++++++-- 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/stride/__init__.py b/stride/__init__.py index 0c22d2e..02d31c2 100644 --- a/stride/__init__.py +++ b/stride/__init__.py @@ -280,6 +280,7 @@ async def adjoint(problem, pde, loss, optimisation_loop, optimiser, *args, **kwa if dump and optimisation_loop.restart and not optimisation_loop.started: if iteration.abs_id > 0: + # reload the latest version of the optimiser variable try: optimiser.load(path=problem.output_folder, project_name=problem.name, @@ -292,6 +293,9 @@ async def adjoint(problem, pde, loss, optimisation_loop, optimiser, *args, **kwa 'variable version or optimiser version %d cannot be found.' % iteration.abs_id) + # ensure previously used shots are not repeated + problem.acquisitions.filter_shot_ids(optimisation_loop, **select_shots) + shot_ids = problem.acquisitions.select_shot_ids(**select_shots) num_shots = len(shot_ids) diff --git a/stride/problem/acquisitions.py b/stride/problem/acquisitions.py index 058d4de..10363ae 100644 --- a/stride/problem/acquisitions.py +++ b/stride/problem/acquisitions.py @@ -1029,6 +1029,50 @@ def select_shot_ids(self, shot_ids=None, start=None, end=None, num=None, every=1 return next_slice + def filter_shot_ids(self, loop, **kwargs): + """ + Filter out shots that have already been used in the optimisation loop. + + Parameters + ---------- + loop : OptimisationLoop + Current optimisation loop. + shot_ids : list, optional + List of shot IDs to select from. + start : int, optional + Start of the slice, defaults to the first id. + end : int, optional + End of the slice, defaults to the last id. + num : int, optional + Number of shots to select every time the method is called. + every : int, optional + How many shots to skip in the selection, defaults to 1, which means taking all shots + subsequently. + randomly : bool, optional + Whether to select the shots at random at in order, defaults to False. + + Returns + ------- + list + List with selected shots. + + """ + shot_ids = kwargs.get('shot_ids', None) + shot_starts = self.shot_ids if shot_ids is None else shot_ids + next_slice, selection = _select_slice([], shot_starts, **kwargs) + shot_ids = next_slice + selection + + block = loop.current_block + for iter in block._iterations.values(): + try: + iter_shots = [int(shot_id) for shot_id in iter._runs[0].losses.keys()] + for shot_id in iter_shots: + shot_ids.remove(shot_id) + except KeyError: + pass + + self._shot_selection = shot_ids + def select_sequence_ids(self, start=None, end=None, num=None, every=1, randomly=False): """ Select a number of sequences according to the rules given in the arguments to the method. @@ -1046,10 +1090,10 @@ def select_sequence_ids(self, start=None, end=None, num=None, every=1, randomly= num : int, optional Number of shots to select every time the method is called. every : int, optional - How many shots to skip in the selection, defaults to 1, which means taking all shots + How many sequences to skip in the selection, defaults to 1, which means taking all sequences subsequently. randomly : bool, optional - Whether to select the shots at random at in order, defaults to False. + Whether to select the sequences at random at in order, defaults to False. Returns -------