diff --git a/distarray/context.py b/distarray/context.py index 41d10148..ba7e2cda 100644 --- a/distarray/context.py +++ b/distarray/context.py @@ -19,6 +19,8 @@ from distarray.client_map import ClientMDMap from distarray.ipython_utils import IPythonClient +DISTARRAY_BASE_NAME = '__distarray__' + class Context(object): ''' @@ -109,6 +111,18 @@ def get_size(): block=True ) + # `localize` and `vectorize` allow extra functions to be added to the context. + + def localize(self, func): + from distarray.decorators import Localize + lf = Localize(func, self) + setattr(self, func.__name__, lf) + + def vectorize(self, func): + from distarray.decorators import Vectorize + lf = Vectorize(func, self) + setattr(self, func.__name__, lf) + # Key management routines: def _setup_key_context(self): @@ -123,7 +137,7 @@ def _setup_key_context(self): def _key_basename(self): """ Get the base name for all keys. """ - return '_distarray_key' + return DISTARRAY_BASE_NAME def _key_prefix(self): """ Generate a prefix for a key name for this context. """ diff --git a/distarray/decorators.py b/distarray/decorators.py index c302b92e..953a77bb 100644 --- a/distarray/decorators.py +++ b/distarray/decorators.py @@ -11,108 +11,59 @@ import functools from distarray.client import DistArray -from distarray.context import Context +from distarray.context import DISTARRAY_BASE_NAME from distarray.error import ContextError -from distarray.utils import has_exactly_one +from distarray.externals.six import string_types -class DecoratorBase(object): - """ - Base class for decorators, handles name wrapping and allows the - decorator to take an optional kwarg. +class FunctionRegistrationBase(object): """ + Base class for local function registration. + + Subclasses: + Localize + Vectorize - def __init__(self, fn): - self.fn = fn - self.fn_key = self.fn.__name__ - functools.update_wrapper(self, fn) - self.context = None + """ - def push_fn(self, context, fn_key, fn): - """Push function to the engines.""" - context._push({fn_key: fn}) + def __init__(self, func, context): + self.func = func + self.func_key = self.func.__name__ + functools.update_wrapper(self, func) + self.context = context def determine_context(self, args, kwargs): """ Determine a context from a functions arguments.""" - contexts = [] # inspect args for a context for arg in args + tuple(kwargs.values()): if isinstance(arg, DistArray): - contexts.append(arg.context) - elif isinstance(arg, Context): - contexts.append(arg) - - # check the args had a context - if contexts == []: - raise TypeError('Function must take DistArray or Context objects.') + if arg.context != self.context: + msg = "DistArray %r not in same context as registered function %r." + raise ContextError(msg % (arg, self.func)) - # check that all contexts are equal - if not contexts.count(contexts[0]) == len(contexts): - msg = ("Arguments must use the same Context (given arguments of " - "type %r)") - msg %= (tuple(set(contexts)),) - raise ContextError(msg) + return self.context - return contexts[0] - - def key_and_push_args(self, args, kwargs, context=None, da_handler=None): - """ - Push a tuple of args and dict of kwargs to the engines. Return a - tuple with keys corresponding to args values on the engines. And a - dictionary with the same keys and values which are the keys to the - input dictionary's values. - - This allows us to use the following interface to execute code on - the engines: - - >>> def foo(*args, **kwargs): - >>> args, kwargs = _key_and_push_args(args, kwargs) - >>> exec_str = "remote_foo(*%s, **%s)" - >>> exec_str %= (args, kwargs) - >>> context.execute(exec_str) + def build_args(self, args, kwargs): """ + Returns a new args tuple and kwargs dictionary with all distarrays in + the original args and kwargs arguments replaced by their .keys. - if context is None: - context = self.determine_context(args, kwargs) + """ - # handle positional arguments - arg_keys = [] - push_keys = {} - for arg in args: + args = list(args) + for idx, arg in enumerate(args): if isinstance(arg, DistArray): - if da_handler is None: - arg_keys.append(arg.key) - # da_handler handles distarrays. - else: - arg_keys = da_handler(arg, arg_keys) - else: - new_key = context._generate_key() - arg_keys.append(new_key) - push_keys[new_key] = arg + args[idx] = arg.key # handle key word arguments - for kw in kwargs: - if isinstance(kwargs[kw], DistArray): - kwargs[kw] = kwargs[kw].key - else: - new_key = context._generate_key() - push_keys[new_key] = kwargs[kw] - kwargs[kw] = new_key - - # push the keys to the engines - context._push(push_keys) - - # build arg string - arg_str = '(' + ', '.join(arg_keys) + ',)' + for k, v in kwargs.items(): + if isinstance(v, DistArray): + kwargs[k] = v.key - # build kwarg string - kwarg_iter = ["'%s': %s" % (k, v) for (k, v) in kwargs.items()] - kwarg_str = '{' + ', '.join(kwarg_iter) + '}' + return args, kwargs - return arg_str, kwarg_str - - def process_return_value(self, context, result_key): + def process_return_value(self, result_from_target): """Figure out what to return on the Client. Parameters @@ -128,85 +79,101 @@ def process_return_value(self, context, result_key): client and return it. If all but one of the pulled values is None, return that non-None value only. """ - type_key = context._generate_key() - type_statement = "{} = str(type({}))".format(type_key, result_key) - context._execute(type_statement) - result_type_str = context._pull(type_key) - - def is_NoneType(typestring): - return (typestring == "" or - typestring == "") - - def is_LocalArray(typestring): - return (typestring == "") - - if all(is_LocalArray(r) for r in result_type_str): - result = DistArray.from_localarrays(result_key, context) - elif all(is_NoneType(r) for r in result_type_str): + + results = list(result_from_target.values()) + + if all(isinstance(r, string_types) and r.startswith(DISTARRAY_BASE_NAME) + for r in results): + result = DistArray.from_localarrays(results[0], self.context) + elif all(r is None for r in results): result = None else: - result = context._pull(result_key) - if has_exactly_one(result): - result = next(x for x in result if x is not None) - + non_nones = [r for r in results if r is not None] + if len(non_nones) == 1: + result = non_nones[0] + else: + result = results return result -class local(DecoratorBase): - """Decorator to run a function locally on the engines.""" +def _rpc_localize(func, args, kwargs, result_key, prefix): + + ns = __import__('__main__') + + from distarray.local.localarray import LocalArray + from distarray.externals.six import string_types + + args = list(args) + for idx, a in enumerate(args): + if isinstance(a, string_types): + if a.startswith(prefix): + args[idx] = getattr(ns, a) + + for k, v in kwargs.items(): + if isinstance(v, string_types): + if v.startswith(prefix): + kwargs[k] = getattr(ns, v) + + res = func(*args, **kwargs) + if isinstance(res, LocalArray): + setattr(ns, result_key, res) + return result_key + return res + + +class Localize(FunctionRegistrationBase): + """Runs a function locally on the engines.""" def __call__(self, *args, **kwargs): - # get context from args context = self.determine_context(args, kwargs) - # push function - self.push_fn(context, self.fn_key, self.fn) - - args, kwargs = self.key_and_push_args(args, kwargs, - context=context) + args, kwargs = self.build_args(args, kwargs) result_key = context._generate_key() + results = context.view.apply_async(_rpc_localize, self.func, + args, kwargs, result_key, + DISTARRAY_BASE_NAME).get_dict() + return self.process_return_value(results) + + +def _rpc_vectorize(func, args, kwargs, out, prefix): + + ns = __import__('__main__') + import numpy as np + from distarray.externals.six import string_types - exec_str = "%s = %s(*%s, **%s)" - exec_str %= (result_key, self.fn_key, args, kwargs) - context._execute(exec_str) + args = list(args) + for idx, a in enumerate(args): + if isinstance(a, string_types): + if a.startswith(prefix): + args[idx] = getattr(ns, a).local_array - return self.process_return_value(context, result_key) + for k, v in kwargs.items(): + if isinstance(v, string_types): + if v.startswith(prefix): + kwargs[k] = getattr(ns, v).local_array + out = getattr(ns, out) -class vectorize(DecoratorBase): + func = np.vectorize(func) + out.local_array = func(*args) + + +class Vectorize(FunctionRegistrationBase): """ - Analogous to numpy.vectorize. Input DistArray's must all be the - same shape, and this will be the shape of the output distarray. + Like `Localize`, but vectorizes the function with numpy.vectorize and runs + it on the engines. """ - def get_local_array(self, da, arg_keys): - return arg_keys + [da.key + '.local_array'] - def __call__(self, *args, **kwargs): - # get context from args context = self.determine_context(args, kwargs) - # push function - self.push_fn(context, self.fn_key, self.fn) - # vectorize the function - exec_str = "%s = numpy.vectorize(%s)" % (self.fn_key, self.fn_key) - context._execute(exec_str) - - # Find the first distarray, they should all be the same up to the data. + # TODO: FIXME: This uses an extra round-trip (or two (or three)) to + # create the `out` array. Better would be to create a new LocalArray + # inside _rpc_vectorize and return its metadata to create a DistArray + # using `.from_localarrays()`. for arg in args: if isinstance(arg, DistArray): - # Create the output distarray. out = context.empty(arg.shape, dtype=arg.dtype, - dist=arg.dist, - grid_shape=arg.grid_shape) - # parse args - args_str, kwargs_str = self.key_and_push_args( - args, kwargs, context=context, - da_handler=self.get_local_array) - - # Call the function - exec_str = ("if %s.local_array.size != 0: %s.local_array = " - "%s(*%s, **%s)") - exec_str %= (out.key, out.key, self.fn_key, args_str, - kwargs_str) - context._execute(exec_str) - return out + dist=arg.dist, grid_shape=arg.grid_shape) + args, kwargs = self.build_args(args, kwargs) + context.view.apply_sync(_rpc_vectorize, self.func, + args, kwargs, out.key, DISTARRAY_BASE_NAME) + return out diff --git a/distarray/local/localarray.py b/distarray/local/localarray.py index 8b5d2563..657af035 100644 --- a/distarray/local/localarray.py +++ b/distarray/local/localarray.py @@ -29,6 +29,12 @@ from distarray.local.error import InvalidDimensionError, IncompatibleArrayError +def _rpc(payload): + funcname, args, kwargs = payload + func = globals()[funcname] + return func(*args, **kwargs) + + def _start_stop_block(size, proc_grid_size, proc_grid_rank): nelements = size // proc_grid_size if size % proc_grid_size != 0: diff --git a/distarray/plotting/plotting.py b/distarray/plotting/plotting.py index 1b89e399..368e7309 100644 --- a/distarray/plotting/plotting.py +++ b/distarray/plotting/plotting.py @@ -10,22 +10,7 @@ from six.moves import range from matplotlib import pyplot, colors, cm -from numpy import arange, concatenate, empty, linspace, resize - -from distarray.decorators import local - - -@local -def _get_ranks(arr): - """ - Given a distarray arr, return a distarray with the same shape, but - with the elements equal to the rank of the process the element is - on. - """ - out = arr.copy() - out.local_array[:] = arr.comm_rank - out.local_array = out.local_array.astype(int) - return out +from numpy import concatenate, empty, linspace def cmap_discretize(cmap, N): @@ -272,8 +257,22 @@ def plot_array_distribution(darray, # This is based somewhat on: # http://matplotlib.org/examples/api/colorbar_only.html + def _get_ranks(arr): + """ + Given a distarray arr, return a distarray with the same shape, but + with the elements equal to the rank of the process the element is + on. + """ + out = arr.copy() + out.local_array[:] = arr.comm_rank + out.local_array = out.local_array.astype(int) + return out + + ctx = darray.context + ctx.register(_get_ranks) + # Process per element. - process_darray = _get_ranks(darray) + process_darray = ctx._get_ranks(darray) process_array = process_darray.toarray() # Values per element. diff --git a/distarray/tests/test_decorators.py b/distarray/tests/test_decorators.py index b0110898..c4d64806 100644 --- a/distarray/tests/test_decorators.py +++ b/distarray/tests/test_decorators.py @@ -18,29 +18,30 @@ from numpy.testing import assert_array_equal from distarray.context import Context -from distarray.decorators import DecoratorBase, local, vectorize +from distarray.decorators import FunctionRegistrationBase from distarray.error import ContextError -class TestDecoratorBase(TestCase): +class TestFunctionRegistrationBase(TestCase): def test_determine_context(self): context = Context() context2 = Context() # for cross Context checking da = context.ones((2, 2)) - def dummy_func(*args, **kwargs): - fn = lambda x: x - db = DecoratorBase(fn) + def dummy_func(ctx, *args, **kwargs): + def fn(x): + return x + db = FunctionRegistrationBase(fn, ctx) return db.determine_context(args, kwargs) - self.assertEqual(dummy_func(6, 7, context), context) - self.assertEqual(dummy_func('ab', da), context) - self.assertEqual(dummy_func(a=da), context) + self.assertEqual(dummy_func(context, 6, 7), context) + self.assertEqual(dummy_func(context, 'ab', da), context) + self.assertEqual(dummy_func(context, a=da), context) self.assertEqual(dummy_func(context, a=da), context) - self.assertRaises(TypeError, dummy_func, 'foo') - self.assertRaises(ContextError, dummy_func, context, context2) + db = context2.ones((2, 2)) + self.assertRaises(ContextError, dummy_func, context, db) def test_key_and_push_args(self): context = Context() @@ -48,101 +49,95 @@ def test_key_and_push_args(self): da = context.ones((2, 2)) db = da*2 - def dummy_func(*args, **kwargs): - fn = lambda x: x - db = DecoratorBase(fn) - return db.key_and_push_args(args, kwargs) + def dummy_func(ctx, *args, **kwargs): + def fn(x): + return x + db = FunctionRegistrationBase(fn, ctx) + return db.build_args(args, kwargs) # Push some distarrays - arg_keys1, kw_keys1 = dummy_func(da, db, foo=da, bar=db) + arg_keys1, kw_keys1 = dummy_func(context, da, db, foo=da, bar=db) # with some other data too - arg_keys2, kw_keys2 = dummy_func(da, 'question', answer=42, foo=db) + arg_keys2, kw_keys2 = dummy_func(context, da, 'question', answer=42, foo=db) - self.assertEqual(arg_keys1, "(%s, %s,)" % (da.key, db.key)) + self.assertSequenceEqual(arg_keys1, (da.key, db.key)) # assert we pushed the right key, keystr pair - self.assertTrue("'foo': %s" % (da.key) in kw_keys1) - self.assertTrue("'bar': %s" % (db.key) in kw_keys1) + self.assertDictEqual({'foo': da.key, 'bar': db.key}, kw_keys1) # lots of string manipulation to parse out the relevant pieces # of the python commands. - self.assertEqual(arg_keys2[1: -2].split(', ')[0], da.key) - - _key = arg_keys2[1: -2].split(', ')[1] - self.assertEqual(context._pull0(_key), 'question') - self.assertTrue("'answer'" in kw_keys2) + self.assertSequenceEqual(arg_keys2, (da.key, 'question')) - self.assertTrue("'foo'" in kw_keys2) - self.assertTrue(db.key in kw_keys2) + self.assertSetEqual(set("answer foo".split()), set(kw_keys2.keys())) + self.assertIn(db.key, kw_keys2.values()) class TestLocalDecorator(TestCase): - # Functions for @local decorator tests. These are here so we can - # guarantee they are pushed to the engines before we try to use them. - @local - def local_add50(da): - return da + 50 + @classmethod + def setUpClass(cls): + cls.context = ctx = Context() + cls.da = cls.context.empty((5, 5)) + cls.da.fill(2 * numpy.pi) - @local - def local_add_num(da, num): - return da + num + def local_add50(da): + return da + 50 + ctx.localize(local_add50) - @local - def assert_allclose(da, db): - assert numpy.allclose(da, db), "Arrays not equal within tolerance." + def local_add_num(da, num): + return da + num + ctx.localize(local_add_num) - @local - def local_sin(da): - return numpy.sin(da) + def assert_allclose(da, db): + assert numpy.allclose(da, db), "Arrays not equal within tolerance." + ctx.localize(assert_allclose) - @local - def local_sum(da): - return numpy.sum(da.get_localarray()) + def local_sin(da): + return numpy.sin(da) + ctx.localize(local_sin) - @local - def call_barrier(da): - from mpi4py import MPI - MPI.COMM_WORLD.Barrier() - return da + def local_sum(da): + return numpy.sum(da.get_localarray()) + ctx.localize(local_sum) - @local - def local_add_nums(da, num1, num2, num3): - return da + num1 + num2 + num3 + def call_barrier(da): + from mpi4py import MPI + MPI.COMM_WORLD.Barrier() + return da + ctx.localize(call_barrier) - @local - def local_add_distarrayproxies(da, dg): - return da + dg + def local_add_nums(da, num1, num2, num3): + return da + num1 + num2 + num3 + ctx.localize(local_add_nums) - @local - def local_add_mixed(da, num1, dg, num2): - return da + num1 + dg + num2 + def local_add_distarrayproxies(da, dg): + return da + dg + ctx.localize(local_add_distarrayproxies) - @local - def local_add_ndarray(da, num, ndarr): - return da + num + ndarr + def local_add_mixed(da, num1, dg, num2): + return da + num1 + dg + num2 + ctx.localize(local_add_mixed) - @local - def local_add_kwargs(da, num1, num2=55): - return da + num1 + num2 + def local_add_ndarray(da, num, ndarr): + return da + num + ndarr + ctx.localize(local_add_ndarray) - @local - def local_add_supermix(da, num1, db, num2, dc, num3=99, num4=66): - return da + num1 + db + num2 + dc + num3 + num4 + def local_add_kwargs(da, num1, num2=55): + return da + num1 + num2 + ctx.localize(local_add_kwargs) - @local - def local_none(da): - return None + def local_add_supermix(da, num1, db, num2, dc, num3=99, num4=66): + return da + num1 + db + num2 + dc + num3 + num4 + ctx.localize(local_add_supermix) - @local - def parameterless(): - """This is a parameterless function.""" - return None + def local_none(da): + return None + ctx.localize(local_none) - @classmethod - def setUpClass(cls): - cls.context = Context() - cls.da = cls.context.empty((5, 5)) - cls.da.fill(2 * numpy.pi) + def parameterless(): + """This is a parameterless function.""" + return None + ctx.localize(parameterless) @classmethod def tearDownClass(cls): @@ -160,39 +155,28 @@ def fill_a(a): a[i, j] = i + j return a - @local def fill_da(da): for i in da.maps[0].global_iter: for j in da.maps[1].global_iter: da.global_index[i, j] = i + j return da + context.localize(fill_da) - da = fill_da(da) + da = context.fill_da(da) a = fill_a(a) assert_array_equal(da.toarray(), a) - def test_different_contexts(self): - ctx1 = Context(targets=range(4)) - ctx2 = Context(targets=range(3)) - da1 = ctx1.ones((10,)) - da2 = ctx2.ones((10,)) - db1 = self.local_sin(da1) - db2 = self.local_sin(da2) - ndarr1 = db1.toarray() - ndarr2 = db2.toarray() - assert_array_equal(ndarr1, ndarr2) - def test_local_sin(self): - db = self.local_sin(self.da) - self.assert_allclose(db, 0) + db = self.context.local_sin(self.da) + self.context.assert_allclose(db, 0) def test_local_add50(self): - dc = self.local_add50(self.da) - self.assert_allclose(dc, 2 * numpy.pi + 50) + dc = self.context.local_add50(self.da) + self.context.assert_allclose(dc, 2 * numpy.pi + 50) def test_local_sum(self): - dd = self.local_sum(self.da) + dd = self.context.local_sum(self.da) lshapes = self.da.get_localshapes() expected = [] for lshape in lshapes: @@ -201,61 +185,62 @@ def test_local_sum(self): self.assertAlmostEqual(v, e, places=5) def test_local_add_num(self): - de = self.local_add_num(self.da, 11) - self.assert_allclose(de, 2 * numpy.pi + 11) + de = self.context.local_add_num(self.da, 11) + self.context.assert_allclose(de, 2 * numpy.pi + 11) def test_local_add_nums(self): - df = self.local_add_nums(self.da, 11, 12, 13) - self.assert_allclose(df, 2 * numpy.pi + 11 + 12 + 13) + df = self.context.local_add_nums(self.da, 11, 12, 13) + self.context.assert_allclose(df, 2 * numpy.pi + 11 + 12 + 13) def test_local_add_distarrayproxies(self): dg = self.context.empty((5, 5)) dg.fill(33) - dh = self.local_add_distarrayproxies(self.da, dg) - self.assert_allclose(dh, 33 + 2 * numpy.pi) + dh = self.context.local_add_distarrayproxies(self.da, dg) + self.context.assert_allclose(dh, 33 + 2 * numpy.pi) def test_local_add_mixed(self): di = self.context.empty((5, 5)) di.fill(33) - dj = self.local_add_mixed(self.da, 11, di, 12) - self.assert_allclose(dj, 2 * numpy.pi + 11 + 33 + 12) + dj = self.context.local_add_mixed(self.da, 11, di, 12) + self.context.assert_allclose(dj, 2 * numpy.pi + 11 + 33 + 12) @unittest.skip('Locally adding ndarrays not supported.') def test_local_add_ndarray(self): shp = self.da.get_localshapes()[0] ndarr = numpy.empty(shp) ndarr.fill(33) - dk = self.local_add_ndarray(self.da, 11, ndarr) - self.assert_allclose(dk, 2 * numpy.pi + 11 + 33) + dk = self.context.local_add_ndarray(self.da, 11, ndarr) + self.context.assert_allclose(dk, 2 * numpy.pi + 11 + 33) def test_local_add_kwargs(self): - dl = self.local_add_kwargs(self.da, 11, num2=12) - self.assert_allclose(dl, 2 * numpy.pi + 11 + 12) + dl = self.context.local_add_kwargs(self.da, 11, num2=12) + self.context.assert_allclose(dl, 2 * numpy.pi + 11 + 12) def test_local_add_supermix(self): dm = self.context.empty((5, 5)) dm.fill(22) dn = self.context.empty((5, 5)) dn.fill(44) - do = self.local_add_supermix(self.da, 11, dm, 33, dc=dn, num3=55) + do = self.context.local_add_supermix(self.da, 11, dm, 33, dc=dn, num3=55) expected = 2 * numpy.pi + 11 + 22 + 33 + 44 + 55 + 66 - self.assert_allclose(do, expected) + self.context.assert_allclose(do, expected) def test_local_none(self): - dp = self.local_none(self.da) + dp = self.context.local_none(self.da) self.assertTrue(dp is None) def test_barrier(self): - self.call_barrier(self.da) + self.context.call_barrier(self.da) def test_parameterless(self): - self.assertRaises(TypeError, self.parameterless) + result = self.context.parameterless() + self.assertIsNone(result) def test_function_metadata(self): name = "parameterless" docstring = """This is a parameterless function.""" - self.assertEqual(self.parameterless.__name__, name) - self.assertEqual(self.parameterless.__doc__, docstring) + self.assertEqual(self.context.parameterless.__name__, name) + self.assertEqual(self.context.parameterless.__doc__, docstring) class TestVectorizeDecorator(TestCase): @@ -263,21 +248,21 @@ class TestVectorizeDecorator(TestCase): def test_vectorize(self): """Test the @vectorize decorator for parity with NumPy's""" - context = Context() + ctx = Context() a = numpy.arange(16).reshape(4, 4) - da = context.fromndarray(a) + da = ctx.fromndarray(a) - @vectorize def da_fn(a, b, c): return a**2 + b + c + ctx.vectorize(da_fn) @numpy.vectorize def a_fn(a, b, c): return a**2 + b + c a = a_fn(a, a, 6) - db = da_fn(da, da, 6) + db = ctx.da_fn(da, da, 6) assert_array_equal(db.toarray(), a)