diff --git a/distarray/dist/decorators.py b/distarray/dist/decorators.py index 62b7aaa5..98243841 100644 --- a/distarray/dist/decorators.py +++ b/distarray/dist/decorators.py @@ -13,8 +13,8 @@ import functools from distarray.dist.distarray import DistArray -from distarray.dist.context import Context -from distarray.error import ContextError +from distarray.dist.maps import Distribution +from distarray.error import DistributionError from distarray.utils import has_exactly_one @@ -34,29 +34,30 @@ def push_fn(self, context, fn_key, fn): """Push function to the engines.""" context._push({fn_key: fn}, targets=context.targets) - def determine_context(self, args, kwargs): - """ Determine a context from a functions arguments.""" + def determine_distribution(self, args, kwargs): + """ Determine a distribution from a functions arguments.""" - contexts = [] + dists = [] # inspect args for a context for arg in args + tuple(kwargs.values()): if isinstance(arg, DistArray): - contexts.append(arg.context) - elif isinstance(arg, Context): - contexts.append(arg) + dists.append(arg.distribution) + elif isinstance(arg, Distribution): + dists.append(arg) # check the args had a context - if contexts == []: - raise TypeError('Function must take DistArray or Context objects.') + if dists == []: + raise TypeError('Function must take DistArray or Distribution' + ' objects.') # check that all contexts are equal - if not contexts.count(contexts[0]) == len(contexts): - msg = ("Arguments must use the same Context (given arguments of " - "type %r)") - msg %= (tuple(set(contexts)),) - raise ContextError(msg) + if not dists.count(dists[0]) == len(dists): + msg = ("Arguments must use the same Distribution (given arguments " + "of type %r)") + msg %= (tuple(set(dists)),) + raise DistributionError(msg) - return contexts[0] + return dists[0] def key_and_push_args(self, args, kwargs, context=None, da_handler=None): """ @@ -76,7 +77,8 @@ def key_and_push_args(self, args, kwargs, context=None, da_handler=None): """ if context is None: - context = self.determine_context(args, kwargs) + distribution = self.determine_distribution(args, kwargs) + context = distribution.context # handle positional arguments arg_keys = [] @@ -114,7 +116,7 @@ def key_and_push_args(self, args, kwargs, context=None, da_handler=None): return arg_str, kwarg_str - def process_return_value(self, context, result_key): + def process_return_value(self, context, targets, result_key): """Figure out what to return on the Client. Parameters @@ -133,7 +135,7 @@ def process_return_value(self, context, result_key): def get_type_str(key): return str(type(key)) result_type_str = context.apply(get_type_str, args=(result_key,), - targets=context.targets) + targets=targets) def is_NoneType(typestring): return (typestring == "" or @@ -144,11 +146,11 @@ def is_LocalArray(typestring): "LocalArray'>") if all(is_LocalArray(r) for r in result_type_str): - result = DistArray.from_localarrays(result_key, context=context) + result = DistArray.from_localarrays(result_key, context=context, targets=targets) elif all(is_NoneType(r) for r in result_type_str): result = None else: - result = context._pull(result_key, targets=context.targets) + result = context._pull(result_key, targets=targets) if has_exactly_one(result): result = next(x for x in result if x is not None) @@ -160,7 +162,8 @@ class local(DecoratorBase): def __call__(self, *args, **kwargs): # get context from args - context = self.determine_context(args, kwargs) + distribution = self.determine_distribution(args, kwargs) + context = distribution.context # push function self.push_fn(context, self.fn_key, self.fn) @@ -170,9 +173,9 @@ def __call__(self, *args, **kwargs): exec_str = "%s = %s(*%s, **%s)" exec_str %= (result_key, self.fn_key, args, kwargs) - context._execute(exec_str, targets=context.targets) + context._execute(exec_str, targets=distribution.targets) - return self.process_return_value(context, result_key) + return self.process_return_value(context, distribution.targets, result_key) class vectorize(DecoratorBase): @@ -186,12 +189,13 @@ def get_ndarray(self, da, arg_keys): def __call__(self, *args, **kwargs): # get context from args - context = self.determine_context(args, kwargs) + distribution = self.determine_distribution(args, kwargs) + context = distribution.context # push function self.push_fn(context, self.fn_key, self.fn) # vectorize the function exec_str = "%s = numpy.vectorize(%s)" % (self.fn_key, self.fn_key) - context._execute(exec_str, targets=context.targets) + context._execute(exec_str, targets=distribution.targets) # Find the first distarray, they should all be the same up to the data. for arg in args: @@ -208,5 +212,5 @@ def __call__(self, *args, **kwargs): "%s(*%s, **%s)") exec_str %= (out.key, out.key, self.fn_key, args_str, kwargs_str) - context._execute(exec_str, targets=context.targets) + context._execute(exec_str, targets=distribution.targets) return out diff --git a/distarray/dist/tests/test_decorators.py b/distarray/dist/tests/test_decorators.py index f63107f0..47fa06a9 100644 --- a/distarray/dist/tests/test_decorators.py +++ b/distarray/dist/tests/test_decorators.py @@ -21,29 +21,30 @@ from distarray.dist.context import Context from distarray.dist.maps import Distribution from distarray.dist.decorators import DecoratorBase, local, vectorize -from distarray.error import ContextError +from distarray.error import DistributionError class TestDecoratorBase(TestCase): - def test_determine_context(self): + def test_determine_distribution(self): context = Context() context2 = Context() # for cross Context checking - distribution = Distribution(context, (2, 2)) - da = context.ones(distribution) + dist = Distribution(context, (2, 2)) + dist2 = Distribution(context2, (2, 2)) + da = context.ones(dist) def dummy_func(*args, **kwargs): fn = lambda x: x db = DecoratorBase(fn) - return db.determine_context(args, kwargs) + return db.determine_distribution(args, kwargs) - self.assertEqual(dummy_func(6, 7, context), context) - self.assertEqual(dummy_func('ab', da), context) - self.assertEqual(dummy_func(a=da), context) - self.assertEqual(dummy_func(context, a=da), context) + self.assertEqual(dummy_func(6, 7, dist), dist) + self.assertEqual(dummy_func('ab', da), dist) + self.assertEqual(dummy_func(a=da), dist) + self.assertEqual(dummy_func(dist, a=da), dist) self.assertRaises(TypeError, dummy_func, 'foo') - self.assertRaises(ContextError, dummy_func, context, context2) + self.assertRaises(DistributionError, dummy_func, dist, dist2) def test_key_and_push_args(self): context = Context() @@ -107,8 +108,7 @@ def local_sum(da): @local def call_barrier(da): - from mpi4py import MPI - MPI.COMM_WORLD.Barrier() + da.comm.Barrier() return da @local @@ -219,15 +219,13 @@ def test_local_add_nums(self): self.assert_allclose(df, 2 * numpy.pi + 11 + 12 + 13) def test_local_add_distarrayproxies(self): - distribution = Distribution(self.context, (5, 5)) - dg = self.context.empty(distribution) + dg = self.context.empty(self.da.distribution) dg.fill(33) dh = self.local_add_distarrayproxies(self.da, dg) self.assert_allclose(dh, 33 + 2 * numpy.pi) def test_local_add_mixed(self): - distribution = Distribution(self.context, (5, 5)) - di = self.context.empty(distribution) + di = self.context.empty(self.da.distribution) di.fill(33) dj = self.local_add_mixed(self.da, 11, di, 12) self.assert_allclose(dj, 2 * numpy.pi + 11 + 33 + 12) @@ -245,10 +243,9 @@ def test_local_add_kwargs(self): self.assert_allclose(dl, 2 * numpy.pi + 11 + 12) def test_local_add_supermix(self): - distribution = Distribution(self.context, (5, 5)) - dm = self.context.empty(distribution) + dm = self.context.empty(self.da.distribution) dm.fill(22) - dn = self.context.empty(distribution) + dn = self.context.empty(self.da.distribution) dn.fill(44) do = self.local_add_supermix(self.da, 11, dm, 33, dc=dn, num3=55) expected = 2 * numpy.pi + 11 + 22 + 33 + 44 + 55 + 66 diff --git a/distarray/error.py b/distarray/error.py index 3a08176b..27cac7ca 100644 --- a/distarray/error.py +++ b/distarray/error.py @@ -30,3 +30,7 @@ class InvalidRankError(MPIDistArrayError): class MPICommError(MPIDistArrayError): pass + + +class DistributionError(DistArrayError): + pass diff --git a/distarray/tests/test_metadata_utils.py b/distarray/tests/test_metadata_utils.py index 8fad5dc1..83b59330 100644 --- a/distarray/tests/test_metadata_utils.py +++ b/distarray/tests/test_metadata_utils.py @@ -8,6 +8,8 @@ from distarray import metadata_utils from distarray.dist import Distribution, Context +from distarray.testing import ContextTestCase + class TestMakeGridShape(unittest.TestCase): @@ -215,11 +217,7 @@ def test_big_step(self): self.assertEqual(result, expected) -class TestGridSizes(unittest.TestCase): - - @classmethod - def setUpClass(cls): - cls.context = Context() +class TestGridSizes(ContextTestCase): def test_dist_sizes(self): dist = Distribution(self.context, (2, 3, 4), dist=('n', 'b', 'c'))