Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 25 additions & 21 deletions distarray/dist/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import functools

from distarray.dist.distarray import DistArray
from distarray.dist.context import Context
from distarray.error import ContextError
from distarray.dist.maps import Distribution
from distarray.error import DistributionError
from distarray.utils import has_exactly_one


Expand All @@ -34,29 +34,30 @@ def push_fn(self, context, fn_key, fn):
"""Push function to the engines."""
context._push({fn_key: fn}, targets=context.targets)

def determine_context(self, args, kwargs):
""" Determine a context from a functions arguments."""
def determine_distribution(self, args, kwargs):
""" Determine a distribution from a functions arguments."""

contexts = []
dists = []
# inspect args for a context
for arg in args + tuple(kwargs.values()):
if isinstance(arg, DistArray):
contexts.append(arg.context)
elif isinstance(arg, Context):
contexts.append(arg)
dists.append(arg.distribution)
elif isinstance(arg, Distribution):
dists.append(arg)

# check the args had a context
if contexts == []:
raise TypeError('Function must take DistArray or Context objects.')
if dists == []:
raise TypeError('Function must take DistArray or Distribution'
' objects.')

# check that all contexts are equal
if not contexts.count(contexts[0]) == len(contexts):
msg = ("Arguments must use the same Context (given arguments of "
"type %r)")
msg %= (tuple(set(contexts)),)
raise ContextError(msg)
if not dists.count(dists[0]) == len(dists):
msg = ("Arguments must use the same Distribution (given arguments "
"of type %r)")
msg %= (tuple(set(dists)),)
raise DistributionError(msg)

return contexts[0]
return dists[0]

def key_and_push_args(self, args, kwargs, context=None, da_handler=None):
"""
Expand All @@ -76,7 +77,8 @@ def key_and_push_args(self, args, kwargs, context=None, da_handler=None):
"""

if context is None:
context = self.determine_context(args, kwargs)
distribution = self.determine_distribution(args, kwargs)
context = distribution.context

# handle positional arguments
arg_keys = []
Expand Down Expand Up @@ -160,7 +162,8 @@ class local(DecoratorBase):

def __call__(self, *args, **kwargs):
# get context from args
context = self.determine_context(args, kwargs)
distribution = self.determine_distribution(args, kwargs)
context = distribution.context
# push function
self.push_fn(context, self.fn_key, self.fn)

Expand All @@ -170,7 +173,7 @@ def __call__(self, *args, **kwargs):

exec_str = "%s = %s(*%s, **%s)"
exec_str %= (result_key, self.fn_key, args, kwargs)
context._execute(exec_str, targets=context.targets)
context._execute(exec_str, targets=distribution.targets)

return self.process_return_value(context, result_key)

Expand All @@ -186,12 +189,13 @@ def get_ndarray(self, da, arg_keys):

def __call__(self, *args, **kwargs):
# get context from args
context = self.determine_context(args, kwargs)
distribution = self.determine_distribution(args, kwargs)
context = distribution.context
# push function
self.push_fn(context, self.fn_key, self.fn)
# vectorize the function
exec_str = "%s = numpy.vectorize(%s)" % (self.fn_key, self.fn_key)
context._execute(exec_str, targets=context.targets)
context._execute(exec_str, targets=distribution.targets)

# Find the first distarray, they should all be the same up to the data.
for arg in args:
Expand Down
32 changes: 15 additions & 17 deletions distarray/dist/tests/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,30 @@
from distarray.dist.context import Context
from distarray.dist.maps import Distribution
from distarray.dist.decorators import DecoratorBase, local, vectorize
from distarray.error import ContextError
from distarray.error import DistributionError


class TestDecoratorBase(TestCase):

def test_determine_context(self):
def test_determine_distribution(self):
context = Context()
context2 = Context() # for cross Context checking
distribution = Distribution.from_shape(context, (2, 2))
da = context.ones(distribution)
dist = Distribution.from_shape(context, (2, 2))
dist2 = Distribution.from_shape(context2, (2, 2))
da = context.ones(dist)

def dummy_func(*args, **kwargs):
fn = lambda x: x
db = DecoratorBase(fn)
return db.determine_context(args, kwargs)
return db.determine_distribution(args, kwargs)

self.assertEqual(dummy_func(6, 7, context), context)
self.assertEqual(dummy_func('ab', da), context)
self.assertEqual(dummy_func(a=da), context)
self.assertEqual(dummy_func(context, a=da), context)
self.assertEqual(dummy_func(6, 7, dist), dist)
self.assertEqual(dummy_func('ab', da), dist)
self.assertEqual(dummy_func(a=da), dist)
self.assertEqual(dummy_func(dist, a=da), dist)

self.assertRaises(TypeError, dummy_func, 'foo')
self.assertRaises(ContextError, dummy_func, context, context2)
self.assertRaises(DistributionError, dummy_func, dist, dist2)

def test_key_and_push_args(self):
context = Context()
Expand Down Expand Up @@ -219,15 +220,13 @@ def test_local_add_nums(self):
self.assert_allclose(df, 2 * numpy.pi + 11 + 12 + 13)

def test_local_add_distarrayproxies(self):
distribution = Distribution.from_shape(self.context, (5, 5))
dg = self.context.empty(distribution)
dg = self.context.empty(self.da.distribution)
dg.fill(33)
dh = self.local_add_distarrayproxies(self.da, dg)
self.assert_allclose(dh, 33 + 2 * numpy.pi)

def test_local_add_mixed(self):
distribution = Distribution.from_shape(self.context, (5, 5))
di = self.context.empty(distribution)
di = self.context.empty(self.da.distribution)
di.fill(33)
dj = self.local_add_mixed(self.da, 11, di, 12)
self.assert_allclose(dj, 2 * numpy.pi + 11 + 33 + 12)
Expand All @@ -245,10 +244,9 @@ def test_local_add_kwargs(self):
self.assert_allclose(dl, 2 * numpy.pi + 11 + 12)

def test_local_add_supermix(self):
distribution = Distribution.from_shape(self.context, (5, 5))
dm = self.context.empty(distribution)
dm = self.context.empty(self.da.distribution)
dm.fill(22)
dn = self.context.empty(distribution)
dn = self.context.empty(self.da.distribution)
dn.fill(44)
do = self.local_add_supermix(self.da, 11, dm, 33, dc=dn, num3=55)
expected = 2 * numpy.pi + 11 + 22 + 33 + 44 + 55 + 66
Expand Down
4 changes: 4 additions & 0 deletions distarray/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,7 @@ class InvalidRankError(MPIDistArrayError):

class MPICommError(MPIDistArrayError):
pass


class DistributionError(DistArrayError):
pass