diff --git a/distarray/globalapi/context.py b/distarray/globalapi/context.py index c593fe5d..c73aad7c 100644 --- a/distarray/globalapi/context.py +++ b/distarray/globalapi/context.py @@ -71,7 +71,7 @@ def make_subcomm(self, new_targets): pass @abstractmethod - def apply(self, func, args=None, kwargs=None, targets=None): + def apply(self, func, args=None, kwargs=None, targets=None, autoproxyize=False): pass @abstractmethod @@ -816,6 +816,10 @@ def func_wrapper(func, apply_nonce, context_key, args, kwargs, autoproxyize): # default arguments args = () if args is None else args kwargs = {} if kwargs is None else kwargs + + args = tuple(a.key if isinstance(a, DistArray) else a for a in args) + kwargs = {k: (v.key if isinstance(v, DistArray) else v) for k, v in kwargs.items()} + apply_nonce = nonce() wrapped_args = (func, apply_nonce, self.context_key, args, kwargs, autoproxyize) @@ -972,6 +976,10 @@ def apply(self, func, args=None, kwargs=None, targets=None, autoproxyize=False): # default arguments args = () if args is None else args kwargs = {} if kwargs is None else kwargs + + args = tuple(a.key if isinstance(a, DistArray) else a for a in args) + kwargs = {k: (v.key if isinstance(v, DistArray) else v) for k, v in kwargs.items()} + targets = self.targets if targets is None else targets apply_nonce = nonce() diff --git a/distarray/globalapi/tests/test_context.py b/distarray/globalapi/tests/test_context.py index bac00754..ea1eca1b 100644 --- a/distarray/globalapi/tests/test_context.py +++ b/distarray/globalapi/tests/test_context.py @@ -378,7 +378,6 @@ def foo(a, b, c=None, d=None): self.assertEqual(val, [9] * self.ntargets) - def test_apply_proxy(self): def foo(): @@ -401,6 +400,20 @@ def foo(): self.assertEqual(set(r[0].name for r in res), set([res[0][0].name])) self.assertEqual(set(r[-1].name for r in res), set([res[0][-1].name])) + def test_apply_distarray(self): + + da = self.context.empty((len(self.context.targets),), dtype=numpy.uint32) + + def local_label(la): + la.ndarray.fill(la.comm.rank) + + # Testing that we can pass in `da` and `apply()` extracts `da.key` automatically. + self.context.apply(local_label, (da,)) + assert_array_equal(da.tondarray(), range(len(self.context.targets))) + + self.context.apply(local_label, kwargs={'la': da}) + assert_array_equal(da.tondarray(), range(len(self.context.targets))) + class TestGetBaseComm(DefaultContextTestCase): ntargets = 'any'