Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions distarray/dist/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def func_call(func_name, arr_name, args, kwargs):
res = func(arr_name, *args, **kwargs)
return proxyize(res), res.dtype # noqa

res = context.apply(func_call, args=(name, a.key, args, kwargs))
res = context.apply(func_call, args=(name, a.key, args, kwargs),
targets=a.targets)
new_key = res[0][0]
dtype = res[0][1]
return DistArray.from_localarrays(new_key,
Expand Down Expand Up @@ -84,7 +85,8 @@ def func_call(func_name, a, b, args, kwargs):
res = func(a, b, *args, **kwargs)
return proxyize(res), res.dtype # noqa

res = context.apply(func_call, args=(name, a_key, b_key, args, kwargs))
res = context.apply(func_call, args=(name, a_key, b_key, args, kwargs),
targets=distribution.targets)
new_key = res[0][0]
dtype = res[0][1]
return DistArray.from_localarrays(new_key,
Expand Down
18 changes: 11 additions & 7 deletions distarray/dist/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def from_dim_data_per_rank(cls, context, dim_data_per_rank, targets=None):
self.ndim = len(dd0)
self.dist = tuple(dd['dist_type'] for dd in dd0)
self.grid_shape = tuple(dd['proc_grid_size'] for dd in dd0)
self.grid_shape = normalize_grid_shape(self.grid_shape, self.ndim,
self.grid_shape = normalize_grid_shape(self.grid_shape, self.shape,
self.dist, len(self.targets))

coords = [tuple(d['proc_grid_rank'] for d in dd) for dd in
Expand Down Expand Up @@ -453,8 +453,6 @@ def from_shape(cls, context, shape, dist=None, grid_shape=None,

self = cls.__new__(cls)
self.context = context
self.targets = sorted(targets or context.targets)
self.comm = self.context._make_subcomm(self.targets)
self.shape = shape
self.ndim = len(shape)

Expand All @@ -463,13 +461,19 @@ def from_shape(cls, context, shape, dist=None, grid_shape=None,
dist = {0: 'b'}
self.dist = normalize_dist(dist, self.ndim)

# all possible targets
all_targets = sorted(targets or context.targets)
# grid_shape
if grid_shape is None:
grid_shape = make_grid_shape(self.shape, self.dist,
len(self.targets))
len(all_targets))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the order of operations should be changed here. We should first get a fully validated and normalized self.grid_shape, then get self.targets, then get self.comm.

Would this work?

if grid_shape is None:
    grid_shape = make_grid_shape(...)
self.grid_shape = normalize_grid_shape(grid_shape, self.shape, self.dist, len(all_targets))
self.targets = [...] # as before
self.comm = self.context._make_subcomm(self.targets)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


self.grid_shape = normalize_grid_shape(grid_shape, self.ndim,
self.dist, len(self.targets))
self.grid_shape = normalize_grid_shape(grid_shape, self.shape,
self.dist, len(all_targets))
ntargets = reduce(operator.mul, self.grid_shape, 1)
# choose targets from grid_shape
self.targets = all_targets[:ntargets]
self.comm = self.context._make_subcomm(self.targets)

# TODO: FIXME: assert that self.rank_from_coords is valid and conforms
# to how MPI does it.
Expand Down Expand Up @@ -579,7 +583,7 @@ def __init__(self, context, global_dim_data, targets=None):
self.dist = tuple(m.dist for m in self.maps)
self.grid_shape = tuple(m.grid_size for m in self.maps)

self.grid_shape = normalize_grid_shape(self.grid_shape, self.ndim,
self.grid_shape = normalize_grid_shape(self.grid_shape, self.shape,
self.dist, len(self.targets))

nelts = reduce(operator.mul, self.grid_shape, 1)
Expand Down
8 changes: 4 additions & 4 deletions distarray/dist/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def rand(self, distribution):
da_key = self.context._generate_key()
ddpr = distribution.get_dim_data_per_rank()
ddpr_name = self.context._key_and_push(ddpr)[0]
comm_name = self.context.comm
comm_name = distribution.comm
self.context._execute(
'{da_key} = distarray.local.random.rand('
'distribution=distarray.local.maps.Distribution('
Expand Down Expand Up @@ -122,7 +122,7 @@ def normal(self, distribution, loc=0.0, scale=1.0):
ddpr = distribution.get_dim_data_per_rank()
loc_name, scale_name, ddpr_name = \
self.context._key_and_push(loc, scale, ddpr)
comm_name = self.context.comm
comm_name = distribution.comm
self.context._execute(
'{da_key} = distarray.local.random.normal('
'loc={loc_name}, scale={scale_name},'
Expand Down Expand Up @@ -160,7 +160,7 @@ def randint(self, distribution, low, high=None):
ddpr = distribution.get_dim_data_per_rank()
low_name, high_name, ddpr_name = \
self.context._key_and_push(low, high, ddpr)
comm_name = self.context.comm
comm_name = distribution.comm
self.context._execute(
'{da_key} = distarray.local.random.randint('
'low={low_name}, high={high_name},'
Expand All @@ -186,7 +186,7 @@ def randn(self, distribution):
da_key = self.context._generate_key()
ddpr = distribution.get_dim_data_per_rank()
ddpr_name = self.context._key_and_push(ddpr)[0]
comm_name = self.context.comm
comm_name = distribution.comm
self.context._execute(
'{da_key} = distarray.local.random.randn('
'distribution=distarray.local.maps.Distribution('
Expand Down
18 changes: 6 additions & 12 deletions distarray/dist/tests/test_distarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,11 +439,16 @@ def test_mean_multiaxis(self):
da_mean = self.darr.mean(axis=(0, 1))
assert_allclose(da_mean.tondarray(), np_mean)

def test_mean_along_axis_1(self):
def test_mean_along_axis_0(self):
da_mean = self.darr.mean(axis=0)
np_mean = self.arr.mean(axis=0)
assert_allclose(da_mean.tondarray(), np_mean)

def test_mean_along_axis_1(self):
da_mean = self.darr.mean(axis=1)
np_mean = self.arr.mean(axis=1)
assert_allclose(da_mean.tondarray(), np_mean)

def test_mean_dtype(self):
da_mean = self.darr.mean(axis=0, dtype=int)
np_mean = self.arr.mean(axis=0, dtype=int)
Expand Down Expand Up @@ -531,17 +536,6 @@ def test_sum_4D_cyclic(self):
assert_allclose(darr_sum.tondarray(), arr_sum)
assert_allclose(darr.sum().tondarray(), arr.sum())

def test_empty_localarray(self):
if len(self.context.targets) < 2:
raise self.skipTest("not enough targets to run test.")
dist = Distribution.from_shape(self.context,
shape=(1,),
dist=('b',),
targets=self.context.targets[:2])
darr = self.context.ones(dist)
self.assertRaises(NotImplementedError, darr.min, ())
self.assertRaises(NotImplementedError, darr.sum, (), {'axis':0})


class TestFromLocalArrays(ContextTestCase):

Expand Down
2 changes: 1 addition & 1 deletion distarray/local/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def from_shape(cls, comm, shape, dist=None, grid_shape=None):

if grid_shape is None: # Make a new grid_shape if not provided.
grid_shape = make_grid_shape(shape, dist_tuple, comm_size)
grid_shape = normalize_grid_shape(grid_shape, ndim,
grid_shape = normalize_grid_shape(grid_shape, shape,
dist_tuple, comm_size)

comm = construct.init_comm(base_comm, grid_shape)
Expand Down
2 changes: 1 addition & 1 deletion distarray/local/tests/paralleltest_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_basic_2d(self):

def test_bad_distribution(self):
"""Test that invalid distribution type fails as expected."""
with self.assertRaises(TypeError):
with self.assertRaises(ValueError):
# Invalid distribution type 'x'.
Distribution.from_shape(comm=self.comm, shape=(7,),
dist={0: 'x'}, grid_shape=(4,))
Expand Down
80 changes: 67 additions & 13 deletions distarray/metadata_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,72 @@ class GridShapeError(Exception):
pass


def normalize_grid_shape(grid_shape, ndims, dist, comm_size):
def check_grid_shape_preconditions(shape, dist, comm_size):
"""
Verify various distarray parameters are correct before making a grid_shape.
"""
if comm_size < 1:
raise ValueError("comm_size >= 1 not satisfied, comm_size = %s" %
(comm_size,))
if len(shape) != len(dist):
raise ValueError("len(shape) == len(dist) not satisfied, len(shape) ="
" %s and len(dist) = %s" % (len(shape), len(dist)))
if any(i < 0 for i in shape):
raise ValueError("shape must be a sequence of non-negative integers, "
"shape = %s" % (shape,))
if any(i not in ('b', 'c', 'n', 'u') for i in dist):
raise ValueError("dist must be a sequence of 'b', 'n', 'c', 'u' "
"strings, dist = %s" % (dist,))


def check_grid_shape_postconditions(grid_shape, shape, dist, comm_size):
if not (len(grid_shape) == len(shape) == len(dist)):
raise ValueError("len(gird_shape) == len(shape) == len(dist) not "
"satisfied, len(grid_shape) = %s and len(shape) = %s "
"and len(dist) = %s" % (len(grid_shape), len(shape),
len(dist)))
if any(gs < 1 for gs in grid_shape):
raise ValueError("all(gs >= 1 for gs in grid_shape) not satisfied, "
"grid_shape = %s" % (grid_shape,))
if any(gs != 1 for (d, gs) in zip(dist, grid_shape) if d == 'n'):
raise ValueError("all(gs == 1 for (d, gs) in zip(dist, grid_shape) if "
"d == 'n', not satified dist = %s and grid_shape = "
"%s" % (dist, grid_shape))
if any(gs > s for (s, gs) in zip(shape, grid_shape) if s > 0):
raise ValueError("all(gs <= s for (s, gs) in zip(shape, grid_shape) "
"if s > 0) not satisfied, shape = %s and grid_shape "
"= %s" % (shape, grid_shape))
if reduce(operator.mul, grid_shape, 1) > comm_size:
raise ValueError("reduce(operator.mul, grid_shape, 1) <= comm_size not"
" satisfied, grid_shape = %s product = %s and "
"comm_size = %s" % (
grid_shape,
reduce(operator.mul, grid_shape, 1),
comm_size))


def normalize_grid_shape(grid_shape, shape, dist, comm_size):
"""Adds 1s to grid_shape so it has `ndims` dimensions. Validates
`grid_shape` tuple against the `dist` tuple and `comm_size`.
"""
def check_normalization_preconditions(grid_shape, dist):
if any(i < 0 for i in grid_shape):
raise ValueError("grid_shape must be a sequence of non-negative "
"integers, grid_shape = %s" % (grid_shape,))
if len(grid_shape) > len(dist):
raise ValueError("len(grid_shape) <= len(dist) not satisfied, "
"len(grid_shape) = %s and len(dist) = %s" %
(len(grid_shape), len(dist)))
check_grid_shape_preconditions(shape, dist, comm_size)
check_normalization_preconditions(grid_shape, dist)

ndims = len(shape)
grid_shape = tuple(grid_shape) + (1,) * (ndims - len(grid_shape))

# short circuit for special case
if all(x == 'n' for x in dist):
if not all(x == 1 for x in grid_shape):
raise ValueError("grid shape should be all `1`'s not %s." %
grid_shape)
return grid_shape

if len(grid_shape) != len(dist):
msg = "grid_shape's length (%d) not equal to dist's length (%d)"
raise InvalidGridShapeError(msg % (len(grid_shape), len(dist)))
if reduce(operator.mul, grid_shape, 1) != comm_size:
if reduce(operator.mul, grid_shape, 1) > comm_size:
msg = "grid shape %r not compatible with comm size of %d."
raise InvalidGridShapeError(msg % (grid_shape, comm_size))
return grid_shape
Expand Down Expand Up @@ -73,16 +122,19 @@ def make_grid_shape(shape, dist, comm_size):
if not possible to distribute `comm_size` processes over number of
dimensions.
"""
if not isinstance(dist, Sequence):
raise TypeError("`dist` argument should be a Sequence.")
check_grid_shape_preconditions(shape, dist, comm_size)
distdims = tuple(i for (i, v) in enumerate(dist) if v != 'n')
ndistdim = len(distdims)

if ndistdim == 0:
dist_grid_shape = ()

elif ndistdim == 1:
# Trivial case: all processes used for the one distributed dimension.
dist_grid_shape = (comm_size,)
if comm_size >= shape[distdims[0]]:
dist_grid_shape = (shape[distdims[0]],)
else:
dist_grid_shape = (comm_size,)

elif comm_size == 1:
# Trivial case: only one process to distribute over!
Expand Down Expand Up @@ -116,7 +168,9 @@ def make_grid_shape(shape, dist, comm_size):
for distdim in distdims:
grid_shape[distdim] = next(it)

return tuple(grid_shape)
out_grid_shape = tuple(grid_shape)
check_grid_shape_postconditions(out_grid_shape, shape, dist, comm_size)
return out_grid_shape


def _compute_grid_ratios(shape):
Expand Down