Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions distarray/globalapi/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from distarray.localapi.proxyize import Proxy

# mpi context
from distarray.mpionly_utils import (make_targets_comm, get_nengines,
from distarray.mpionly_utils import (make_targets_comm,
get_world_rank, initial_comm_setup,
is_solo_mpi_process, get_comm_world,
mpi, push_function)
Expand Down Expand Up @@ -710,6 +710,8 @@ def get_size():
def _make_new_comm(rank_list):
import distarray.localapi.mpiutils as mpiutils
new_comm = mpiutils.create_comm_with_list(rank_list)
if not mpiutils.get_base_comm():
mpiutils.set_base_comm(new_comm)
return proxyize(new_comm) # noqa

return self.apply(_make_new_comm, args=(ranks,),
Expand Down Expand Up @@ -859,8 +861,7 @@ def __init__(self, targets=None):
MPIContext.INTERCOMM = initial_comm_setup()
assert get_world_rank() == 0

self.nengines = get_nengines()

self.nengines = MPIContext.INTERCOMM.remote_size
self.all_targets = list(range(self.nengines))
self.targets = self.all_targets if targets is None else sorted(targets)

Expand Down
26 changes: 23 additions & 3 deletions distarray/globalapi/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from distarray.testing import DefaultContextTestCase, IPythonContextTestCase, check_targets
from distarray.globalapi.context import Context
from distarray.globalapi.maps import Distribution
from distarray.mpionly_utils import is_solo_mpi_process, get_nengines
from distarray.mpionly_utils import is_solo_mpi_process
from distarray.localapi import LocalArray


Expand Down Expand Up @@ -263,12 +263,14 @@ def test_create_context(self):

def test_create_Context_with_targets(self):
"""Can we create a context with a subset of engines?"""
check_targets(required=2, available=get_nengines())
from distarray.globalapi.context import MPIContext
check_targets(required=2, available=MPIContext.INTERCOMM.remote_size)
Context(targets=[0, 1])

def test_create_Context_with_targets_ranks(self):
"""Check that the target <=> rank mapping is consistent."""
check_targets(required=4, available=get_nengines())
from distarray.globalapi.context import MPIContext
check_targets(required=4, available=MPIContext.INTERCOMM.remote_size)
targets = [3, 2]
dac = Context(targets=targets)
self.assertEqual(set(dac.targets), set(targets))
Expand Down Expand Up @@ -399,6 +401,24 @@ def foo():
self.assertEqual(set(r[0].name for r in res), set([res[0][0].name]))
self.assertEqual(set(r[-1].name for r in res), set([res[0][-1].name]))

class TestGetBaseComm(DefaultContextTestCase):

ntargets = 'any'

def test_get_base_comm(self):

def local_test_type():
from mpi4py.MPI import Intracomm
from distarray.localapi.mpiutils import get_base_comm
bc = get_base_comm()
return isinstance(bc, Intracomm), bc.rank, bc.size

test, ranks, sizes = zip(*self.context.apply(local_test_type, ()))

self.assertTrue(all(test))
self.assertSetEqual(set(ranks), set(range(len(ranks))))
self.assertSetEqual(set(sizes), set([len(sizes)]))


if __name__ == '__main__':
unittest.main(verbosity=2)
1 change: 1 addition & 0 deletions distarray/localapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@

from distarray.localapi import localarray
from distarray.localapi.localarray import *
from distarray.localapi.mpiutils import get_base_comm
7 changes: 7 additions & 0 deletions distarray/localapi/mpiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@
from mpi4py import MPI
from distarray.error import InvalidCommSizeError, InvalidRankError

def get_base_comm():
return _BASE_COMM

_BASE_COMM = None
def set_base_comm(comm):
global _BASE_COMM
_BASE_COMM = comm

def get_comm_private():
return MPI.COMM_WORLD.Clone()
Expand Down
60 changes: 2 additions & 58 deletions distarray/mpionly_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,56 +60,13 @@ def reassemble_and_store_func(key_dummy_container, func_data):
targets=context.targets)


def get_nengines():
"""Get the number of engines which must be COMM_WORLD.size - 1 (for the
client)
"""
return get_comm_world().size - 1


def _set_on_main(name, obj):
"""Add obj as an attribute to the __main__ module with alias `name` like:
__main__.name = obj
"""
return Proxy(name, obj, '__main__')


def make_intercomm(targets=None):
world = get_comm_world()
world_rank = world.rank
# create a comm that is split into client and engines.
targets = targets or list(range(world.size - 1))

if world_rank == client_rank:
split_world = world.Split(0, 0)
else:
split_world = world.Split(1, world_rank)

# create the intercomm
if world_rank == client_rank:
intercomm = split_world.Create_intercomm(0, world, 1)
else:
intercomm = split_world.Create_intercomm(0, world, 0)
return intercomm


def make_base_comm():
"""
Creates an intracomm consisting of all the engines. Then sets:
`__main__._base_comm = comm_name`
"""
world = get_comm_world()
if world.rank == 0:
comm_name = uid()
else:
comm_name = ''
comm_name = world.bcast(comm_name)

engines = world.group.Excl([client_rank])
engine_comm = world.Create(engines)
return _set_on_main(comm_name, engine_comm)


def make_targets_comm(targets):
world = get_comm_world()
world_rank = world.rank
Expand Down Expand Up @@ -140,21 +97,6 @@ def make_targets_comm(targets):
return _set_on_main(comm_name, targets_comm)


def setup_engine_comm(targets=None):
# create a comm that is split into client and engines.
world = get_comm_world()
world_rank = world.rank
targets = range(world.size - 1) if targets is None else targets
name = uid()
if world_rank == client_rank:
split_world = world.Split(0, 0)
elif (world_rank + 1) in targets:
split_world = world.Split(1, world_rank)
_set_on_main(name, split_world)
else:
world.Split(2, world_rank)


def initial_comm_setup():
"""Setup client and engine intracomm, and intercomm."""
world = get_comm_world()
Expand All @@ -165,6 +107,8 @@ def initial_comm_setup():
split_world = world.Split(0, 0)
else:
split_world = world.Split(1, world_rank)
from distarray.localapi.mpiutils import set_base_comm
set_base_comm(split_world)

# create the intercomm
if world_rank == client_rank:
Expand Down