Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 75 additions & 17 deletions distarray/dist/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from __future__ import absolute_import

import uuid
import collections
import atexit

Expand All @@ -22,7 +21,7 @@
from distarray.dist.maps import Distribution

from distarray.dist.ipython_utils import IPythonClient
from distarray import DISTARRAY_BASE_NAME
from distarray.utils import uid, DISTARRAY_BASE_NAME


class Context(object):
Expand Down Expand Up @@ -69,8 +68,10 @@ def __init__(self, client=None, targets=None):
# FIXME: IPython bug #4296: This doesn't work under Python 3
#with self.view.sync_imports():
# import distarray
self.view.execute("import distarray.local; "
self.view.execute("from functools import reduce; "
"import distarray.local; "
"import distarray.local.mpiutils; "
"import distarray.utils; "
"import numpy")

self.context_key = self._setup_context_key()
Expand All @@ -82,7 +83,7 @@ def _setup_context_key(self):
Create a dict on the engines which will hold everything from
this context.
"""
context_key = DISTARRAY_BASE_NAME + self.uid()
context_key = uid()
cmd = ("import types, sys;"
"%s = types.ModuleType('%s');")
cmd %= (context_key, context_key)
Expand Down Expand Up @@ -139,21 +140,9 @@ def _set_engine_rank_mapping(self):
# the intracomm.
self.targets = [target_from_rank[i] for i in range(len(target_from_rank))]

# Key management routines:
@staticmethod
def _key_prefix():
""" Get the base name for all keys. """
return DISTARRAY_BASE_NAME

@staticmethod
def uid():
"""Generate a unique valid python name."""
# Full length seems excessively verbose so use 16 characters.
return Context._key_prefix() + uuid.uuid4().hex[:16]

def _generate_key(self):
""" Generate a unique key name for this context. """
key = "%s.%s" % (self.context_key, 'key_' + self.uid())
key = "%s.%s" % (self.context_key, 'key_' + uid())
return key

def _key_and_push(self, *values):
Expand Down Expand Up @@ -503,3 +492,72 @@ def fromfunction(self, function, shape, **kwargs):
'**{kwargs_name})')
self._execute(cmd.format(**locals()))
return DistArray.from_localarrays(da_name, distribution=distribution)

def apply(self, func, args=None, kwargs=None, targets=None,
result_name=None):
"""
Analogous to IPython.parallel.view.apply_sync

Parameters
----------
func : function
args : tuple
positional arguments to func
kwargs : dict
key word arguments to func
targets : sequence of integers
engines func is to be run on.
result_name : str
The name given the result on the engines. If given this is returned
to act as a proxy object.

Returns
-------
if result_name is not None : str
Name of the result on the engines.
else: list
A list of the results on all the engines.
"""

def func_wrapper(func, result_name, args, kwargs):
"""
Function which calls the applied function after grabbing all the
arguments on the engines that are passed in as names of the form
`__distarray__<some uuid>`.
"""
main = __import__('__main__')
prefix = main.distarray.utils.DISTARRAY_BASE_NAME

# convert args
args = list(args)
for i, a in enumerate(args):
if (isinstance(a, str) and a.startswith(prefix)):
args[i] = main.reduce(getattr, [main] + a.split('.'))
args = tuple(args)

# convert kwargs
for k in kwargs.keys():
val = kwargs[k]
if (isinstance(val, str) and val.startswith(prefix)):
kwargs[k] = main.reduce(getattr, [main] + val.split('.'))

if result_name:
setattr(main, result_name, func(*args, **kwargs))
return result_name
else:
return func(*args, **kwargs)

# default arguments
args = () if args is None else args
kwargs = {} if kwargs is None else kwargs
wrapped_args = (func, result_name, args, kwargs)

targets = self.targets if targets is None else targets

result = self.view._really_apply(func_wrapper, args=wrapped_args,
targets=targets, block=True)
if result_name is not None:
# result is a list of the same name 4 times, so just return 1.
return result[0]
else:
return result
81 changes: 81 additions & 0 deletions distarray/dist/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,5 +126,86 @@ def test_3D(self):
self.assertEqual(c.grid_shape, (1, 1, 3))


class TestApply(unittest.TestCase):

@classmethod
def setUpClass(cls):
cls.context = Context()

def test_apply_no_args(self):

def foo():
return 42

val = self.context.apply(foo)

self.assertEqual(val, [42]*4)

def test_apply_pos_args(self):

def foo(a, b, c):
return a + b + c

# push all arguments
val = self.context.apply(foo, (1, 2, 3))
self.assertEqual(val, [6]*4)

# some local, some pushed
local_thing = self.context._key_and_push(2)[0]
val = self.context.apply(foo, (1, local_thing, 3))

self.assertEqual(val, [6]*4)

# all pushed
local_args = self.context._key_and_push(1, 2, 3)
val = self.context.apply(foo, local_args)

self.assertEqual(val, [6]*4)

def test_apply_kwargs(self):

def foo(a, b, c=None, d=None):
c = -1 if c is None else c
d = -2 if d is None else d
return a + b + c + d

# empty kwargs
val = self.context.apply(foo, (1, 2))

self.assertEqual(val, [0]*4)

# some empty
val = self.context.apply(foo, (1, 2), {'d': 3})

self.assertEqual(val, [5]*4)

# all kwargs
val = self.context.apply(foo, (1, 2), {'c': 2, 'd': 3})

self.assertEqual(val, [8]*4)

# now with local values
local_a = self.context._key_and_push(1)[0]
local_c = self.context._key_and_push(3)[0]

val = self.context.apply(foo, (local_a, 2), {'c': local_c, 'd': 3})

self.assertEqual(val, [9]*4)

def test_apply_return_val(self):

def foo(a, b, c=None):
c = 3 if c is None else c
return a + b + c

name = self.context.apply(foo, (1, 2), {'c': 5}, result_name='test')

self.assertEqual(name, 'test')

val = self.context._pull(name)

self.assertEqual(val, [8]*len(self.context.targets))


if __name__ == '__main__':
unittest.main(verbosity=2)
5 changes: 5 additions & 0 deletions distarray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,14 @@
"""

from math import sqrt
import uuid

from distarray.externals.six import next

DISTARRAY_BASE_NAME = '__distarray__'

def uid():
return DISTARRAY_BASE_NAME + uuid.uuid4().hex[:16]

def multi_for(iterables):
if not iterables:
Expand Down