diff --git a/distarray/context.py b/distarray/context.py index 498ec36a..74434670 100644 --- a/distarray/context.py +++ b/distarray/context.py @@ -477,11 +477,28 @@ def fromndarray(self, arr, dist=None, grid_shape=None): fromarray = fromndarray def fromfunction(self, function, shape, **kwargs): - func_key = self._generate_key() - self.view.push_function({func_key: function}, targets=self.targets, - block=True) - keys = self._key_and_push(shape, kwargs) - new_key = self._generate_key() - subs = (new_key, func_key) + keys - self._execute('%s = distarray.local.fromfunction(%s,%s,**%s)' % subs) - return DistArray.from_localarrays(new_key, context=self) + """Create a DistArray from a function over global indices. + + Unlike numpy's `fromfunction`, the result of distarray's + `fromfunction` is restricted to the same Distribution as the + index array generated from `shape`. + + See numpy.fromfunction for more details. + """ + dtype = kwargs.get('dtype', None) + dist = kwargs.get('dist', None) + grid_shape = kwargs.get('grid_shape', None) + distribution = Distribution.from_shape(context=self, + shape=shape, dist=dist, + grid_shape=grid_shape) + ddpr = distribution.get_dim_data_per_rank() + function_name, ddpr_name, kwargs_name = \ + self._key_and_push(function, ddpr, kwargs) + da_name = self._generate_key() + comm_name = self._comm_key + cmd = ('{da_name} = distarray.local.fromfunction({function_name}, ' + 'distarray.local.maps.Distribution(' + '{ddpr_name}[{comm_name}.Get_rank()], comm={comm_name}),' + '**{kwargs_name})') + self._execute(cmd.format(**locals())) + return DistArray.from_localarrays(da_name, distribution=distribution) diff --git a/distarray/tests/test_client.py b/distarray/tests/test_client.py index 4de68e60..38b0ff39 100644 --- a/distarray/tests/test_client.py +++ b/distarray/tests/test_client.py @@ -354,6 +354,13 @@ def test_grid_rank(self): grid_shape=(1, 1, 4)) self.assertEqual(a.grid_shape, (1, 1, 4)) + def test_fromfunction(self): + fn = lambda i, j: i + j + shape = (7, 9) + expected = numpy.fromfunction(fn, shape, dtype=int) + result = self.context.fromfunction(fn, shape, dtype=int) + assert_array_equal(expected, result.tondarray()) + class TestReduceMethods(unittest.TestCase): """Test reduction methods"""