diff --git a/distarray/context.py b/distarray/context.py index 679278a4..169a1c65 100644 --- a/distarray/context.py +++ b/distarray/context.py @@ -68,47 +68,35 @@ def __init__(self, client=None, targets=None): "import distarray.mpiutils; " "import numpy") - self._setup_key_context() - self._make_intracomm() + self.context_key = self._setup_context_key() + self._comm_key = self._make_intracomm() self._set_engine_rank_mapping() - def _set_engine_rank_mapping(self): - # The MPI intracomm referred to by self._comm_key may have a different - # mapping between IPython engines and MPI ranks than COMM_PRIVATE. We - # reorder self.targets so self.targets[i] is the IPython engine ID that - # corresponds to MPI intracomm rank i. - rank = self._generate_key() - self.view.execute( - '%s = %s.Get_rank()' % (rank, self._comm_key), - block=True, targets=self.targets) - - # mapping target -> rank, rank -> target. - rank_from_target = self.view.pull(rank, targets=self.targets).get_dict() - target_from_rank = {v: k for (k, v) in rank_from_target.items()} - - # ensure consistency - assert set(self.targets) == set(rank_from_target) - assert set(range(len(self.targets))) == set(target_from_rank) - - # reorder self.targets so that the targets are in MPI rank order for - # the intracomm. - self.targets = [target_from_rank[i] for i in range(len(target_from_rank))] + def _setup_context_key(self): + """ + Create a dict on the engines which will hold everything from + this context. + """ + context_key = DISTARRAY_BASE_NAME + self.uid() + cmd = '%s = {}' % (context_key) + self._execute(cmd, targets=range(len(self.view))) + return context_key def _make_intracomm(self): def get_rank(): from distarray.mpiutils import COMM_PRIVATE return COMM_PRIVATE.Get_rank() - # get a mapping of IPython engine ID to MPI rank - rank_map = self.view.apply_async(get_rank).get_dict() - ranks = [ rank_map[engine] for engine in self.targets ] - # self.view's engines must encompass all ranks in the MPI communicator, # i.e., everything in rank_map.values(). def get_size(): from distarray.mpiutils import COMM_PRIVATE return COMM_PRIVATE.Get_size() + # get a mapping of IPython engine ID to MPI rank + rank_map = self.view.apply_async(get_rank).get_dict() + ranks = [ rank_map[engine] for engine in self.targets ] + comm_size = self.view.apply_async(get_size).get()[0] if set(rank_map.values()) != set(range(comm_size)): raise ValueError('Engines in view must encompass all MPI ranks.') @@ -116,38 +104,48 @@ def get_size(): # create a new communicator with the subset of engines note that # MPI_Comm_create must be called on all engines, not just those # involved in the new communicator. - self._comm_key = self._generate_key() + comm_key = self._generate_key() + cmd = "%s = distarray.mpiutils.create_comm_with_list(%s)" + cmd %= (comm_key, ranks) + self.view.execute(cmd, block=True) + return comm_key + + def _set_engine_rank_mapping(self): + # The MPI intracomm referred to by self._comm_key may have a different + # mapping between IPython engines and MPI ranks than COMM_PRIVATE. We + # reorder self.targets so self.targets[i] is the IPython engine ID that + # corresponds to MPI intracomm rank i. + rank = self._generate_key() self.view.execute( - '%s = distarray.mpiutils.create_comm_with_list(%s)' % (self._comm_key, ranks), - block=True - ) + '%s = %s.Get_rank()' % (rank, self._comm_key), + block=True, targets=self.targets) - # Key management routines: + # mapping target -> rank, rank -> target. + rank_from_target = self.view.pull(rank, targets=self.targets).get_dict() + target_from_rank = {v: k for (k, v) in rank_from_target.items()} - def _setup_key_context(self): - """ Generate a unique string for this context. + # ensure consistency + assert set(self.targets) == set(rank_from_target) + assert set(range(len(self.targets))) == set(target_from_rank) - This will be included in the names of all keys we create. - This prefix allows us to delete only keys from this context. - """ - # Full length seems excessively verbose so use 16 characters. - uid = uuid.uuid4() - self.key_context = uid.hex[:16] + # reorder self.targets so that the targets are in MPI rank order for + # the intracomm. + self.targets = [target_from_rank[i] for i in range(len(target_from_rank))] @staticmethod - def _key_basename(): + def _key_prefix(): """ Get the base name for all keys. """ return DISTARRAY_BASE_NAME - def _key_prefix(self): - """ Generate a prefix for a key name for this context. """ - header = self._key_basename() + '_' + self.key_context - return header + # Key management routines: + def uid(self): + """Generate a unique valid python name.""" + # Full length seems excessively verbose so use 16 characters. + return 'da' + uuid.uuid4().hex[:16] def _generate_key(self): """ Generate a unique key name for this context. """ - uid = uuid.uuid4() - key = self._key_prefix() + '_' + uid.hex + key = "%s['%s']" % (self.context_key, 'key_' + self.uid()) return key def _key_and_push(self, *values):