Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions stride/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,6 @@ async def adjoint(problem, pde, loss, optimisation_loop, optimiser, *args, **kwa
num_iters = kwargs.pop('num_iters', 1)
select_shots = kwargs.pop('select_shots', {})

restart = kwargs.pop('restart', None)
restart_id = kwargs.pop('restart_id', -1)

lazy_loading = kwargs.pop('lazy_loading', False)
dump = kwargs.pop('dump', True)
safe = kwargs.pop('safe', True)
Expand Down Expand Up @@ -267,7 +264,7 @@ async def adjoint(problem, pde, loss, optimisation_loop, optimiser, *args, **kwa
if optimiser.reset_block:
optimiser.reset()

for iteration in block.iterations(num_iters, restart=restart, restart_id=restart_id):
for iteration in block.iterations(num_iters):
optimiser.clear_grad()

if optimiser.reset_iteration:
Expand All @@ -281,17 +278,24 @@ async def adjoint(problem, pde, loss, optimisation_loop, optimiser, *args, **kwa
(iteration.id+1, block.num_iterations, block.id+1,
optimisation_loop.num_blocks))

if dump and block.restart and not optimisation_loop.started:
if dump and optimisation_loop.restart and not optimisation_loop.started:
if iteration.abs_id > 0:
# reload the latest version of the optimiser variable
try:
optimiser.load(path=problem.output_folder,
project_name=problem.name,
version=iteration.abs_id)

logger.perf('\n')
logger.perf('Loaded optimiser variable for restart, version %d' % iteration.abs_id)
except OSError:
raise OSError('Optimisation loop cannot be restarted,'
'variable version or optimiser version %d cannot be found.' %
iteration.abs_id)

# ensure previously used shots are not repeated
problem.acquisitions.filter_shot_ids(optimisation_loop, **select_shots)

shot_ids = problem.acquisitions.select_shot_ids(**select_shots)
num_shots = len(shot_ids)

Expand Down
155 changes: 97 additions & 58 deletions stride/optimisation/optimisation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,19 @@ def next_run(self):
self._runs[self._curr_run_idx] = IterationRun(self._curr_run_idx, self)
return self.curr_run

def clear(self):
"""
Clear iteration.

Returns
-------

"""
self._runs = {
0: IterationRun(0, self),
}
self._curr_run_idx = 0

def clear_run(self):
"""
Clear run memory.
Expand Down Expand Up @@ -292,9 +305,9 @@ def append_iteration(self, *args, **kwargs):
except AttributeError:
return

self_desc = self.__get_desc__(**kwargs)

if h5.file_exists(*args, **dump_kwargs):
self_desc = self.__get_desc__(**kwargs)

description = {
'running_id': loop.running_id,
'num_blocks': loop.num_blocks,
Expand Down Expand Up @@ -340,13 +353,20 @@ def __get_desc__(self, **kwargs):
for run in self._runs.values():
description['runs'][str(run.id)] = run.__get_desc__(**kwargs)

if kwargs.pop('no_runs', False):
del description['runs']

return description

def __set_desc__(self, description, **kwargs):
self.id = description.id
self.abs_id = description.abs_id

runs = description.runs
try:
runs = description.runs
except AttributeError:
self.clear()
return
if isinstance(runs, mosaic.types.Struct):
runs = runs.values()

Expand Down Expand Up @@ -415,7 +435,6 @@ def __init__(self, id, opt_loop, **kwargs):
self._num_iterations = None
self._iterations = dict()
self._current_iteration = None
self.restart = False

@property
def num_iterations(self):
Expand Down Expand Up @@ -477,35 +496,6 @@ def iterations(self, num, *iters, restart=None, restart_id=-1):
Iteration iterables.

"""
loop_restart = self._optimisation_loop.restart
restart = loop_restart if restart is None else restart
self.restart = restart

if restart is False:
self.clear()
else:
if type(restart_id) is int and restart_id < 0:
curr_iter_id = self._current_iteration.id if self._current_iteration is not None else -1
iteration = Iteration(curr_iter_id+1, self._optimisation_loop.running_id,
self, self._optimisation_loop)

self._iterations[curr_iter_id+1] = iteration
self._optimisation_loop.running_id += 1
self._current_iteration = iteration

elif type(restart_id) is int and restart_id >= 0:
if restart_id not in self._iterations:
raise ValueError('Iteration %d does not exist, so loop cannot be '
'restarted from that point' % restart_id)

self._current_iteration = self._iterations[restart_id]

for index in range(restart_id+1, self._num_iterations):
if index in self._iterations:
del self._iterations[index]

if self._current_iteration is not None:
self._optimisation_loop.running_id = self._current_iteration.abs_id+1

if self._num_iterations is None:
self._num_iterations = num
Expand All @@ -525,14 +515,17 @@ def iterations(self, num, *iters, restart=None, restart_id=-1):
self._optimisation_loop.running_id += 1

self._current_iteration = self._iterations[index]
# dump the empty iteration to keep track of loop
self._current_iteration.append_iteration(no_runs=True)

if len(zipped) > 1:
yield (self._iterations[index],) + zipped[1:]
yield (self._current_iteration,) + zipped[1:]
else:
yield self._iterations[index]
yield self._current_iteration

self._optimisation_loop.started = True
self._iterations[index].append_iteration()
# dump completed iteration to enable restart
self._current_iteration.append_iteration()

def __get_desc__(self, **kwargs):
legacy = kwargs.get('legacy', False)
Expand Down Expand Up @@ -679,10 +672,10 @@ def blocks(self, num, *iters, restart=False, restart_id=-1, **kwargs):
restart : int or bool, optional
Whether or not attempt to restart the loop from a previous
block. Defaults to ``False``.
restart_id : int, optional
restart_id : int or tuple, optional
If an integer greater than zero, it will restart from
a specific block. Otherwise, it will restart from the latest
available block.
a specific iteration. If -1, it will restart from the latest
iteration. If a tuple, it will restart from ``(block_id, iteration_id)``.

Returns
-------
Expand All @@ -703,27 +696,74 @@ def blocks(self, num, *iters, restart=False, restart_id=-1, **kwargs):

self.load(**load_kwargs)

except (OSError, AttributeError):
self.clear()
self.restart = False

else:
# restart with an absolute iteration_id
if type(restart_id) is int and restart_id >= 0:
if restart_id not in self._blocks:
raise ValueError('Block %d does not exist, so loop cannot be '
if restart_id >= self.running_id:
raise ValueError('Iteration %d does not exist, so loop cannot be '
'restarted from that point' % restart_id)

self._current_block = self._blocks[restart_id]
last_iter = list(self._current_block._iterations.values())[-1]
self.running_id = last_iter.abs_id
block = self._current_block
iter = list(self._current_block._iterations.values())[-1]

if restart_id-1 in self._blocks:
prev_block = self._blocks[restart_id-1]
last_iter = prev_block._iterations[prev_block.num_iterations-1]
self.running_id = last_iter.abs_id
found_iter = False
for block in self._blocks.values():
for iter in block._iterations.values():
if iter.abs_id >= restart_id:
found_iter = True
break
if found_iter:
break

for index in range(restart_id+1, self._num_blocks):
if index in self._blocks:
del self._blocks[index]
self._current_block = block
block._current_iteration = iter
self.running_id = iter.abs_id+1

except (OSError, AttributeError):
self.clear()
self.restart = False
# restart with a tuple (block_id, iteration_id)
elif type(restart_id) is tuple:
block_id, iteration_id = restart_id

# point to requested block
if block_id not in self._blocks:
raise ValueError('Block %d does not exist, so loop cannot be '
'restarted from that point' % block_id)

self._current_block = self._blocks[block_id]
block = self._current_block

# point to requested iteration
if iteration_id < 0:
last_iter = list(block._iterations.values())[-1]
self.running_id = last_iter.abs_id+1

else:
if iteration_id not in block._iterations:
raise ValueError('Iteration %d does not exist, so loop cannot be '
'restarted from that point' % iteration_id)

block._current_iteration = block._iterations[iteration_id]
self.running_id = block._current_iteration.abs_id+1

# access restarted block an iteration
block = self._current_block
iter = self._current_block._current_iteration

# clear iteration
iter.clear()

# delete any other blocks
for index in range(block.id+1, self._num_blocks):
if index in self._blocks:
del self._blocks[index]

# delete any other iterations
for index in range(iter.id+1, block._num_iterations):
if index in block._iterations:
del block._iterations[index]

if self._num_blocks is None:
self._num_blocks = num
Expand All @@ -742,11 +782,9 @@ def blocks(self, num, *iters, restart=False, restart_id=-1, **kwargs):
self._current_block = self._blocks[index]

if len(zipped) > 1:
yield (self._blocks[index],) + zipped[1:]
yield (self._current_block,) + zipped[1:]
else:
yield self._blocks[index]

self.restart = False
yield self._current_block

def dump(self, *args, **kwargs):
"""
Expand All @@ -767,6 +805,7 @@ def dump(self, *args, **kwargs):
dump_kwargs = dict(path=self.problem.output_folder,
project_name=self.problem.name, version=0)
dump_kwargs.update(self._file_kwargs)
dump_kwargs.update(kwargs)

super().dump(*args, **dump_kwargs)
except AttributeError:
Expand Down
48 changes: 46 additions & 2 deletions stride/problem/acquisitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,50 @@ def select_shot_ids(self, shot_ids=None, start=None, end=None, num=None, every=1

return next_slice

def filter_shot_ids(self, loop, **kwargs):
"""
Filter out shots that have already been used in the optimisation loop.

Parameters
----------
loop : OptimisationLoop
Current optimisation loop.
shot_ids : list, optional
List of shot IDs to select from.
start : int, optional
Start of the slice, defaults to the first id.
end : int, optional
End of the slice, defaults to the last id.
num : int, optional
Number of shots to select every time the method is called.
every : int, optional
How many shots to skip in the selection, defaults to 1, which means taking all shots
subsequently.
randomly : bool, optional
Whether to select the shots at random at in order, defaults to False.

Returns
-------
list
List with selected shots.

"""
shot_ids = kwargs.get('shot_ids', None)
shot_starts = self.shot_ids if shot_ids is None else shot_ids
next_slice, selection = _select_slice([], shot_starts, **kwargs)
shot_ids = next_slice + selection

block = loop.current_block
for iter in block._iterations.values():
try:
iter_shots = [int(shot_id) for shot_id in iter._runs[0].losses.keys()]
for shot_id in iter_shots:
shot_ids.remove(shot_id)
except KeyError:
pass

self._shot_selection = shot_ids

def select_sequence_ids(self, start=None, end=None, num=None, every=1, randomly=False):
"""
Select a number of sequences according to the rules given in the arguments to the method.
Expand All @@ -1046,10 +1090,10 @@ def select_sequence_ids(self, start=None, end=None, num=None, every=1, randomly=
num : int, optional
Number of shots to select every time the method is called.
every : int, optional
How many shots to skip in the selection, defaults to 1, which means taking all shots
How many sequences to skip in the selection, defaults to 1, which means taking all sequences
subsequently.
randomly : bool, optional
Whether to select the shots at random at in order, defaults to False.
Whether to select the sequences at random at in order, defaults to False.

Returns
-------
Expand Down