Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 45 additions & 47 deletions distarray/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,86 +68,84 @@ def __init__(self, client=None, targets=None):
"import distarray.mpiutils; "
"import numpy")

self._setup_key_context()
self._make_intracomm()
self.context_key = self._setup_context_key()
self._comm_key = self._make_intracomm()
self._set_engine_rank_mapping()

def _set_engine_rank_mapping(self):
# The MPI intracomm referred to by self._comm_key may have a different
# mapping between IPython engines and MPI ranks than COMM_PRIVATE. We
# reorder self.targets so self.targets[i] is the IPython engine ID that
# corresponds to MPI intracomm rank i.
rank = self._generate_key()
self.view.execute(
'%s = %s.Get_rank()' % (rank, self._comm_key),
block=True, targets=self.targets)

# mapping target -> rank, rank -> target.
rank_from_target = self.view.pull(rank, targets=self.targets).get_dict()
target_from_rank = {v: k for (k, v) in rank_from_target.items()}

# ensure consistency
assert set(self.targets) == set(rank_from_target)
assert set(range(len(self.targets))) == set(target_from_rank)

# reorder self.targets so that the targets are in MPI rank order for
# the intracomm.
self.targets = [target_from_rank[i] for i in range(len(target_from_rank))]
def _setup_context_key(self):
"""
Create a dict on the engines which will hold everything from
this context.
"""
context_key = DISTARRAY_BASE_NAME + self.uid()
cmd = '%s = {}' % (context_key)
self._execute(cmd, targets=range(len(self.view)))
return context_key

def _make_intracomm(self):
def get_rank():
from distarray.mpiutils import COMM_PRIVATE
return COMM_PRIVATE.Get_rank()

# get a mapping of IPython engine ID to MPI rank
rank_map = self.view.apply_async(get_rank).get_dict()
ranks = [ rank_map[engine] for engine in self.targets ]

# self.view's engines must encompass all ranks in the MPI communicator,
# i.e., everything in rank_map.values().
def get_size():
from distarray.mpiutils import COMM_PRIVATE
return COMM_PRIVATE.Get_size()

# get a mapping of IPython engine ID to MPI rank
rank_map = self.view.apply_async(get_rank).get_dict()
ranks = [ rank_map[engine] for engine in self.targets ]

comm_size = self.view.apply_async(get_size).get()[0]
if set(rank_map.values()) != set(range(comm_size)):
raise ValueError('Engines in view must encompass all MPI ranks.')

# create a new communicator with the subset of engines note that
# MPI_Comm_create must be called on all engines, not just those
# involved in the new communicator.
self._comm_key = self._generate_key()
comm_key = self._generate_key()
cmd = "%s = distarray.mpiutils.create_comm_with_list(%s)"
cmd %= (comm_key, ranks)
self.view.execute(cmd, block=True)
return comm_key

def _set_engine_rank_mapping(self):
# The MPI intracomm referred to by self._comm_key may have a different
# mapping between IPython engines and MPI ranks than COMM_PRIVATE. We
# reorder self.targets so self.targets[i] is the IPython engine ID that
# corresponds to MPI intracomm rank i.
rank = self._generate_key()
self.view.execute(
'%s = distarray.mpiutils.create_comm_with_list(%s)' % (self._comm_key, ranks),
block=True
)
'%s = %s.Get_rank()' % (rank, self._comm_key),
block=True, targets=self.targets)

# Key management routines:
# mapping target -> rank, rank -> target.
rank_from_target = self.view.pull(rank, targets=self.targets).get_dict()
target_from_rank = {v: k for (k, v) in rank_from_target.items()}

def _setup_key_context(self):
""" Generate a unique string for this context.
# ensure consistency
assert set(self.targets) == set(rank_from_target)
assert set(range(len(self.targets))) == set(target_from_rank)

This will be included in the names of all keys we create.
This prefix allows us to delete only keys from this context.
"""
# Full length seems excessively verbose so use 16 characters.
uid = uuid.uuid4()
self.key_context = uid.hex[:16]
# reorder self.targets so that the targets are in MPI rank order for
# the intracomm.
self.targets = [target_from_rank[i] for i in range(len(target_from_rank))]

@staticmethod
def _key_basename():
def _key_prefix():
""" Get the base name for all keys. """
return DISTARRAY_BASE_NAME

def _key_prefix(self):
""" Generate a prefix for a key name for this context. """
header = self._key_basename() + '_' + self.key_context
return header
# Key management routines:
def uid(self):
"""Generate a unique valid python name."""
# Full length seems excessively verbose so use 16 characters.
return 'da' + uuid.uuid4().hex[:16]

def _generate_key(self):
""" Generate a unique key name for this context. """
uid = uuid.uuid4()
key = self._key_prefix() + '_' + uid.hex
key = "%s['%s']" % (self.context_key, 'key_' + self.uid())
return key

def _key_and_push(self, *values):
Expand Down