diff --git a/distarray/dist/context.py b/distarray/dist/context.py index 807e57d8..c9120e0d 100644 --- a/distarray/dist/context.py +++ b/distarray/dist/context.py @@ -10,7 +10,6 @@ from __future__ import absolute_import -import uuid import collections import atexit @@ -22,7 +21,7 @@ from distarray.dist.maps import Distribution from distarray.dist.ipython_utils import IPythonClient -from distarray import DISTARRAY_BASE_NAME +from distarray.utils import uid, DISTARRAY_BASE_NAME class Context(object): @@ -69,8 +68,10 @@ def __init__(self, client=None, targets=None): # FIXME: IPython bug #4296: This doesn't work under Python 3 #with self.view.sync_imports(): # import distarray - self.view.execute("import distarray.local; " + self.view.execute("from functools import reduce; " + "import distarray.local; " "import distarray.local.mpiutils; " + "import distarray.utils; " "import numpy") self.context_key = self._setup_context_key() @@ -82,7 +83,7 @@ def _setup_context_key(self): Create a dict on the engines which will hold everything from this context. """ - context_key = DISTARRAY_BASE_NAME + self.uid() + context_key = uid() cmd = ("import types, sys;" "%s = types.ModuleType('%s');") cmd %= (context_key, context_key) @@ -139,21 +140,9 @@ def _set_engine_rank_mapping(self): # the intracomm. self.targets = [target_from_rank[i] for i in range(len(target_from_rank))] - # Key management routines: - @staticmethod - def _key_prefix(): - """ Get the base name for all keys. """ - return DISTARRAY_BASE_NAME - - @staticmethod - def uid(): - """Generate a unique valid python name.""" - # Full length seems excessively verbose so use 16 characters. - return Context._key_prefix() + uuid.uuid4().hex[:16] - def _generate_key(self): """ Generate a unique key name for this context. """ - key = "%s.%s" % (self.context_key, 'key_' + self.uid()) + key = "%s.%s" % (self.context_key, 'key_' + uid()) return key def _key_and_push(self, *values): @@ -503,3 +492,72 @@ def fromfunction(self, function, shape, **kwargs): '**{kwargs_name})') self._execute(cmd.format(**locals())) return DistArray.from_localarrays(da_name, distribution=distribution) + + def apply(self, func, args=None, kwargs=None, targets=None, + result_name=None): + """ + Analogous to IPython.parallel.view.apply_sync + + Parameters + ---------- + func : function + args : tuple + positional arguments to func + kwargs : dict + key word arguments to func + targets : sequence of integers + engines func is to be run on. + result_name : str + The name given the result on the engines. If given this is returned + to act as a proxy object. + + Returns + ------- + if result_name is not None : str + Name of the result on the engines. + else: list + A list of the results on all the engines. + """ + + def func_wrapper(func, result_name, args, kwargs): + """ + Function which calls the applied function after grabbing all the + arguments on the engines that are passed in as names of the form + `__distarray__`. + """ + main = __import__('__main__') + prefix = main.distarray.utils.DISTARRAY_BASE_NAME + + # convert args + args = list(args) + for i, a in enumerate(args): + if (isinstance(a, str) and a.startswith(prefix)): + args[i] = main.reduce(getattr, [main] + a.split('.')) + args = tuple(args) + + # convert kwargs + for k in kwargs.keys(): + val = kwargs[k] + if (isinstance(val, str) and val.startswith(prefix)): + kwargs[k] = main.reduce(getattr, [main] + val.split('.')) + + if result_name: + setattr(main, result_name, func(*args, **kwargs)) + return result_name + else: + return func(*args, **kwargs) + + # default arguments + args = () if args is None else args + kwargs = {} if kwargs is None else kwargs + wrapped_args = (func, result_name, args, kwargs) + + targets = self.targets if targets is None else targets + + result = self.view._really_apply(func_wrapper, args=wrapped_args, + targets=targets, block=True) + if result_name is not None: + # result is a list of the same name 4 times, so just return 1. + return result[0] + else: + return result diff --git a/distarray/dist/tests/test_context.py b/distarray/dist/tests/test_context.py index 0fb3f72e..2628ff80 100644 --- a/distarray/dist/tests/test_context.py +++ b/distarray/dist/tests/test_context.py @@ -126,5 +126,86 @@ def test_3D(self): self.assertEqual(c.grid_shape, (1, 1, 3)) +class TestApply(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.context = Context() + + def test_apply_no_args(self): + + def foo(): + return 42 + + val = self.context.apply(foo) + + self.assertEqual(val, [42]*4) + + def test_apply_pos_args(self): + + def foo(a, b, c): + return a + b + c + + # push all arguments + val = self.context.apply(foo, (1, 2, 3)) + self.assertEqual(val, [6]*4) + + # some local, some pushed + local_thing = self.context._key_and_push(2)[0] + val = self.context.apply(foo, (1, local_thing, 3)) + + self.assertEqual(val, [6]*4) + + # all pushed + local_args = self.context._key_and_push(1, 2, 3) + val = self.context.apply(foo, local_args) + + self.assertEqual(val, [6]*4) + + def test_apply_kwargs(self): + + def foo(a, b, c=None, d=None): + c = -1 if c is None else c + d = -2 if d is None else d + return a + b + c + d + + # empty kwargs + val = self.context.apply(foo, (1, 2)) + + self.assertEqual(val, [0]*4) + + # some empty + val = self.context.apply(foo, (1, 2), {'d': 3}) + + self.assertEqual(val, [5]*4) + + # all kwargs + val = self.context.apply(foo, (1, 2), {'c': 2, 'd': 3}) + + self.assertEqual(val, [8]*4) + + # now with local values + local_a = self.context._key_and_push(1)[0] + local_c = self.context._key_and_push(3)[0] + + val = self.context.apply(foo, (local_a, 2), {'c': local_c, 'd': 3}) + + self.assertEqual(val, [9]*4) + + def test_apply_return_val(self): + + def foo(a, b, c=None): + c = 3 if c is None else c + return a + b + c + + name = self.context.apply(foo, (1, 2), {'c': 5}, result_name='test') + + self.assertEqual(name, 'test') + + val = self.context._pull(name) + + self.assertEqual(val, [8]*len(self.context.targets)) + + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/distarray/utils.py b/distarray/utils.py index 52e52764..d8b4ab82 100644 --- a/distarray/utils.py +++ b/distarray/utils.py @@ -8,9 +8,14 @@ """ from math import sqrt +import uuid + from distarray.externals.six import next +DISTARRAY_BASE_NAME = '__distarray__' +def uid(): + return DISTARRAY_BASE_NAME + uuid.uuid4().hex[:16] def multi_for(iterables): if not iterables: