diff --git a/distarray/dist/functions.py b/distarray/dist/functions.py index 950c61ce..e0f40131 100644 --- a/distarray/dist/functions.py +++ b/distarray/dist/functions.py @@ -46,7 +46,8 @@ def func_call(func_name, arr_name, args, kwargs): res = func(arr_name, *args, **kwargs) return proxyize(res), res.dtype # noqa - res = context.apply(func_call, args=(name, a.key, args, kwargs)) + res = context.apply(func_call, args=(name, a.key, args, kwargs), + targets=a.targets) new_key = res[0][0] dtype = res[0][1] return DistArray.from_localarrays(new_key, @@ -84,7 +85,8 @@ def func_call(func_name, a, b, args, kwargs): res = func(a, b, *args, **kwargs) return proxyize(res), res.dtype # noqa - res = context.apply(func_call, args=(name, a_key, b_key, args, kwargs)) + res = context.apply(func_call, args=(name, a_key, b_key, args, kwargs), + targets=distribution.targets) new_key = res[0][0] dtype = res[0][1] return DistArray.from_localarrays(new_key, diff --git a/distarray/dist/maps.py b/distarray/dist/maps.py index bca5f2f6..bef6d3d4 100644 --- a/distarray/dist/maps.py +++ b/distarray/dist/maps.py @@ -411,7 +411,7 @@ def from_dim_data_per_rank(cls, context, dim_data_per_rank, targets=None): self.ndim = len(dd0) self.dist = tuple(dd['dist_type'] for dd in dd0) self.grid_shape = tuple(dd['proc_grid_size'] for dd in dd0) - self.grid_shape = normalize_grid_shape(self.grid_shape, self.ndim, + self.grid_shape = normalize_grid_shape(self.grid_shape, self.shape, self.dist, len(self.targets)) coords = [tuple(d['proc_grid_rank'] for d in dd) for dd in @@ -453,8 +453,6 @@ def from_shape(cls, context, shape, dist=None, grid_shape=None, self = cls.__new__(cls) self.context = context - self.targets = sorted(targets or context.targets) - self.comm = self.context._make_subcomm(self.targets) self.shape = shape self.ndim = len(shape) @@ -463,13 +461,19 @@ def from_shape(cls, context, shape, dist=None, grid_shape=None, dist = {0: 'b'} self.dist = normalize_dist(dist, self.ndim) + # all possible targets + all_targets = sorted(targets or context.targets) # grid_shape if grid_shape is None: grid_shape = make_grid_shape(self.shape, self.dist, - len(self.targets)) + len(all_targets)) - self.grid_shape = normalize_grid_shape(grid_shape, self.ndim, - self.dist, len(self.targets)) + self.grid_shape = normalize_grid_shape(grid_shape, self.shape, + self.dist, len(all_targets)) + ntargets = reduce(operator.mul, self.grid_shape, 1) + # choose targets from grid_shape + self.targets = all_targets[:ntargets] + self.comm = self.context._make_subcomm(self.targets) # TODO: FIXME: assert that self.rank_from_coords is valid and conforms # to how MPI does it. @@ -579,7 +583,7 @@ def __init__(self, context, global_dim_data, targets=None): self.dist = tuple(m.dist for m in self.maps) self.grid_shape = tuple(m.grid_size for m in self.maps) - self.grid_shape = normalize_grid_shape(self.grid_shape, self.ndim, + self.grid_shape = normalize_grid_shape(self.grid_shape, self.shape, self.dist, len(self.targets)) nelts = reduce(operator.mul, self.grid_shape, 1) diff --git a/distarray/dist/random.py b/distarray/dist/random.py index b1fc7d4c..25133e8c 100644 --- a/distarray/dist/random.py +++ b/distarray/dist/random.py @@ -62,7 +62,7 @@ def rand(self, distribution): da_key = self.context._generate_key() ddpr = distribution.get_dim_data_per_rank() ddpr_name = self.context._key_and_push(ddpr)[0] - comm_name = self.context.comm + comm_name = distribution.comm self.context._execute( '{da_key} = distarray.local.random.rand(' 'distribution=distarray.local.maps.Distribution(' @@ -122,7 +122,7 @@ def normal(self, distribution, loc=0.0, scale=1.0): ddpr = distribution.get_dim_data_per_rank() loc_name, scale_name, ddpr_name = \ self.context._key_and_push(loc, scale, ddpr) - comm_name = self.context.comm + comm_name = distribution.comm self.context._execute( '{da_key} = distarray.local.random.normal(' 'loc={loc_name}, scale={scale_name},' @@ -160,7 +160,7 @@ def randint(self, distribution, low, high=None): ddpr = distribution.get_dim_data_per_rank() low_name, high_name, ddpr_name = \ self.context._key_and_push(low, high, ddpr) - comm_name = self.context.comm + comm_name = distribution.comm self.context._execute( '{da_key} = distarray.local.random.randint(' 'low={low_name}, high={high_name},' @@ -186,7 +186,7 @@ def randn(self, distribution): da_key = self.context._generate_key() ddpr = distribution.get_dim_data_per_rank() ddpr_name = self.context._key_and_push(ddpr)[0] - comm_name = self.context.comm + comm_name = distribution.comm self.context._execute( '{da_key} = distarray.local.random.randn(' 'distribution=distarray.local.maps.Distribution(' diff --git a/distarray/dist/tests/test_distarray.py b/distarray/dist/tests/test_distarray.py index e276b3b7..357f6d52 100644 --- a/distarray/dist/tests/test_distarray.py +++ b/distarray/dist/tests/test_distarray.py @@ -439,11 +439,16 @@ def test_mean_multiaxis(self): da_mean = self.darr.mean(axis=(0, 1)) assert_allclose(da_mean.tondarray(), np_mean) - def test_mean_along_axis_1(self): + def test_mean_along_axis_0(self): da_mean = self.darr.mean(axis=0) np_mean = self.arr.mean(axis=0) assert_allclose(da_mean.tondarray(), np_mean) + def test_mean_along_axis_1(self): + da_mean = self.darr.mean(axis=1) + np_mean = self.arr.mean(axis=1) + assert_allclose(da_mean.tondarray(), np_mean) + def test_mean_dtype(self): da_mean = self.darr.mean(axis=0, dtype=int) np_mean = self.arr.mean(axis=0, dtype=int) @@ -531,17 +536,6 @@ def test_sum_4D_cyclic(self): assert_allclose(darr_sum.tondarray(), arr_sum) assert_allclose(darr.sum().tondarray(), arr.sum()) - def test_empty_localarray(self): - if len(self.context.targets) < 2: - raise self.skipTest("not enough targets to run test.") - dist = Distribution.from_shape(self.context, - shape=(1,), - dist=('b',), - targets=self.context.targets[:2]) - darr = self.context.ones(dist) - self.assertRaises(NotImplementedError, darr.min, ()) - self.assertRaises(NotImplementedError, darr.sum, (), {'axis':0}) - class TestFromLocalArrays(ContextTestCase): diff --git a/distarray/local/maps.py b/distarray/local/maps.py index faa7ade7..dc654037 100644 --- a/distarray/local/maps.py +++ b/distarray/local/maps.py @@ -56,7 +56,7 @@ def from_shape(cls, comm, shape, dist=None, grid_shape=None): if grid_shape is None: # Make a new grid_shape if not provided. grid_shape = make_grid_shape(shape, dist_tuple, comm_size) - grid_shape = normalize_grid_shape(grid_shape, ndim, + grid_shape = normalize_grid_shape(grid_shape, shape, dist_tuple, comm_size) comm = construct.init_comm(base_comm, grid_shape) diff --git a/distarray/local/tests/paralleltest_maps.py b/distarray/local/tests/paralleltest_maps.py index 51b8bd70..48343812 100644 --- a/distarray/local/tests/paralleltest_maps.py +++ b/distarray/local/tests/paralleltest_maps.py @@ -89,7 +89,7 @@ def test_basic_2d(self): def test_bad_distribution(self): """Test that invalid distribution type fails as expected.""" - with self.assertRaises(TypeError): + with self.assertRaises(ValueError): # Invalid distribution type 'x'. Distribution.from_shape(comm=self.comm, shape=(7,), dist={0: 'x'}, grid_shape=(4,)) diff --git a/distarray/metadata_utils.py b/distarray/metadata_utils.py index fda53a25..f3c3eef9 100644 --- a/distarray/metadata_utils.py +++ b/distarray/metadata_utils.py @@ -24,23 +24,72 @@ class GridShapeError(Exception): pass -def normalize_grid_shape(grid_shape, ndims, dist, comm_size): +def check_grid_shape_preconditions(shape, dist, comm_size): + """ + Verify various distarray parameters are correct before making a grid_shape. + """ + if comm_size < 1: + raise ValueError("comm_size >= 1 not satisfied, comm_size = %s" % + (comm_size,)) + if len(shape) != len(dist): + raise ValueError("len(shape) == len(dist) not satisfied, len(shape) =" + " %s and len(dist) = %s" % (len(shape), len(dist))) + if any(i < 0 for i in shape): + raise ValueError("shape must be a sequence of non-negative integers, " + "shape = %s" % (shape,)) + if any(i not in ('b', 'c', 'n', 'u') for i in dist): + raise ValueError("dist must be a sequence of 'b', 'n', 'c', 'u' " + "strings, dist = %s" % (dist,)) + + +def check_grid_shape_postconditions(grid_shape, shape, dist, comm_size): + if not (len(grid_shape) == len(shape) == len(dist)): + raise ValueError("len(gird_shape) == len(shape) == len(dist) not " + "satisfied, len(grid_shape) = %s and len(shape) = %s " + "and len(dist) = %s" % (len(grid_shape), len(shape), + len(dist))) + if any(gs < 1 for gs in grid_shape): + raise ValueError("all(gs >= 1 for gs in grid_shape) not satisfied, " + "grid_shape = %s" % (grid_shape,)) + if any(gs != 1 for (d, gs) in zip(dist, grid_shape) if d == 'n'): + raise ValueError("all(gs == 1 for (d, gs) in zip(dist, grid_shape) if " + "d == 'n', not satified dist = %s and grid_shape = " + "%s" % (dist, grid_shape)) + if any(gs > s for (s, gs) in zip(shape, grid_shape) if s > 0): + raise ValueError("all(gs <= s for (s, gs) in zip(shape, grid_shape) " + "if s > 0) not satisfied, shape = %s and grid_shape " + "= %s" % (shape, grid_shape)) + if reduce(operator.mul, grid_shape, 1) > comm_size: + raise ValueError("reduce(operator.mul, grid_shape, 1) <= comm_size not" + " satisfied, grid_shape = %s product = %s and " + "comm_size = %s" % ( + grid_shape, + reduce(operator.mul, grid_shape, 1), + comm_size)) + + +def normalize_grid_shape(grid_shape, shape, dist, comm_size): """Adds 1s to grid_shape so it has `ndims` dimensions. Validates `grid_shape` tuple against the `dist` tuple and `comm_size`. """ + def check_normalization_preconditions(grid_shape, dist): + if any(i < 0 for i in grid_shape): + raise ValueError("grid_shape must be a sequence of non-negative " + "integers, grid_shape = %s" % (grid_shape,)) + if len(grid_shape) > len(dist): + raise ValueError("len(grid_shape) <= len(dist) not satisfied, " + "len(grid_shape) = %s and len(dist) = %s" % + (len(grid_shape), len(dist))) + check_grid_shape_preconditions(shape, dist, comm_size) + check_normalization_preconditions(grid_shape, dist) + + ndims = len(shape) grid_shape = tuple(grid_shape) + (1,) * (ndims - len(grid_shape)) - # short circuit for special case - if all(x == 'n' for x in dist): - if not all(x == 1 for x in grid_shape): - raise ValueError("grid shape should be all `1`'s not %s." % - grid_shape) - return grid_shape - if len(grid_shape) != len(dist): msg = "grid_shape's length (%d) not equal to dist's length (%d)" raise InvalidGridShapeError(msg % (len(grid_shape), len(dist))) - if reduce(operator.mul, grid_shape, 1) != comm_size: + if reduce(operator.mul, grid_shape, 1) > comm_size: msg = "grid shape %r not compatible with comm size of %d." raise InvalidGridShapeError(msg % (grid_shape, comm_size)) return grid_shape @@ -73,16 +122,19 @@ def make_grid_shape(shape, dist, comm_size): if not possible to distribute `comm_size` processes over number of dimensions. """ - if not isinstance(dist, Sequence): - raise TypeError("`dist` argument should be a Sequence.") + check_grid_shape_preconditions(shape, dist, comm_size) distdims = tuple(i for (i, v) in enumerate(dist) if v != 'n') ndistdim = len(distdims) if ndistdim == 0: dist_grid_shape = () + elif ndistdim == 1: # Trivial case: all processes used for the one distributed dimension. - dist_grid_shape = (comm_size,) + if comm_size >= shape[distdims[0]]: + dist_grid_shape = (shape[distdims[0]],) + else: + dist_grid_shape = (comm_size,) elif comm_size == 1: # Trivial case: only one process to distribute over! @@ -116,7 +168,9 @@ def make_grid_shape(shape, dist, comm_size): for distdim in distdims: grid_shape[distdim] = next(it) - return tuple(grid_shape) + out_grid_shape = tuple(grid_shape) + check_grid_shape_postconditions(out_grid_shape, shape, dist, comm_size) + return out_grid_shape def _compute_grid_ratios(shape):