diff --git a/distarray/dist/decorators.py b/distarray/dist/decorators.py index 62b7aaa5..ba763f82 100644 --- a/distarray/dist/decorators.py +++ b/distarray/dist/decorators.py @@ -13,8 +13,8 @@ import functools from distarray.dist.distarray import DistArray -from distarray.dist.context import Context -from distarray.error import ContextError +from distarray.dist.maps import Distribution +from distarray.error import DistributionError from distarray.utils import has_exactly_one @@ -34,29 +34,30 @@ def push_fn(self, context, fn_key, fn): """Push function to the engines.""" context._push({fn_key: fn}, targets=context.targets) - def determine_context(self, args, kwargs): - """ Determine a context from a functions arguments.""" + def determine_distribution(self, args, kwargs): + """ Determine a distribution from a functions arguments.""" - contexts = [] + dists = [] # inspect args for a context for arg in args + tuple(kwargs.values()): if isinstance(arg, DistArray): - contexts.append(arg.context) - elif isinstance(arg, Context): - contexts.append(arg) + dists.append(arg.distribution) + elif isinstance(arg, Distribution): + dists.append(arg) # check the args had a context - if contexts == []: - raise TypeError('Function must take DistArray or Context objects.') + if dists == []: + raise TypeError('Function must take DistArray or Distribution' + ' objects.') # check that all contexts are equal - if not contexts.count(contexts[0]) == len(contexts): - msg = ("Arguments must use the same Context (given arguments of " - "type %r)") - msg %= (tuple(set(contexts)),) - raise ContextError(msg) + if not dists.count(dists[0]) == len(dists): + msg = ("Arguments must use the same Distribution (given arguments " + "of type %r)") + msg %= (tuple(set(dists)),) + raise DistributionError(msg) - return contexts[0] + return dists[0] def key_and_push_args(self, args, kwargs, context=None, da_handler=None): """ @@ -76,7 +77,8 @@ def key_and_push_args(self, args, kwargs, context=None, da_handler=None): """ if context is None: - context = self.determine_context(args, kwargs) + distribution = self.determine_distribution(args, kwargs) + context = distribution.context # handle positional arguments arg_keys = [] @@ -160,7 +162,8 @@ class local(DecoratorBase): def __call__(self, *args, **kwargs): # get context from args - context = self.determine_context(args, kwargs) + distribution = self.determine_distribution(args, kwargs) + context = distribution.context # push function self.push_fn(context, self.fn_key, self.fn) @@ -170,7 +173,7 @@ def __call__(self, *args, **kwargs): exec_str = "%s = %s(*%s, **%s)" exec_str %= (result_key, self.fn_key, args, kwargs) - context._execute(exec_str, targets=context.targets) + context._execute(exec_str, targets=distribution.targets) return self.process_return_value(context, result_key) @@ -186,12 +189,13 @@ def get_ndarray(self, da, arg_keys): def __call__(self, *args, **kwargs): # get context from args - context = self.determine_context(args, kwargs) + distribution = self.determine_distribution(args, kwargs) + context = distribution.context # push function self.push_fn(context, self.fn_key, self.fn) # vectorize the function exec_str = "%s = numpy.vectorize(%s)" % (self.fn_key, self.fn_key) - context._execute(exec_str, targets=context.targets) + context._execute(exec_str, targets=distribution.targets) # Find the first distarray, they should all be the same up to the data. for arg in args: diff --git a/distarray/dist/tests/test_decorators.py b/distarray/dist/tests/test_decorators.py index e6ea707e..0c711976 100644 --- a/distarray/dist/tests/test_decorators.py +++ b/distarray/dist/tests/test_decorators.py @@ -21,29 +21,30 @@ from distarray.dist.context import Context from distarray.dist.maps import Distribution from distarray.dist.decorators import DecoratorBase, local, vectorize -from distarray.error import ContextError +from distarray.error import DistributionError class TestDecoratorBase(TestCase): - def test_determine_context(self): + def test_determine_distribution(self): context = Context() context2 = Context() # for cross Context checking - distribution = Distribution.from_shape(context, (2, 2)) - da = context.ones(distribution) + dist = Distribution.from_shape(context, (2, 2)) + dist2 = Distribution.from_shape(context2, (2, 2)) + da = context.ones(dist) def dummy_func(*args, **kwargs): fn = lambda x: x db = DecoratorBase(fn) - return db.determine_context(args, kwargs) + return db.determine_distribution(args, kwargs) - self.assertEqual(dummy_func(6, 7, context), context) - self.assertEqual(dummy_func('ab', da), context) - self.assertEqual(dummy_func(a=da), context) - self.assertEqual(dummy_func(context, a=da), context) + self.assertEqual(dummy_func(6, 7, dist), dist) + self.assertEqual(dummy_func('ab', da), dist) + self.assertEqual(dummy_func(a=da), dist) + self.assertEqual(dummy_func(dist, a=da), dist) self.assertRaises(TypeError, dummy_func, 'foo') - self.assertRaises(ContextError, dummy_func, context, context2) + self.assertRaises(DistributionError, dummy_func, dist, dist2) def test_key_and_push_args(self): context = Context() @@ -219,15 +220,13 @@ def test_local_add_nums(self): self.assert_allclose(df, 2 * numpy.pi + 11 + 12 + 13) def test_local_add_distarrayproxies(self): - distribution = Distribution.from_shape(self.context, (5, 5)) - dg = self.context.empty(distribution) + dg = self.context.empty(self.da.distribution) dg.fill(33) dh = self.local_add_distarrayproxies(self.da, dg) self.assert_allclose(dh, 33 + 2 * numpy.pi) def test_local_add_mixed(self): - distribution = Distribution.from_shape(self.context, (5, 5)) - di = self.context.empty(distribution) + di = self.context.empty(self.da.distribution) di.fill(33) dj = self.local_add_mixed(self.da, 11, di, 12) self.assert_allclose(dj, 2 * numpy.pi + 11 + 33 + 12) @@ -245,10 +244,9 @@ def test_local_add_kwargs(self): self.assert_allclose(dl, 2 * numpy.pi + 11 + 12) def test_local_add_supermix(self): - distribution = Distribution.from_shape(self.context, (5, 5)) - dm = self.context.empty(distribution) + dm = self.context.empty(self.da.distribution) dm.fill(22) - dn = self.context.empty(distribution) + dn = self.context.empty(self.da.distribution) dn.fill(44) do = self.local_add_supermix(self.da, 11, dm, 33, dc=dn, num3=55) expected = 2 * numpy.pi + 11 + 22 + 33 + 44 + 55 + 66 diff --git a/distarray/error.py b/distarray/error.py index 3a08176b..27cac7ca 100644 --- a/distarray/error.py +++ b/distarray/error.py @@ -30,3 +30,7 @@ class InvalidRankError(MPIDistArrayError): class MPICommError(MPIDistArrayError): pass + + +class DistributionError(DistArrayError): + pass