From c2dc4a68ffe6a354495e7a916addef8c71d3950f Mon Sep 17 00:00:00 2001 From: Kurt Smith Date: Thu, 17 Apr 2014 17:15:30 -0500 Subject: [PATCH 1/7] WIP: decorators -> register with context. --- distarray/context.py | 5 ++ distarray/decorators.py | 147 +++++++++++++--------------------- distarray/local/localarray.py | 6 ++ 3 files changed, 68 insertions(+), 90 deletions(-) diff --git a/distarray/context.py b/distarray/context.py index 41d10148..76d766a6 100644 --- a/distarray/context.py +++ b/distarray/context.py @@ -109,6 +109,11 @@ def get_size(): block=True ) + # Function registration. + + def register(self, func): + setattr(self, func.__name__, func) + # Key management routines: def _setup_key_context(self): diff --git a/distarray/decorators.py b/distarray/decorators.py index c302b92e..f9f60271 100644 --- a/distarray/decorators.py +++ b/distarray/decorators.py @@ -11,10 +11,29 @@ import functools from distarray.client import DistArray -from distarray.context import Context +from distarray.context import Context, DISTARRAY_BASE_NAME from distarray.error import ContextError from distarray.utils import has_exactly_one +def _rpc(func, args, kwargs, result_key): + + args = list(args) + for idx, a in enumerate(args): + if isinstance(a, basestring): + if a.startswith(DISTARRAY_BASE_NAME): + args[idx] = eval(a) + + for k, v in kwargs.items(): + if isinstance(v, basestring): + if v.startswith(DISTARRAY_BASE_NAME): + kwargs[k] = eval(v) + + res = func(*args, **kwargs) + if isinstance(res, LocalArray): + globals()[result_key] = res + res = result_key + return res + class DecoratorBase(object): """ @@ -22,41 +41,30 @@ class DecoratorBase(object): decorator to take an optional kwarg. """ - def __init__(self, fn): + def __init__(self, fn, context): self.fn = fn self.fn_key = self.fn.__name__ functools.update_wrapper(self, fn) - self.context = None + self.context = context + self.push_fn() - def push_fn(self, context, fn_key, fn): + def push_fn(self): """Push function to the engines.""" - context._push({fn_key: fn}) + self.context._push({self.fn_key: self.fn}) - def determine_context(self, args, kwargs): + def check_contexts(self, args, kwargs): """ Determine a context from a functions arguments.""" - contexts = [] # inspect args for a context for arg in args + tuple(kwargs.values()): if isinstance(arg, DistArray): - contexts.append(arg.context) - elif isinstance(arg, Context): - contexts.append(arg) - - # check the args had a context - if contexts == []: - raise TypeError('Function must take DistArray or Context objects.') + if arg.context != self.context: + msg = "DistArray %r not in same context as registered function %r." + raise ContextError(msg % (arg, self.fn)) - # check that all contexts are equal - if not contexts.count(contexts[0]) == len(contexts): - msg = ("Arguments must use the same Context (given arguments of " - "type %r)") - msg %= (tuple(set(contexts)),) - raise ContextError(msg) + return self.context - return contexts[0] - - def key_and_push_args(self, args, kwargs, context=None, da_handler=None): + def build_args(self, args, kwargs): """ Push a tuple of args and dict of kwargs to the engines. Return a tuple with keys corresponding to args values on the engines. And a @@ -73,46 +81,19 @@ def key_and_push_args(self, args, kwargs, context=None, da_handler=None): >>> context.execute(exec_str) """ - if context is None: - context = self.determine_context(args, kwargs) - - # handle positional arguments - arg_keys = [] - push_keys = {} - for arg in args: + args = list(args) + for idx, arg in enumerate(args): if isinstance(arg, DistArray): - if da_handler is None: - arg_keys.append(arg.key) - # da_handler handles distarrays. - else: - arg_keys = da_handler(arg, arg_keys) - else: - new_key = context._generate_key() - arg_keys.append(new_key) - push_keys[new_key] = arg + args[idx] = arg.key # handle key word arguments - for kw in kwargs: - if isinstance(kwargs[kw], DistArray): - kwargs[kw] = kwargs[kw].key - else: - new_key = context._generate_key() - push_keys[new_key] = kwargs[kw] - kwargs[kw] = new_key - - # push the keys to the engines - context._push(push_keys) - - # build arg string - arg_str = '(' + ', '.join(arg_keys) + ',)' + for k, v in kwargs.items(): + if isinstance(v, DistArray): + kwargs[k] = v.key - # build kwarg string - kwarg_iter = ["'%s': %s" % (k, v) for (k, v) in kwargs.items()] - kwarg_str = '{' + ', '.join(kwarg_iter) + '}' + return args, kwargs - return arg_str, kwarg_str - - def process_return_value(self, context, result_key): + def process_return_value(self, result_from_target): """Figure out what to return on the Client. Parameters @@ -128,28 +109,22 @@ def process_return_value(self, context, result_key): client and return it. If all but one of the pulled values is None, return that non-None value only. """ - type_key = context._generate_key() - type_statement = "{} = str(type({}))".format(type_key, result_key) - context._execute(type_statement) - result_type_str = context._pull(type_key) - - def is_NoneType(typestring): - return (typestring == "" or - typestring == "") - - def is_LocalArray(typestring): - return (typestring == "") - - if all(is_LocalArray(r) for r in result_type_str): - result = DistArray.from_localarrays(result_key, context) - elif all(is_NoneType(r) for r in result_type_str): + # type_key = self.context._generate_key() + # type_statement = "{} = str(type({}))".format(type_key, result_key) + # context._execute(type_statement) + # result_type_str = context._pull(type_key) + + results = list(result_from_target.values()) + + if all(isinstance(r, basestring) and r.startswith(DISTARRAY_BASE_NAME) + for r in results): + result = DistArray.from_localarrays(results[0], self.context) + elif all(r is None for r in results): result = None else: - result = context._pull(result_key) - if has_exactly_one(result): - result = next(x for x in result if x is not None) - + non_nones = [r for r in results if r is not None] + if len(non_nones) == 1: + result = non_nones[0] return result @@ -159,21 +134,13 @@ class local(DecoratorBase): def __call__(self, *args, **kwargs): # get context from args context = self.determine_context(args, kwargs) - # push function - self.push_fn(context, self.fn_key, self.fn) - - args, kwargs = self.key_and_push_args(args, kwargs, - context=context) + args, kwargs = self.build_args(args, kwargs) result_key = context._generate_key() - - exec_str = "%s = %s(*%s, **%s)" - exec_str %= (result_key, self.fn_key, args, kwargs) - context._execute(exec_str) - - return self.process_return_value(context, result_key) + results = context.view.apply_async(_rpc, self.func, args, kwargs, result_key).get_dict() + return self.process_return_value(results) -class vectorize(DecoratorBase): +class _vectorize(DecoratorBase): """ Analogous to numpy.vectorize. Input DistArray's must all be the same shape, and this will be the shape of the output distarray. diff --git a/distarray/local/localarray.py b/distarray/local/localarray.py index 8b5d2563..657af035 100644 --- a/distarray/local/localarray.py +++ b/distarray/local/localarray.py @@ -29,6 +29,12 @@ from distarray.local.error import InvalidDimensionError, IncompatibleArrayError +def _rpc(payload): + funcname, args, kwargs = payload + func = globals()[funcname] + return func(*args, **kwargs) + + def _start_stop_block(size, proc_grid_size, proc_grid_rank): nelements = size // proc_grid_size if size % proc_grid_size != 0: From 213ece9e4c1ce23cf4bc11fe666a89a9c744f90a Mon Sep 17 00:00:00 2001 From: Kurt Smith Date: Wed, 23 Apr 2014 12:44:39 -0500 Subject: [PATCH 2/7] Testsuite working with refactored decorators. --- distarray/context.py | 8 +- distarray/decorators.py | 50 +++-- distarray/plotting/plotting.py | 33 ++-- distarray/tests/test_decorators.py | 297 +++++++++++++++-------------- 4 files changed, 200 insertions(+), 188 deletions(-) diff --git a/distarray/context.py b/distarray/context.py index 76d766a6..213fc85c 100644 --- a/distarray/context.py +++ b/distarray/context.py @@ -19,6 +19,8 @@ from distarray.client_map import ClientMDMap from distarray.ipython_utils import IPythonClient +DISTARRAY_BASE_NAME = '__distarray__' + class Context(object): ''' @@ -112,7 +114,9 @@ def get_size(): # Function registration. def register(self, func): - setattr(self, func.__name__, func) + from distarray.decorators import local + lf = local(func, self) + setattr(self, func.__name__, lf) # Key management routines: @@ -128,7 +132,7 @@ def _setup_key_context(self): def _key_basename(self): """ Get the base name for all keys. """ - return '_distarray_key' + return DISTARRAY_BASE_NAME def _key_prefix(self): """ Generate a prefix for a key name for this context. """ diff --git a/distarray/decorators.py b/distarray/decorators.py index f9f60271..b609681a 100644 --- a/distarray/decorators.py +++ b/distarray/decorators.py @@ -15,23 +15,29 @@ from distarray.error import ContextError from distarray.utils import has_exactly_one -def _rpc(func, args, kwargs, result_key): +def _rpc(func, args, kwargs, result_key, prefix): + + main = __import__('__main__') + + from distarray.local.localarray import LocalArray args = list(args) for idx, a in enumerate(args): if isinstance(a, basestring): - if a.startswith(DISTARRAY_BASE_NAME): - args[idx] = eval(a) + if a.startswith(prefix): + args[idx] = getattr(main, a) for k, v in kwargs.items(): if isinstance(v, basestring): - if v.startswith(DISTARRAY_BASE_NAME): - kwargs[k] = eval(v) + if v.startswith(prefix): + kwargs[k] = getattr(main, v) + + print args, kwargs res = func(*args, **kwargs) if isinstance(res, LocalArray): - globals()[result_key] = res - res = result_key + setattr(main, result_key, res) + return result_key return res @@ -41,16 +47,16 @@ class DecoratorBase(object): decorator to take an optional kwarg. """ - def __init__(self, fn, context): - self.fn = fn - self.fn_key = self.fn.__name__ - functools.update_wrapper(self, fn) + def __init__(self, func, context): + self.func = func + self.func_key = self.func.__name__ + functools.update_wrapper(self, func) self.context = context - self.push_fn() + self.push_func() - def push_fn(self): + def push_func(self): """Push function to the engines.""" - self.context._push({self.fn_key: self.fn}) + self.context._push({self.func_key: self.func}) def check_contexts(self, args, kwargs): """ Determine a context from a functions arguments.""" @@ -60,7 +66,7 @@ def check_contexts(self, args, kwargs): if isinstance(arg, DistArray): if arg.context != self.context: msg = "DistArray %r not in same context as registered function %r." - raise ContextError(msg % (arg, self.fn)) + raise ContextError(msg % (arg, self.func)) return self.context @@ -125,6 +131,8 @@ def process_return_value(self, result_from_target): non_nones = [r for r in results if r is not None] if len(non_nones) == 1: result = non_nones[0] + else: + result = results return result @@ -133,10 +141,10 @@ class local(DecoratorBase): def __call__(self, *args, **kwargs): # get context from args - context = self.determine_context(args, kwargs) + context = self.check_contexts(args, kwargs) args, kwargs = self.build_args(args, kwargs) result_key = context._generate_key() - results = context.view.apply_async(_rpc, self.func, args, kwargs, result_key).get_dict() + results = context.view.apply_async(_rpc, self.func, args, kwargs, result_key, DISTARRAY_BASE_NAME).get_dict() return self.process_return_value(results) @@ -151,11 +159,11 @@ def get_local_array(self, da, arg_keys): def __call__(self, *args, **kwargs): # get context from args - context = self.determine_context(args, kwargs) + context = self.check_contexts(args, kwargs) # push function - self.push_fn(context, self.fn_key, self.fn) + self.push_func(context, self.func_key, self.func) # vectorize the function - exec_str = "%s = numpy.vectorize(%s)" % (self.fn_key, self.fn_key) + exec_str = "%s = numpy.vectorize(%s)" % (self.func_key, self.func_key) context._execute(exec_str) # Find the first distarray, they should all be the same up to the data. @@ -173,7 +181,7 @@ def __call__(self, *args, **kwargs): # Call the function exec_str = ("if %s.local_array.size != 0: %s.local_array = " "%s(*%s, **%s)") - exec_str %= (out.key, out.key, self.fn_key, args_str, + exec_str %= (out.key, out.key, self.func_key, args_str, kwargs_str) context._execute(exec_str) return out diff --git a/distarray/plotting/plotting.py b/distarray/plotting/plotting.py index 1b89e399..368e7309 100644 --- a/distarray/plotting/plotting.py +++ b/distarray/plotting/plotting.py @@ -10,22 +10,7 @@ from six.moves import range from matplotlib import pyplot, colors, cm -from numpy import arange, concatenate, empty, linspace, resize - -from distarray.decorators import local - - -@local -def _get_ranks(arr): - """ - Given a distarray arr, return a distarray with the same shape, but - with the elements equal to the rank of the process the element is - on. - """ - out = arr.copy() - out.local_array[:] = arr.comm_rank - out.local_array = out.local_array.astype(int) - return out +from numpy import concatenate, empty, linspace def cmap_discretize(cmap, N): @@ -272,8 +257,22 @@ def plot_array_distribution(darray, # This is based somewhat on: # http://matplotlib.org/examples/api/colorbar_only.html + def _get_ranks(arr): + """ + Given a distarray arr, return a distarray with the same shape, but + with the elements equal to the rank of the process the element is + on. + """ + out = arr.copy() + out.local_array[:] = arr.comm_rank + out.local_array = out.local_array.astype(int) + return out + + ctx = darray.context + ctx.register(_get_ranks) + # Process per element. - process_darray = _get_ranks(darray) + process_darray = ctx._get_ranks(darray) process_array = process_darray.toarray() # Values per element. diff --git a/distarray/tests/test_decorators.py b/distarray/tests/test_decorators.py index b0110898..c9e9b848 100644 --- a/distarray/tests/test_decorators.py +++ b/distarray/tests/test_decorators.py @@ -18,131 +18,131 @@ from numpy.testing import assert_array_equal from distarray.context import Context -from distarray.decorators import DecoratorBase, local, vectorize +from distarray.decorators import DecoratorBase, local from distarray.error import ContextError -class TestDecoratorBase(TestCase): +# class TestDecoratorBase(TestCase): - def test_determine_context(self): - context = Context() - context2 = Context() # for cross Context checking - da = context.ones((2, 2)) + # def test_determine_context(self): + # context = Context() + # context2 = Context() # for cross Context checking + # da = context.ones((2, 2)) - def dummy_func(*args, **kwargs): - fn = lambda x: x - db = DecoratorBase(fn) - return db.determine_context(args, kwargs) + # def dummy_func(*args, **kwargs): + # fn = lambda x: x + # db = DecoratorBase(fn) + # return db.determine_context(args, kwargs) - self.assertEqual(dummy_func(6, 7, context), context) - self.assertEqual(dummy_func('ab', da), context) - self.assertEqual(dummy_func(a=da), context) - self.assertEqual(dummy_func(context, a=da), context) + # self.assertEqual(dummy_func(6, 7, context), context) + # self.assertEqual(dummy_func('ab', da), context) + # self.assertEqual(dummy_func(a=da), context) + # self.assertEqual(dummy_func(context, a=da), context) - self.assertRaises(TypeError, dummy_func, 'foo') - self.assertRaises(ContextError, dummy_func, context, context2) + # self.assertRaises(TypeError, dummy_func, 'foo') + # self.assertRaises(ContextError, dummy_func, context, context2) - def test_key_and_push_args(self): - context = Context() + # def test_key_and_push_args(self): + # context = Context() - da = context.ones((2, 2)) - db = da*2 + # da = context.ones((2, 2)) + # db = da*2 - def dummy_func(*args, **kwargs): - fn = lambda x: x - db = DecoratorBase(fn) - return db.key_and_push_args(args, kwargs) + # def dummy_func(*args, **kwargs): + # fn = lambda x: x + # db = DecoratorBase(fn) + # return db.key_and_push_args(args, kwargs) - # Push some distarrays - arg_keys1, kw_keys1 = dummy_func(da, db, foo=da, bar=db) - # with some other data too - arg_keys2, kw_keys2 = dummy_func(da, 'question', answer=42, foo=db) + # # Push some distarrays + # arg_keys1, kw_keys1 = dummy_func(da, db, foo=da, bar=db) + # # with some other data too + # arg_keys2, kw_keys2 = dummy_func(da, 'question', answer=42, foo=db) - self.assertEqual(arg_keys1, "(%s, %s,)" % (da.key, db.key)) - # assert we pushed the right key, keystr pair - self.assertTrue("'foo': %s" % (da.key) in kw_keys1) - self.assertTrue("'bar': %s" % (db.key) in kw_keys1) + # self.assertEqual(arg_keys1, "(%s, %s,)" % (da.key, db.key)) + # # assert we pushed the right key, keystr pair + # self.assertTrue("'foo': %s" % (da.key) in kw_keys1) + # self.assertTrue("'bar': %s" % (db.key) in kw_keys1) - # lots of string manipulation to parse out the relevant pieces - # of the python commands. - self.assertEqual(arg_keys2[1: -2].split(', ')[0], da.key) + # # lots of string manipulation to parse out the relevant pieces + # # of the python commands. + # self.assertEqual(arg_keys2[1: -2].split(', ')[0], da.key) - _key = arg_keys2[1: -2].split(', ')[1] - self.assertEqual(context._pull0(_key), 'question') - self.assertTrue("'answer'" in kw_keys2) + # _key = arg_keys2[1: -2].split(', ')[1] + # self.assertEqual(context._pull0(_key), 'question') + # self.assertTrue("'answer'" in kw_keys2) - self.assertTrue("'foo'" in kw_keys2) - self.assertTrue(db.key in kw_keys2) + # self.assertTrue("'foo'" in kw_keys2) + # self.assertTrue(db.key in kw_keys2) class TestLocalDecorator(TestCase): - # Functions for @local decorator tests. These are here so we can - # guarantee they are pushed to the engines before we try to use them. - @local - def local_add50(da): - return da + 50 - - @local - def local_add_num(da, num): - return da + num - - @local - def assert_allclose(da, db): - assert numpy.allclose(da, db), "Arrays not equal within tolerance." - - @local - def local_sin(da): - return numpy.sin(da) - @local - def local_sum(da): - return numpy.sum(da.get_localarray()) - - @local - def call_barrier(da): - from mpi4py import MPI - MPI.COMM_WORLD.Barrier() - return da - - @local - def local_add_nums(da, num1, num2, num3): - return da + num1 + num2 + num3 + @classmethod + def setUpClass(cls): + cls.context = ctx = Context() + cls.da = cls.context.empty((5, 5)) + cls.da.fill(2 * numpy.pi) + # Functions for @local decorator tests. These are here so we can + # guarantee they are pushed to the engines before we try to use them. + def local_add50(da): + return da + 50 + ctx.register(local_add50) + + def local_add_num(da, num): + return da + num + ctx.register(local_add_num) + + def assert_allclose(da, db): + assert numpy.allclose(da, db), "Arrays not equal within tolerance." + ctx.register(assert_allclose) + + def local_sin(da): + return numpy.sin(da) + ctx.register(local_sin) + + def local_sum(da): + return numpy.sum(da.get_localarray()) + ctx.register(local_sum) + + def call_barrier(da): + from mpi4py import MPI + MPI.COMM_WORLD.Barrier() + return da + ctx.register(call_barrier) - @local - def local_add_distarrayproxies(da, dg): - return da + dg + def local_add_nums(da, num1, num2, num3): + return da + num1 + num2 + num3 + ctx.register(local_add_nums) - @local - def local_add_mixed(da, num1, dg, num2): - return da + num1 + dg + num2 + def local_add_distarrayproxies(da, dg): + return da + dg + ctx.register(local_add_distarrayproxies) - @local - def local_add_ndarray(da, num, ndarr): - return da + num + ndarr + def local_add_mixed(da, num1, dg, num2): + return da + num1 + dg + num2 + ctx.register(local_add_mixed) - @local - def local_add_kwargs(da, num1, num2=55): - return da + num1 + num2 + def local_add_ndarray(da, num, ndarr): + return da + num + ndarr + ctx.register(local_add_ndarray) - @local - def local_add_supermix(da, num1, db, num2, dc, num3=99, num4=66): - return da + num1 + db + num2 + dc + num3 + num4 + def local_add_kwargs(da, num1, num2=55): + return da + num1 + num2 + ctx.register(local_add_kwargs) - @local - def local_none(da): - return None + def local_add_supermix(da, num1, db, num2, dc, num3=99, num4=66): + return da + num1 + db + num2 + dc + num3 + num4 + ctx.register(local_add_supermix) - @local - def parameterless(): - """This is a parameterless function.""" - return None + def local_none(da): + return None + ctx.register(local_none) - @classmethod - def setUpClass(cls): - cls.context = Context() - cls.da = cls.context.empty((5, 5)) - cls.da.fill(2 * numpy.pi) + def parameterless(): + """This is a parameterless function.""" + return None + ctx.register(parameterless) @classmethod def tearDownClass(cls): @@ -160,39 +160,39 @@ def fill_a(a): a[i, j] = i + j return a - @local def fill_da(da): for i in da.maps[0].global_iter: for j in da.maps[1].global_iter: da.global_index[i, j] = i + j return da + context.register(fill_da) - da = fill_da(da) + da = context.fill_da(da) a = fill_a(a) assert_array_equal(da.toarray(), a) - def test_different_contexts(self): - ctx1 = Context(targets=range(4)) - ctx2 = Context(targets=range(3)) - da1 = ctx1.ones((10,)) - da2 = ctx2.ones((10,)) - db1 = self.local_sin(da1) - db2 = self.local_sin(da2) - ndarr1 = db1.toarray() - ndarr2 = db2.toarray() - assert_array_equal(ndarr1, ndarr2) + # def test_different_contexts(self): + # ctx1 = Context(targets=range(4)) + # ctx2 = Context(targets=range(3)) + # da1 = ctx1.ones((10,)) + # da2 = ctx2.ones((10,)) + # db1 = self.local_sin(da1) + # db2 = self.local_sin(da2) + # ndarr1 = db1.toarray() + # ndarr2 = db2.toarray() + # assert_array_equal(ndarr1, ndarr2) def test_local_sin(self): - db = self.local_sin(self.da) - self.assert_allclose(db, 0) + db = self.context.local_sin(self.da) + self.context.assert_allclose(db, 0) def test_local_add50(self): - dc = self.local_add50(self.da) - self.assert_allclose(dc, 2 * numpy.pi + 50) + dc = self.context.local_add50(self.da) + self.context.assert_allclose(dc, 2 * numpy.pi + 50) def test_local_sum(self): - dd = self.local_sum(self.da) + dd = self.context.local_sum(self.da) lshapes = self.da.get_localshapes() expected = [] for lshape in lshapes: @@ -201,84 +201,85 @@ def test_local_sum(self): self.assertAlmostEqual(v, e, places=5) def test_local_add_num(self): - de = self.local_add_num(self.da, 11) - self.assert_allclose(de, 2 * numpy.pi + 11) + de = self.context.local_add_num(self.da, 11) + self.context.assert_allclose(de, 2 * numpy.pi + 11) def test_local_add_nums(self): - df = self.local_add_nums(self.da, 11, 12, 13) - self.assert_allclose(df, 2 * numpy.pi + 11 + 12 + 13) + df = self.context.local_add_nums(self.da, 11, 12, 13) + self.context.assert_allclose(df, 2 * numpy.pi + 11 + 12 + 13) def test_local_add_distarrayproxies(self): dg = self.context.empty((5, 5)) dg.fill(33) - dh = self.local_add_distarrayproxies(self.da, dg) - self.assert_allclose(dh, 33 + 2 * numpy.pi) + dh = self.context.local_add_distarrayproxies(self.da, dg) + self.context.assert_allclose(dh, 33 + 2 * numpy.pi) def test_local_add_mixed(self): di = self.context.empty((5, 5)) di.fill(33) - dj = self.local_add_mixed(self.da, 11, di, 12) - self.assert_allclose(dj, 2 * numpy.pi + 11 + 33 + 12) + dj = self.context.local_add_mixed(self.da, 11, di, 12) + self.context.assert_allclose(dj, 2 * numpy.pi + 11 + 33 + 12) @unittest.skip('Locally adding ndarrays not supported.') def test_local_add_ndarray(self): shp = self.da.get_localshapes()[0] ndarr = numpy.empty(shp) ndarr.fill(33) - dk = self.local_add_ndarray(self.da, 11, ndarr) - self.assert_allclose(dk, 2 * numpy.pi + 11 + 33) + dk = self.context.local_add_ndarray(self.da, 11, ndarr) + self.context.assert_allclose(dk, 2 * numpy.pi + 11 + 33) def test_local_add_kwargs(self): - dl = self.local_add_kwargs(self.da, 11, num2=12) - self.assert_allclose(dl, 2 * numpy.pi + 11 + 12) + dl = self.context.local_add_kwargs(self.da, 11, num2=12) + self.context.assert_allclose(dl, 2 * numpy.pi + 11 + 12) def test_local_add_supermix(self): dm = self.context.empty((5, 5)) dm.fill(22) dn = self.context.empty((5, 5)) dn.fill(44) - do = self.local_add_supermix(self.da, 11, dm, 33, dc=dn, num3=55) + do = self.context.local_add_supermix(self.da, 11, dm, 33, dc=dn, num3=55) expected = 2 * numpy.pi + 11 + 22 + 33 + 44 + 55 + 66 - self.assert_allclose(do, expected) + self.context.assert_allclose(do, expected) def test_local_none(self): - dp = self.local_none(self.da) + dp = self.context.local_none(self.da) self.assertTrue(dp is None) def test_barrier(self): - self.call_barrier(self.da) + self.context.call_barrier(self.da) def test_parameterless(self): - self.assertRaises(TypeError, self.parameterless) + result = self.context.parameterless() + self.assertIsNone(result) def test_function_metadata(self): name = "parameterless" docstring = """This is a parameterless function.""" - self.assertEqual(self.parameterless.__name__, name) - self.assertEqual(self.parameterless.__doc__, docstring) + self.assertEqual(self.context.parameterless.__name__, name) + self.assertEqual(self.context.parameterless.__doc__, docstring) -class TestVectorizeDecorator(TestCase): +# class TestVectorizeDecorator(TestCase): - def test_vectorize(self): - """Test the @vectorize decorator for parity with NumPy's""" + # def test_vectorize(self): + # """Test the @vectorize decorator for parity with NumPy's""" - context = Context() + # context = Context() - a = numpy.arange(16).reshape(4, 4) - da = context.fromndarray(a) + # a = numpy.arange(16).reshape(4, 4) + # da = context.fromndarray(a) - @vectorize - def da_fn(a, b, c): - return a**2 + b + c + # @vectorize + # def da_fn(a, b, c): + # return a**2 + b + c - @numpy.vectorize - def a_fn(a, b, c): - return a**2 + b + c + # @numpy.vectorize + # def a_fn(a, b, c): + # return a**2 + b + c - a = a_fn(a, a, 6) - db = da_fn(da, da, 6) - assert_array_equal(db.toarray(), a) + # a = a_fn(a, a, 6) + # db = da_fn(da, da, 6) + # assert_array_equal(db.toarray(), a) if __name__ == '__main__': From 914011a0002356f51a60b00915ff307eb5d71da9 Mon Sep 17 00:00:00 2001 From: Kurt Smith Date: Wed, 23 Apr 2014 12:45:25 -0500 Subject: [PATCH 3/7] Remove unused imports. --- distarray/decorators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distarray/decorators.py b/distarray/decorators.py index b609681a..b0f72166 100644 --- a/distarray/decorators.py +++ b/distarray/decorators.py @@ -11,9 +11,9 @@ import functools from distarray.client import DistArray -from distarray.context import Context, DISTARRAY_BASE_NAME +from distarray.context import DISTARRAY_BASE_NAME from distarray.error import ContextError -from distarray.utils import has_exactly_one + def _rpc(func, args, kwargs, result_key, prefix): From 7710d0c167bf664e68669b864c62c1626b5b778d Mon Sep 17 00:00:00 2001 From: Kurt Smith Date: Tue, 29 Apr 2014 12:00:50 -0500 Subject: [PATCH 4/7] Refactoring of decorators. --- distarray/context.py | 9 +- distarray/decorators.py | 152 ++++++++++++++-------------- distarray/tests/test_decorators.py | 154 +++++++++++++---------------- 3 files changed, 148 insertions(+), 167 deletions(-) diff --git a/distarray/context.py b/distarray/context.py index 213fc85c..abbd70bd 100644 --- a/distarray/context.py +++ b/distarray/context.py @@ -111,13 +111,18 @@ def get_size(): block=True ) - # Function registration. + # `localize` and `vectorize` allow extra functions to be added to the context. - def register(self, func): + def localize(self, func): from distarray.decorators import local lf = local(func, self) setattr(self, func.__name__, lf) + def vectorize(self, func): + from distarray.decorators import vectorize + lf = vectorize(func, self) + setattr(self, func.__name__, lf) + # Key management routines: def _setup_key_context(self): diff --git a/distarray/decorators.py b/distarray/decorators.py index b0f72166..1aa1be7d 100644 --- a/distarray/decorators.py +++ b/distarray/decorators.py @@ -15,36 +15,14 @@ from distarray.error import ContextError -def _rpc(func, args, kwargs, result_key, prefix): - - main = __import__('__main__') - - from distarray.local.localarray import LocalArray - - args = list(args) - for idx, a in enumerate(args): - if isinstance(a, basestring): - if a.startswith(prefix): - args[idx] = getattr(main, a) - - for k, v in kwargs.items(): - if isinstance(v, basestring): - if v.startswith(prefix): - kwargs[k] = getattr(main, v) - - print args, kwargs - - res = func(*args, **kwargs) - if isinstance(res, LocalArray): - setattr(main, result_key, res) - return result_key - return res +class FunctionRegistrationBase(object): + """ + Base class for local function registration. + Subclasses: + Localize + Vectorize -class DecoratorBase(object): - """ - Base class for decorators, handles name wrapping and allows the - decorator to take an optional kwarg. """ def __init__(self, func, context): @@ -52,13 +30,8 @@ def __init__(self, func, context): self.func_key = self.func.__name__ functools.update_wrapper(self, func) self.context = context - self.push_func() - - def push_func(self): - """Push function to the engines.""" - self.context._push({self.func_key: self.func}) - def check_contexts(self, args, kwargs): + def determine_context(self, args, kwargs): """ Determine a context from a functions arguments.""" # inspect args for a context @@ -72,19 +45,9 @@ def check_contexts(self, args, kwargs): def build_args(self, args, kwargs): """ - Push a tuple of args and dict of kwargs to the engines. Return a - tuple with keys corresponding to args values on the engines. And a - dictionary with the same keys and values which are the keys to the - input dictionary's values. - - This allows us to use the following interface to execute code on - the engines: - - >>> def foo(*args, **kwargs): - >>> args, kwargs = _key_and_push_args(args, kwargs) - >>> exec_str = "remote_foo(*%s, **%s)" - >>> exec_str %= (args, kwargs) - >>> context.execute(exec_str) + Returns a new args tuple and kwargs dictionary with all distarrays in + the original args and kwargs arguments replaced by their .keys. + """ args = list(args) @@ -115,10 +78,6 @@ def process_return_value(self, result_from_target): client and return it. If all but one of the pulled values is None, return that non-None value only. """ - # type_key = self.context._generate_key() - # type_statement = "{} = str(type({}))".format(type_key, result_key) - # context._execute(type_statement) - # result_type_str = context._pull(type_key) results = list(result_from_target.values()) @@ -136,19 +95,66 @@ def process_return_value(self, result_from_target): return result -class local(DecoratorBase): +def _rpc_localize(func, args, kwargs, result_key, prefix): + + ns = __import__('__main__') + + from distarray.local.localarray import LocalArray + + args = list(args) + for idx, a in enumerate(args): + if isinstance(a, basestring): + if a.startswith(prefix): + args[idx] = getattr(ns, a) + + for k, v in kwargs.items(): + if isinstance(v, basestring): + if v.startswith(prefix): + kwargs[k] = getattr(ns, v) + + res = func(*args, **kwargs) + if isinstance(res, LocalArray): + setattr(ns, result_key, res) + return result_key + return res + + +class Localize(FunctionRegistrationBase): """Decorator to run a function locally on the engines.""" def __call__(self, *args, **kwargs): - # get context from args - context = self.check_contexts(args, kwargs) + context = self.determine_context(args, kwargs) args, kwargs = self.build_args(args, kwargs) result_key = context._generate_key() - results = context.view.apply_async(_rpc, self.func, args, kwargs, result_key, DISTARRAY_BASE_NAME).get_dict() + results = context.view.apply_async(_rpc_localize, self.func, + args, kwargs, result_key, + DISTARRAY_BASE_NAME).get_dict() return self.process_return_value(results) -class _vectorize(DecoratorBase): +def _rpc_vectorize(func, args, kwargs, out, prefix): + + ns = __import__('__main__') + import numpy as np + + args = list(args) + for idx, a in enumerate(args): + if isinstance(a, basestring): + if a.startswith(prefix): + args[idx] = getattr(ns, a).local_array + + for k, v in kwargs.items(): + if isinstance(v, basestring): + if v.startswith(prefix): + kwargs[k] = getattr(ns, v).local_array + + out = getattr(ns, out) + + func = np.vectorize(func) + out.local_array = func(*args) + + +class Vectorize(FunctionRegistrationBase): """ Analogous to numpy.vectorize. Input DistArray's must all be the same shape, and this will be the shape of the output distarray. @@ -158,30 +164,16 @@ def get_local_array(self, da, arg_keys): return arg_keys + [da.key + '.local_array'] def __call__(self, *args, **kwargs): - # get context from args - context = self.check_contexts(args, kwargs) - # push function - self.push_func(context, self.func_key, self.func) - # vectorize the function - exec_str = "%s = numpy.vectorize(%s)" % (self.func_key, self.func_key) - context._execute(exec_str) - - # Find the first distarray, they should all be the same up to the data. + context = self.determine_context(args, kwargs) + # TODO: FIXME: This uses an extra round-trip (or two (or three)) to + # create the `out` array. Better would be to create a new LocalArray + # inside _rpc_vectorize and return its metadata to create a DistArray + # using `.from_localarrays()`. for arg in args: if isinstance(arg, DistArray): - # Create the output distarray. out = context.empty(arg.shape, dtype=arg.dtype, - dist=arg.dist, - grid_shape=arg.grid_shape) - # parse args - args_str, kwargs_str = self.key_and_push_args( - args, kwargs, context=context, - da_handler=self.get_local_array) - - # Call the function - exec_str = ("if %s.local_array.size != 0: %s.local_array = " - "%s(*%s, **%s)") - exec_str %= (out.key, out.key, self.func_key, args_str, - kwargs_str) - context._execute(exec_str) - return out + dist=arg.dist, grid_shape=arg.grid_shape) + args, kwargs = self.build_args(args, kwargs) + context.view.apply_sync(_rpc_vectorize, self.func, + args, kwargs, out.key, DISTARRAY_BASE_NAME) + return out diff --git a/distarray/tests/test_decorators.py b/distarray/tests/test_decorators.py index c9e9b848..c4d64806 100644 --- a/distarray/tests/test_decorators.py +++ b/distarray/tests/test_decorators.py @@ -18,131 +18,126 @@ from numpy.testing import assert_array_equal from distarray.context import Context -from distarray.decorators import DecoratorBase, local +from distarray.decorators import FunctionRegistrationBase from distarray.error import ContextError -# class TestDecoratorBase(TestCase): +class TestFunctionRegistrationBase(TestCase): - # def test_determine_context(self): - # context = Context() - # context2 = Context() # for cross Context checking - # da = context.ones((2, 2)) - - # def dummy_func(*args, **kwargs): - # fn = lambda x: x - # db = DecoratorBase(fn) - # return db.determine_context(args, kwargs) + def test_determine_context(self): + context = Context() + context2 = Context() # for cross Context checking + da = context.ones((2, 2)) - # self.assertEqual(dummy_func(6, 7, context), context) - # self.assertEqual(dummy_func('ab', da), context) - # self.assertEqual(dummy_func(a=da), context) - # self.assertEqual(dummy_func(context, a=da), context) + def dummy_func(ctx, *args, **kwargs): + def fn(x): + return x + db = FunctionRegistrationBase(fn, ctx) + return db.determine_context(args, kwargs) - # self.assertRaises(TypeError, dummy_func, 'foo') - # self.assertRaises(ContextError, dummy_func, context, context2) + self.assertEqual(dummy_func(context, 6, 7), context) + self.assertEqual(dummy_func(context, 'ab', da), context) + self.assertEqual(dummy_func(context, a=da), context) + self.assertEqual(dummy_func(context, a=da), context) - # def test_key_and_push_args(self): - # context = Context() + db = context2.ones((2, 2)) + self.assertRaises(ContextError, dummy_func, context, db) - # da = context.ones((2, 2)) - # db = da*2 + def test_key_and_push_args(self): + context = Context() - # def dummy_func(*args, **kwargs): - # fn = lambda x: x - # db = DecoratorBase(fn) - # return db.key_and_push_args(args, kwargs) + da = context.ones((2, 2)) + db = da*2 - # # Push some distarrays - # arg_keys1, kw_keys1 = dummy_func(da, db, foo=da, bar=db) - # # with some other data too - # arg_keys2, kw_keys2 = dummy_func(da, 'question', answer=42, foo=db) + def dummy_func(ctx, *args, **kwargs): + def fn(x): + return x + db = FunctionRegistrationBase(fn, ctx) + return db.build_args(args, kwargs) - # self.assertEqual(arg_keys1, "(%s, %s,)" % (da.key, db.key)) - # # assert we pushed the right key, keystr pair - # self.assertTrue("'foo': %s" % (da.key) in kw_keys1) - # self.assertTrue("'bar': %s" % (db.key) in kw_keys1) + # Push some distarrays + arg_keys1, kw_keys1 = dummy_func(context, da, db, foo=da, bar=db) + # with some other data too + arg_keys2, kw_keys2 = dummy_func(context, da, 'question', answer=42, foo=db) - # # lots of string manipulation to parse out the relevant pieces - # # of the python commands. - # self.assertEqual(arg_keys2[1: -2].split(', ')[0], da.key) + self.assertSequenceEqual(arg_keys1, (da.key, db.key)) + # assert we pushed the right key, keystr pair + self.assertDictEqual({'foo': da.key, 'bar': db.key}, kw_keys1) - # _key = arg_keys2[1: -2].split(', ')[1] - # self.assertEqual(context._pull0(_key), 'question') - # self.assertTrue("'answer'" in kw_keys2) + # lots of string manipulation to parse out the relevant pieces + # of the python commands. + self.assertSequenceEqual(arg_keys2, (da.key, 'question')) - # self.assertTrue("'foo'" in kw_keys2) - # self.assertTrue(db.key in kw_keys2) + self.assertSetEqual(set("answer foo".split()), set(kw_keys2.keys())) + self.assertIn(db.key, kw_keys2.values()) class TestLocalDecorator(TestCase): - @classmethod def setUpClass(cls): cls.context = ctx = Context() cls.da = cls.context.empty((5, 5)) cls.da.fill(2 * numpy.pi) - # Functions for @local decorator tests. These are here so we can - # guarantee they are pushed to the engines before we try to use them. + def local_add50(da): return da + 50 - ctx.register(local_add50) + ctx.localize(local_add50) def local_add_num(da, num): return da + num - ctx.register(local_add_num) + ctx.localize(local_add_num) def assert_allclose(da, db): assert numpy.allclose(da, db), "Arrays not equal within tolerance." - ctx.register(assert_allclose) + ctx.localize(assert_allclose) def local_sin(da): return numpy.sin(da) - ctx.register(local_sin) + ctx.localize(local_sin) def local_sum(da): return numpy.sum(da.get_localarray()) - ctx.register(local_sum) + ctx.localize(local_sum) def call_barrier(da): from mpi4py import MPI MPI.COMM_WORLD.Barrier() return da - ctx.register(call_barrier) + ctx.localize(call_barrier) def local_add_nums(da, num1, num2, num3): return da + num1 + num2 + num3 - ctx.register(local_add_nums) + ctx.localize(local_add_nums) def local_add_distarrayproxies(da, dg): return da + dg - ctx.register(local_add_distarrayproxies) + ctx.localize(local_add_distarrayproxies) def local_add_mixed(da, num1, dg, num2): return da + num1 + dg + num2 - ctx.register(local_add_mixed) + ctx.localize(local_add_mixed) def local_add_ndarray(da, num, ndarr): return da + num + ndarr - ctx.register(local_add_ndarray) + ctx.localize(local_add_ndarray) def local_add_kwargs(da, num1, num2=55): return da + num1 + num2 - ctx.register(local_add_kwargs) + ctx.localize(local_add_kwargs) def local_add_supermix(da, num1, db, num2, dc, num3=99, num4=66): return da + num1 + db + num2 + dc + num3 + num4 - ctx.register(local_add_supermix) + ctx.localize(local_add_supermix) def local_none(da): return None - ctx.register(local_none) + ctx.localize(local_none) def parameterless(): """This is a parameterless function.""" return None - ctx.register(parameterless) + ctx.localize(parameterless) @classmethod def tearDownClass(cls): @@ -165,24 +160,13 @@ def fill_da(da): for j in da.maps[1].global_iter: da.global_index[i, j] = i + j return da - context.register(fill_da) + context.localize(fill_da) da = context.fill_da(da) a = fill_a(a) assert_array_equal(da.toarray(), a) - # def test_different_contexts(self): - # ctx1 = Context(targets=range(4)) - # ctx2 = Context(targets=range(3)) - # da1 = ctx1.ones((10,)) - # da2 = ctx2.ones((10,)) - # db1 = self.local_sin(da1) - # db2 = self.local_sin(da2) - # ndarr1 = db1.toarray() - # ndarr2 = db2.toarray() - # assert_array_equal(ndarr1, ndarr2) - def test_local_sin(self): db = self.context.local_sin(self.da) self.context.assert_allclose(db, 0) @@ -259,27 +243,27 @@ def test_function_metadata(self): self.assertEqual(self.context.parameterless.__doc__, docstring) -# class TestVectorizeDecorator(TestCase): +class TestVectorizeDecorator(TestCase): - # def test_vectorize(self): - # """Test the @vectorize decorator for parity with NumPy's""" + def test_vectorize(self): + """Test the @vectorize decorator for parity with NumPy's""" - # context = Context() + ctx = Context() - # a = numpy.arange(16).reshape(4, 4) - # da = context.fromndarray(a) + a = numpy.arange(16).reshape(4, 4) + da = ctx.fromndarray(a) - # @vectorize - # def da_fn(a, b, c): - # return a**2 + b + c + def da_fn(a, b, c): + return a**2 + b + c + ctx.vectorize(da_fn) - # @numpy.vectorize - # def a_fn(a, b, c): - # return a**2 + b + c + @numpy.vectorize + def a_fn(a, b, c): + return a**2 + b + c - # a = a_fn(a, a, 6) - # db = da_fn(da, da, 6) - # assert_array_equal(db.toarray(), a) + a = a_fn(a, a, 6) + db = ctx.da_fn(da, da, 6) + assert_array_equal(db.toarray(), a) if __name__ == '__main__': From 526fd720d37fa7b04a7c682b88aceff1c9da65a8 Mon Sep 17 00:00:00 2001 From: Kurt Smith Date: Tue, 29 Apr 2014 12:24:06 -0500 Subject: [PATCH 5/7] Cleanup docs, remove unused method. --- distarray/decorators.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/distarray/decorators.py b/distarray/decorators.py index 1aa1be7d..c48dbb49 100644 --- a/distarray/decorators.py +++ b/distarray/decorators.py @@ -120,7 +120,7 @@ def _rpc_localize(func, args, kwargs, result_key, prefix): class Localize(FunctionRegistrationBase): - """Decorator to run a function locally on the engines.""" + """Runs a function locally on the engines.""" def __call__(self, *args, **kwargs): context = self.determine_context(args, kwargs) @@ -156,13 +156,10 @@ def _rpc_vectorize(func, args, kwargs, out, prefix): class Vectorize(FunctionRegistrationBase): """ - Analogous to numpy.vectorize. Input DistArray's must all be the - same shape, and this will be the shape of the output distarray. + Like `Localize`, but vectorizes the function with numpy.vectorize and runs + it on the engines. """ - def get_local_array(self, da, arg_keys): - return arg_keys + [da.key + '.local_array'] - def __call__(self, *args, **kwargs): context = self.determine_context(args, kwargs) # TODO: FIXME: This uses an extra round-trip (or two (or three)) to From a663fdaaca3fe2a7a78aed0afbaa0189718b69bc Mon Sep 17 00:00:00 2001 From: Kurt Smith Date: Tue, 29 Apr 2014 12:26:50 -0500 Subject: [PATCH 6/7] Fix imports. --- distarray/context.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/distarray/context.py b/distarray/context.py index abbd70bd..ba7e2cda 100644 --- a/distarray/context.py +++ b/distarray/context.py @@ -114,13 +114,13 @@ def get_size(): # `localize` and `vectorize` allow extra functions to be added to the context. def localize(self, func): - from distarray.decorators import local - lf = local(func, self) + from distarray.decorators import Localize + lf = Localize(func, self) setattr(self, func.__name__, lf) def vectorize(self, func): - from distarray.decorators import vectorize - lf = vectorize(func, self) + from distarray.decorators import Vectorize + lf = Vectorize(func, self) setattr(self, func.__name__, lf) # Key management routines: From 669a4f01069e47a760040a1fff91e093e4d4aa3d Mon Sep 17 00:00:00 2001 From: Kurt Smith Date: Wed, 30 Apr 2014 10:28:59 -0500 Subject: [PATCH 7/7] Remove references to `basestring`, replace with `string_types`. Makes Python3 happy. --- distarray/decorators.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/distarray/decorators.py b/distarray/decorators.py index c48dbb49..953a77bb 100644 --- a/distarray/decorators.py +++ b/distarray/decorators.py @@ -13,6 +13,7 @@ from distarray.client import DistArray from distarray.context import DISTARRAY_BASE_NAME from distarray.error import ContextError +from distarray.externals.six import string_types class FunctionRegistrationBase(object): @@ -81,7 +82,7 @@ def process_return_value(self, result_from_target): results = list(result_from_target.values()) - if all(isinstance(r, basestring) and r.startswith(DISTARRAY_BASE_NAME) + if all(isinstance(r, string_types) and r.startswith(DISTARRAY_BASE_NAME) for r in results): result = DistArray.from_localarrays(results[0], self.context) elif all(r is None for r in results): @@ -100,15 +101,16 @@ def _rpc_localize(func, args, kwargs, result_key, prefix): ns = __import__('__main__') from distarray.local.localarray import LocalArray + from distarray.externals.six import string_types args = list(args) for idx, a in enumerate(args): - if isinstance(a, basestring): + if isinstance(a, string_types): if a.startswith(prefix): args[idx] = getattr(ns, a) for k, v in kwargs.items(): - if isinstance(v, basestring): + if isinstance(v, string_types): if v.startswith(prefix): kwargs[k] = getattr(ns, v) @@ -136,15 +138,16 @@ def _rpc_vectorize(func, args, kwargs, out, prefix): ns = __import__('__main__') import numpy as np + from distarray.externals.six import string_types args = list(args) for idx, a in enumerate(args): - if isinstance(a, basestring): + if isinstance(a, string_types): if a.startswith(prefix): args[idx] = getattr(ns, a).local_array for k, v in kwargs.items(): - if isinstance(v, basestring): + if isinstance(v, string_types): if v.startswith(prefix): kwargs[k] = getattr(ns, v).local_array