From 29fbaab71157d33bc90afb8e2fe54c629e80f96c Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 16 Jun 2014 10:13:35 -0500 Subject: [PATCH 1/3] Use distribution.targets instead of context.targets in decorators.py --- distarray/dist/decorators.py | 39 +++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/distarray/dist/decorators.py b/distarray/dist/decorators.py index 62b7aaa5..42126747 100644 --- a/distarray/dist/decorators.py +++ b/distarray/dist/decorators.py @@ -13,7 +13,7 @@ import functools from distarray.dist.distarray import DistArray -from distarray.dist.context import Context +from distarray.dist.maps import Distribution from distarray.error import ContextError from distarray.utils import has_exactly_one @@ -34,29 +34,30 @@ def push_fn(self, context, fn_key, fn): """Push function to the engines.""" context._push({fn_key: fn}, targets=context.targets) - def determine_context(self, args, kwargs): - """ Determine a context from a functions arguments.""" + def determine_distribution(self, args, kwargs): + """ Determine a distribution from a functions arguments.""" - contexts = [] + dists = [] # inspect args for a context for arg in args + tuple(kwargs.values()): if isinstance(arg, DistArray): - contexts.append(arg.context) - elif isinstance(arg, Context): - contexts.append(arg) + dists.append(arg.distribution) + elif isinstance(arg, Distribution): + dists.append(arg) # check the args had a context - if contexts == []: - raise TypeError('Function must take DistArray or Context objects.') + if dists == []: + raise TypeError('Function must take DistArray or Distribution' + ' objects.') # check that all contexts are equal - if not contexts.count(contexts[0]) == len(contexts): - msg = ("Arguments must use the same Context (given arguments of " - "type %r)") - msg %= (tuple(set(contexts)),) + if not dists.count(dists[0]) == len(dists): + msg = ("Arguments must use the same Distribution (given arguments " + "of type %r)") + msg %= (tuple(set(dists)),) raise ContextError(msg) - return contexts[0] + return dists[0] def key_and_push_args(self, args, kwargs, context=None, da_handler=None): """ @@ -160,7 +161,8 @@ class local(DecoratorBase): def __call__(self, *args, **kwargs): # get context from args - context = self.determine_context(args, kwargs) + distribution = self.determine_distribution(args, kwargs) + context = distribution.context # push function self.push_fn(context, self.fn_key, self.fn) @@ -170,7 +172,7 @@ def __call__(self, *args, **kwargs): exec_str = "%s = %s(*%s, **%s)" exec_str %= (result_key, self.fn_key, args, kwargs) - context._execute(exec_str, targets=context.targets) + context._execute(exec_str, targets=distribution.targets) return self.process_return_value(context, result_key) @@ -186,12 +188,13 @@ def get_ndarray(self, da, arg_keys): def __call__(self, *args, **kwargs): # get context from args - context = self.determine_context(args, kwargs) + distribution = self.determine_distribution(args, kwargs) + context = distribution.context # push function self.push_fn(context, self.fn_key, self.fn) # vectorize the function exec_str = "%s = numpy.vectorize(%s)" % (self.fn_key, self.fn_key) - context._execute(exec_str, targets=context.targets) + context._execute(exec_str, targets=distribution.targets) # Find the first distarray, they should all be the same up to the data. for arg in args: From c88eb583bf770a9ebe1a50ba2af5d8b4808c23f6 Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 16 Jun 2014 10:50:27 -0500 Subject: [PATCH 2/3] Change ContextError to DistributionError in decorators. --- distarray/dist/decorators.py | 7 ++++--- distarray/error.py | 4 ++++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/distarray/dist/decorators.py b/distarray/dist/decorators.py index 42126747..ba763f82 100644 --- a/distarray/dist/decorators.py +++ b/distarray/dist/decorators.py @@ -14,7 +14,7 @@ from distarray.dist.distarray import DistArray from distarray.dist.maps import Distribution -from distarray.error import ContextError +from distarray.error import DistributionError from distarray.utils import has_exactly_one @@ -55,7 +55,7 @@ def determine_distribution(self, args, kwargs): msg = ("Arguments must use the same Distribution (given arguments " "of type %r)") msg %= (tuple(set(dists)),) - raise ContextError(msg) + raise DistributionError(msg) return dists[0] @@ -77,7 +77,8 @@ def key_and_push_args(self, args, kwargs, context=None, da_handler=None): """ if context is None: - context = self.determine_context(args, kwargs) + distribution = self.determine_distribution(args, kwargs) + context = distribution.context # handle positional arguments arg_keys = [] diff --git a/distarray/error.py b/distarray/error.py index 3a08176b..27cac7ca 100644 --- a/distarray/error.py +++ b/distarray/error.py @@ -30,3 +30,7 @@ class InvalidRankError(MPIDistArrayError): class MPICommError(MPIDistArrayError): pass + + +class DistributionError(DistArrayError): + pass From fc0aad963704ac2470c83d21a07394528da8208a Mon Sep 17 00:00:00 2001 From: Blake Griffith Date: Mon, 16 Jun 2014 10:51:08 -0500 Subject: [PATCH 3/3] Reflect changes in decorator tests --- distarray/dist/tests/test_decorators.py | 32 ++++++++++++------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/distarray/dist/tests/test_decorators.py b/distarray/dist/tests/test_decorators.py index e6ea707e..0c711976 100644 --- a/distarray/dist/tests/test_decorators.py +++ b/distarray/dist/tests/test_decorators.py @@ -21,29 +21,30 @@ from distarray.dist.context import Context from distarray.dist.maps import Distribution from distarray.dist.decorators import DecoratorBase, local, vectorize -from distarray.error import ContextError +from distarray.error import DistributionError class TestDecoratorBase(TestCase): - def test_determine_context(self): + def test_determine_distribution(self): context = Context() context2 = Context() # for cross Context checking - distribution = Distribution.from_shape(context, (2, 2)) - da = context.ones(distribution) + dist = Distribution.from_shape(context, (2, 2)) + dist2 = Distribution.from_shape(context2, (2, 2)) + da = context.ones(dist) def dummy_func(*args, **kwargs): fn = lambda x: x db = DecoratorBase(fn) - return db.determine_context(args, kwargs) + return db.determine_distribution(args, kwargs) - self.assertEqual(dummy_func(6, 7, context), context) - self.assertEqual(dummy_func('ab', da), context) - self.assertEqual(dummy_func(a=da), context) - self.assertEqual(dummy_func(context, a=da), context) + self.assertEqual(dummy_func(6, 7, dist), dist) + self.assertEqual(dummy_func('ab', da), dist) + self.assertEqual(dummy_func(a=da), dist) + self.assertEqual(dummy_func(dist, a=da), dist) self.assertRaises(TypeError, dummy_func, 'foo') - self.assertRaises(ContextError, dummy_func, context, context2) + self.assertRaises(DistributionError, dummy_func, dist, dist2) def test_key_and_push_args(self): context = Context() @@ -219,15 +220,13 @@ def test_local_add_nums(self): self.assert_allclose(df, 2 * numpy.pi + 11 + 12 + 13) def test_local_add_distarrayproxies(self): - distribution = Distribution.from_shape(self.context, (5, 5)) - dg = self.context.empty(distribution) + dg = self.context.empty(self.da.distribution) dg.fill(33) dh = self.local_add_distarrayproxies(self.da, dg) self.assert_allclose(dh, 33 + 2 * numpy.pi) def test_local_add_mixed(self): - distribution = Distribution.from_shape(self.context, (5, 5)) - di = self.context.empty(distribution) + di = self.context.empty(self.da.distribution) di.fill(33) dj = self.local_add_mixed(self.da, 11, di, 12) self.assert_allclose(dj, 2 * numpy.pi + 11 + 33 + 12) @@ -245,10 +244,9 @@ def test_local_add_kwargs(self): self.assert_allclose(dl, 2 * numpy.pi + 11 + 12) def test_local_add_supermix(self): - distribution = Distribution.from_shape(self.context, (5, 5)) - dm = self.context.empty(distribution) + dm = self.context.empty(self.da.distribution) dm.fill(22) - dn = self.context.empty(distribution) + dn = self.context.empty(self.da.distribution) dn.fill(44) do = self.local_add_supermix(self.da, 11, dm, 33, dc=dn, num3=55) expected = 2 * numpy.pi + 11 + 22 + 33 + 44 + 55 + 66