diff --git a/distarray/dist/decorators.py b/distarray/dist/decorators.py index 62b7aaa5..98243841 100644 --- a/distarray/dist/decorators.py +++ b/distarray/dist/decorators.py @@ -13,8 +13,8 @@ import functools from distarray.dist.distarray import DistArray -from distarray.dist.context import Context -from distarray.error import ContextError +from distarray.dist.maps import Distribution +from distarray.error import DistributionError from distarray.utils import has_exactly_one @@ -34,29 +34,30 @@ def push_fn(self, context, fn_key, fn): """Push function to the engines.""" context._push({fn_key: fn}, targets=context.targets) - def determine_context(self, args, kwargs): - """ Determine a context from a functions arguments.""" + def determine_distribution(self, args, kwargs): + """ Determine a distribution from a functions arguments.""" - contexts = [] + dists = [] # inspect args for a context for arg in args + tuple(kwargs.values()): if isinstance(arg, DistArray): - contexts.append(arg.context) - elif isinstance(arg, Context): - contexts.append(arg) + dists.append(arg.distribution) + elif isinstance(arg, Distribution): + dists.append(arg) # check the args had a context - if contexts == []: - raise TypeError('Function must take DistArray or Context objects.') + if dists == []: + raise TypeError('Function must take DistArray or Distribution' + ' objects.') # check that all contexts are equal - if not contexts.count(contexts[0]) == len(contexts): - msg = ("Arguments must use the same Context (given arguments of " - "type %r)") - msg %= (tuple(set(contexts)),) - raise ContextError(msg) + if not dists.count(dists[0]) == len(dists): + msg = ("Arguments must use the same Distribution (given arguments " + "of type %r)") + msg %= (tuple(set(dists)),) + raise DistributionError(msg) - return contexts[0] + return dists[0] def key_and_push_args(self, args, kwargs, context=None, da_handler=None): """ @@ -76,7 +77,8 @@ def key_and_push_args(self, args, kwargs, context=None, da_handler=None): """ if context is None: - context = self.determine_context(args, kwargs) + distribution = self.determine_distribution(args, kwargs) + context = distribution.context # handle positional arguments arg_keys = [] @@ -114,7 +116,7 @@ def key_and_push_args(self, args, kwargs, context=None, da_handler=None): return arg_str, kwarg_str - def process_return_value(self, context, result_key): + def process_return_value(self, context, targets, result_key): """Figure out what to return on the Client. Parameters @@ -133,7 +135,7 @@ def process_return_value(self, context, result_key): def get_type_str(key): return str(type(key)) result_type_str = context.apply(get_type_str, args=(result_key,), - targets=context.targets) + targets=targets) def is_NoneType(typestring): return (typestring == "" or @@ -144,11 +146,11 @@ def is_LocalArray(typestring): "LocalArray'>") if all(is_LocalArray(r) for r in result_type_str): - result = DistArray.from_localarrays(result_key, context=context) + result = DistArray.from_localarrays(result_key, context=context, targets=targets) elif all(is_NoneType(r) for r in result_type_str): result = None else: - result = context._pull(result_key, targets=context.targets) + result = context._pull(result_key, targets=targets) if has_exactly_one(result): result = next(x for x in result if x is not None) @@ -160,7 +162,8 @@ class local(DecoratorBase): def __call__(self, *args, **kwargs): # get context from args - context = self.determine_context(args, kwargs) + distribution = self.determine_distribution(args, kwargs) + context = distribution.context # push function self.push_fn(context, self.fn_key, self.fn) @@ -170,9 +173,9 @@ def __call__(self, *args, **kwargs): exec_str = "%s = %s(*%s, **%s)" exec_str %= (result_key, self.fn_key, args, kwargs) - context._execute(exec_str, targets=context.targets) + context._execute(exec_str, targets=distribution.targets) - return self.process_return_value(context, result_key) + return self.process_return_value(context, distribution.targets, result_key) class vectorize(DecoratorBase): @@ -186,12 +189,13 @@ def get_ndarray(self, da, arg_keys): def __call__(self, *args, **kwargs): # get context from args - context = self.determine_context(args, kwargs) + distribution = self.determine_distribution(args, kwargs) + context = distribution.context # push function self.push_fn(context, self.fn_key, self.fn) # vectorize the function exec_str = "%s = numpy.vectorize(%s)" % (self.fn_key, self.fn_key) - context._execute(exec_str, targets=context.targets) + context._execute(exec_str, targets=distribution.targets) # Find the first distarray, they should all be the same up to the data. for arg in args: @@ -208,5 +212,5 @@ def __call__(self, *args, **kwargs): "%s(*%s, **%s)") exec_str %= (out.key, out.key, self.fn_key, args_str, kwargs_str) - context._execute(exec_str, targets=context.targets) + context._execute(exec_str, targets=distribution.targets) return out diff --git a/distarray/dist/tests/test_decorators.py b/distarray/dist/tests/test_decorators.py index f63107f0..47fa06a9 100644 --- a/distarray/dist/tests/test_decorators.py +++ b/distarray/dist/tests/test_decorators.py @@ -21,29 +21,30 @@ from distarray.dist.context import Context from distarray.dist.maps import Distribution from distarray.dist.decorators import DecoratorBase, local, vectorize -from distarray.error import ContextError +from distarray.error import DistributionError class TestDecoratorBase(TestCase): - def test_determine_context(self): + def test_determine_distribution(self): context = Context() context2 = Context() # for cross Context checking - distribution = Distribution(context, (2, 2)) - da = context.ones(distribution) + dist = Distribution(context, (2, 2)) + dist2 = Distribution(context2, (2, 2)) + da = context.ones(dist) def dummy_func(*args, **kwargs): fn = lambda x: x db = DecoratorBase(fn) - return db.determine_context(args, kwargs) + return db.determine_distribution(args, kwargs) - self.assertEqual(dummy_func(6, 7, context), context) - self.assertEqual(dummy_func('ab', da), context) - self.assertEqual(dummy_func(a=da), context) - self.assertEqual(dummy_func(context, a=da), context) + self.assertEqual(dummy_func(6, 7, dist), dist) + self.assertEqual(dummy_func('ab', da), dist) + self.assertEqual(dummy_func(a=da), dist) + self.assertEqual(dummy_func(dist, a=da), dist) self.assertRaises(TypeError, dummy_func, 'foo') - self.assertRaises(ContextError, dummy_func, context, context2) + self.assertRaises(DistributionError, dummy_func, dist, dist2) def test_key_and_push_args(self): context = Context() @@ -107,8 +108,7 @@ def local_sum(da): @local def call_barrier(da): - from mpi4py import MPI - MPI.COMM_WORLD.Barrier() + da.comm.Barrier() return da @local @@ -219,15 +219,13 @@ def test_local_add_nums(self): self.assert_allclose(df, 2 * numpy.pi + 11 + 12 + 13) def test_local_add_distarrayproxies(self): - distribution = Distribution(self.context, (5, 5)) - dg = self.context.empty(distribution) + dg = self.context.empty(self.da.distribution) dg.fill(33) dh = self.local_add_distarrayproxies(self.da, dg) self.assert_allclose(dh, 33 + 2 * numpy.pi) def test_local_add_mixed(self): - distribution = Distribution(self.context, (5, 5)) - di = self.context.empty(distribution) + di = self.context.empty(self.da.distribution) di.fill(33) dj = self.local_add_mixed(self.da, 11, di, 12) self.assert_allclose(dj, 2 * numpy.pi + 11 + 33 + 12) @@ -245,10 +243,9 @@ def test_local_add_kwargs(self): self.assert_allclose(dl, 2 * numpy.pi + 11 + 12) def test_local_add_supermix(self): - distribution = Distribution(self.context, (5, 5)) - dm = self.context.empty(distribution) + dm = self.context.empty(self.da.distribution) dm.fill(22) - dn = self.context.empty(distribution) + dn = self.context.empty(self.da.distribution) dn.fill(44) do = self.local_add_supermix(self.da, 11, dm, 33, dc=dn, num3=55) expected = 2 * numpy.pi + 11 + 22 + 33 + 44 + 55 + 66 diff --git a/distarray/dist/tests/test_distarray.py b/distarray/dist/tests/test_distarray.py index b1e10f0e..4e653228 100644 --- a/distarray/dist/tests/test_distarray.py +++ b/distarray/dist/tests/test_distarray.py @@ -722,6 +722,22 @@ def test_create_target_subset(self): self.assertEqual(len(ddpr), len(subtargets)) +class TestReductionRegression(ContextTestCase): + ''' Separate class necessary b/c need to run on at least 9 engines to exercise + regression. + + ''' + + ntargets = 9 + + def test_reduction(self): + """ Tests that GH issue 403 is fixed """ + arr = numpy.arange(9 * 9).reshape(9, 9) + dist = Distribution(self.context, arr.shape, ('b', 'b'), (9, 1)) + darr = self.context.fromndarray(arr, dist) + assert_allclose(darr.sum().tondarray(), arr.sum()) + + class TestReduceMethods(ContextTestCase): """Test reduction methods""" diff --git a/distarray/error.py b/distarray/error.py index 3a08176b..27cac7ca 100644 --- a/distarray/error.py +++ b/distarray/error.py @@ -30,3 +30,7 @@ class InvalidRankError(MPIDistArrayError): class MPICommError(MPIDistArrayError): pass + + +class DistributionError(DistArrayError): + pass diff --git a/distarray/local/localarray.py b/distarray/local/localarray.py index a89d92f3..a2a6505d 100644 --- a/distarray/local/localarray.py +++ b/distarray/local/localarray.py @@ -954,7 +954,8 @@ def _basic_reducer(reduce_comm, op, func, args, kwargs, out): out_ndarray = out.ndarray if out.ndarray.dtype == np.bool: out.ndarray.dtype = np.uint8 - local_reduce = func(*args, **kwargs) + # Use asarray() to coerce np scalars to zero-dimensional arrays. + local_reduce = np.asarray(func(*args, **kwargs)) reduce_comm.Reduce(local_reduce, out_ndarray, op=op, root=0) return out