diff --git a/distarray/globalapi/context.py b/distarray/globalapi/context.py index 2809e142..ae4a2da6 100644 --- a/distarray/globalapi/context.py +++ b/distarray/globalapi/context.py @@ -30,7 +30,7 @@ from distarray.localapi.proxyize import Proxy # mpi context -from distarray.mpionly_utils import (make_targets_comm, get_nengines, +from distarray.mpionly_utils import (make_targets_comm, get_world_rank, initial_comm_setup, is_solo_mpi_process, get_comm_world, mpi, push_function) @@ -710,6 +710,8 @@ def get_size(): def _make_new_comm(rank_list): import distarray.localapi.mpiutils as mpiutils new_comm = mpiutils.create_comm_with_list(rank_list) + if not mpiutils.get_base_comm(): + mpiutils.set_base_comm(new_comm) return proxyize(new_comm) # noqa return self.apply(_make_new_comm, args=(ranks,), @@ -859,8 +861,7 @@ def __init__(self, targets=None): MPIContext.INTERCOMM = initial_comm_setup() assert get_world_rank() == 0 - self.nengines = get_nengines() - + self.nengines = MPIContext.INTERCOMM.remote_size self.all_targets = list(range(self.nengines)) self.targets = self.all_targets if targets is None else sorted(targets) diff --git a/distarray/globalapi/tests/test_context.py b/distarray/globalapi/tests/test_context.py index 09848eeb..bac00754 100644 --- a/distarray/globalapi/tests/test_context.py +++ b/distarray/globalapi/tests/test_context.py @@ -22,7 +22,7 @@ from distarray.testing import DefaultContextTestCase, IPythonContextTestCase, check_targets from distarray.globalapi.context import Context from distarray.globalapi.maps import Distribution -from distarray.mpionly_utils import is_solo_mpi_process, get_nengines +from distarray.mpionly_utils import is_solo_mpi_process from distarray.localapi import LocalArray @@ -263,12 +263,14 @@ def test_create_context(self): def test_create_Context_with_targets(self): """Can we create a context with a subset of engines?""" - check_targets(required=2, available=get_nengines()) + from distarray.globalapi.context import MPIContext + check_targets(required=2, available=MPIContext.INTERCOMM.remote_size) Context(targets=[0, 1]) def test_create_Context_with_targets_ranks(self): """Check that the target <=> rank mapping is consistent.""" - check_targets(required=4, available=get_nengines()) + from distarray.globalapi.context import MPIContext + check_targets(required=4, available=MPIContext.INTERCOMM.remote_size) targets = [3, 2] dac = Context(targets=targets) self.assertEqual(set(dac.targets), set(targets)) @@ -399,6 +401,24 @@ def foo(): self.assertEqual(set(r[0].name for r in res), set([res[0][0].name])) self.assertEqual(set(r[-1].name for r in res), set([res[0][-1].name])) +class TestGetBaseComm(DefaultContextTestCase): + + ntargets = 'any' + + def test_get_base_comm(self): + + def local_test_type(): + from mpi4py.MPI import Intracomm + from distarray.localapi.mpiutils import get_base_comm + bc = get_base_comm() + return isinstance(bc, Intracomm), bc.rank, bc.size + + test, ranks, sizes = zip(*self.context.apply(local_test_type, ())) + + self.assertTrue(all(test)) + self.assertSetEqual(set(ranks), set(range(len(ranks)))) + self.assertSetEqual(set(sizes), set([len(sizes)])) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/distarray/localapi/__init__.py b/distarray/localapi/__init__.py index 0f4f80d2..1570895b 100644 --- a/distarray/localapi/__init__.py +++ b/distarray/localapi/__init__.py @@ -12,3 +12,4 @@ from distarray.localapi import localarray from distarray.localapi.localarray import * +from distarray.localapi.mpiutils import get_base_comm diff --git a/distarray/localapi/mpiutils.py b/distarray/localapi/mpiutils.py index 6188710b..006cb79b 100644 --- a/distarray/localapi/mpiutils.py +++ b/distarray/localapi/mpiutils.py @@ -12,6 +12,13 @@ from mpi4py import MPI from distarray.error import InvalidCommSizeError, InvalidRankError +def get_base_comm(): + return _BASE_COMM + +_BASE_COMM = None +def set_base_comm(comm): + global _BASE_COMM + _BASE_COMM = comm def get_comm_private(): return MPI.COMM_WORLD.Clone() diff --git a/distarray/mpionly_utils.py b/distarray/mpionly_utils.py index 5e820e6a..5d109064 100644 --- a/distarray/mpionly_utils.py +++ b/distarray/mpionly_utils.py @@ -60,13 +60,6 @@ def reassemble_and_store_func(key_dummy_container, func_data): targets=context.targets) -def get_nengines(): - """Get the number of engines which must be COMM_WORLD.size - 1 (for the - client) - """ - return get_comm_world().size - 1 - - def _set_on_main(name, obj): """Add obj as an attribute to the __main__ module with alias `name` like: __main__.name = obj @@ -74,42 +67,6 @@ def _set_on_main(name, obj): return Proxy(name, obj, '__main__') -def make_intercomm(targets=None): - world = get_comm_world() - world_rank = world.rank - # create a comm that is split into client and engines. - targets = targets or list(range(world.size - 1)) - - if world_rank == client_rank: - split_world = world.Split(0, 0) - else: - split_world = world.Split(1, world_rank) - - # create the intercomm - if world_rank == client_rank: - intercomm = split_world.Create_intercomm(0, world, 1) - else: - intercomm = split_world.Create_intercomm(0, world, 0) - return intercomm - - -def make_base_comm(): - """ - Creates an intracomm consisting of all the engines. Then sets: - `__main__._base_comm = comm_name` - """ - world = get_comm_world() - if world.rank == 0: - comm_name = uid() - else: - comm_name = '' - comm_name = world.bcast(comm_name) - - engines = world.group.Excl([client_rank]) - engine_comm = world.Create(engines) - return _set_on_main(comm_name, engine_comm) - - def make_targets_comm(targets): world = get_comm_world() world_rank = world.rank @@ -140,21 +97,6 @@ def make_targets_comm(targets): return _set_on_main(comm_name, targets_comm) -def setup_engine_comm(targets=None): - # create a comm that is split into client and engines. - world = get_comm_world() - world_rank = world.rank - targets = range(world.size - 1) if targets is None else targets - name = uid() - if world_rank == client_rank: - split_world = world.Split(0, 0) - elif (world_rank + 1) in targets: - split_world = world.Split(1, world_rank) - _set_on_main(name, split_world) - else: - world.Split(2, world_rank) - - def initial_comm_setup(): """Setup client and engine intracomm, and intercomm.""" world = get_comm_world() @@ -165,6 +107,8 @@ def initial_comm_setup(): split_world = world.Split(0, 0) else: split_world = world.Split(1, world_rank) + from distarray.localapi.mpiutils import set_base_comm + set_base_comm(split_world) # create the intercomm if world_rank == client_rank: