diff --git a/Makefile b/Makefile index e1777dc2..1f20f4b2 100644 --- a/Makefile +++ b/Makefile @@ -60,13 +60,13 @@ install: # Testing-related targets. # ---------------------------------------------------------------------------- -test_client: +test_ipython: ${PYTHON} -m unittest discover -c -.PHONY: test_client +.PHONY: test_ipython -test_client_with_coverage: +test_ipython_with_coverage: ${COVERAGE} run -pm unittest discover -cv -.PHONY: test_client_with_coverage +.PHONY: test_ipython_with_coverage ${PARALLEL_OUT_DIR} : mkdir ${PARALLEL_OUT_DIR} @@ -95,10 +95,10 @@ test_mpi_with_coverage: ${MPI_ONLY_LAUNCH_TEST} .PHONY: test_mpi_with_coverage -test: test_client test_engines test_mpi +test: test_ipython test_mpi test_engines .PHONY: test -test_with_coverage: test_client_with_coverage test_engines_with_coverage test_mpi_with_coverage +test_with_coverage: test_ipython_with_coverage test_mpi_with_coverage test_engines_with_coverage .PHONY: test_with_coverage coverage_report: diff --git a/distarray/globalapi/context.py b/distarray/globalapi/context.py index c593fe5d..619c70e1 100644 --- a/distarray/globalapi/context.py +++ b/distarray/globalapi/context.py @@ -772,6 +772,7 @@ def func_wrapper(func, apply_nonce, context_key, args, kwargs, autoproxyize): """ from importlib import import_module import types + from distarray.metadata_utils import arg_kwarg_proxy_converter from distarray.localapi import LocalArray main = import_module('__main__') @@ -793,19 +794,8 @@ def func_wrapper(func, apply_nonce, context_key, args, kwargs, autoproxyize): func = types.FunctionType(func_code, new_func_globals, func_name, func_defaults, func_closure) - # convert args - args = list(args) - for i, a in enumerate(args): - if isinstance(a, main.Proxy): - args[i] = a.dereference() - args = tuple(args) - - # convert kwargs - for k in kwargs.keys(): - val = kwargs[k] - if isinstance(val, main.Proxy): - kwargs[k] = val.dereference() + args, kwargs = arg_kwarg_proxy_converter(args, kwargs) result = func(*args, **kwargs) if autoproxyize and isinstance(result, LocalArray): diff --git a/distarray/globalapi/distarray.py b/distarray/globalapi/distarray.py index 517ca133..35d456bd 100644 --- a/distarray/globalapi/distarray.py +++ b/distarray/globalapi/distarray.py @@ -24,7 +24,7 @@ import distarray.localapi from distarray.metadata_utils import sanitize_indices -from distarray.globalapi.maps import Distribution +from distarray.globalapi.maps import Distribution, asdistribution from distarray.utils import _raise_nie from distarray.metadata_utils import normalize_reduction_axes @@ -495,6 +495,65 @@ def local_view(larr, ddpr, dtype): return DistArray.from_localarrays(key=new_key, distribution=new_dist, dtype=dtype) + def distribute_as(self, shape_or_dist): + """ + Redistributes this DistArray, returning a new DistArray with the same + data and corresponding distribution. + + Parameters + ---------- + shape_or_dist : shape tuple or Distribution object. + Distribution for the new DistArray. The new distribution must have + the same number of items as this distarray. The global shape and + targets may be different. If shape tuple, immediately converted to + a Distribution object with default parameters. + + Returns + ------- + DistArray + A new DistArray distributed according to `dist`. + + Note + ---- + Currently implemented for block and non-distributed maps only. + + """ + + dist = asdistribution(self.context, shape_or_dist) + + if (any(d not in ('b', 'n') for d in self.distribution.dist) or + any(d not in ('b', 'n') for d in dist.dist)): + msg = "Only block and non-distributed dimensions currently supported." + raise NotImplementedError(msg) + + def _local_redistribute_same_shape(comm, plan, la_from, la_to): + from distarray.localapi import redistribute + redistribute(comm, plan, la_from, la_to) + + def _local_redistribute_general(comm, plan, la_from, la_to): + from distarray.localapi import redistribute_general + redistribute_general(comm, plan, la_from, la_to) + + source_size = self.global_size + dest_size = reduce(operator.mul, dist.shape, 1) + + if self.distribution.shape == dist.shape: + _local_redistribute = _local_redistribute_same_shape + elif source_size == dest_size: + _local_redistribute = _local_redistribute_general + else: + msg = ("Original size %d != new size %d," + " and total size of new array must be unchanged.") + raise ValueError(msg % (source_size, dest_size)) + + plan = self.distribution.get_redist_plan(dist) + ubercomm, all_targets = self.distribution.comm_union(dist) + result = DistArray(dist, dtype=self.dtype) + + self.context.apply(_local_redistribute, (ubercomm, plan, self.key, result.key), + targets=all_targets) + return result + # Binary operators def _binary_op_from_ufunc(self, other, func, rop_str=None, *args, **kwargs): diff --git a/distarray/globalapi/maps.py b/distarray/globalapi/maps.py index 69f644b3..b0b237d3 100644 --- a/distarray/globalapi/maps.py +++ b/distarray/globalapi/maps.py @@ -41,7 +41,9 @@ sanitize_indices, _start_stop_block, tuple_intersection, - shapes_from_dim_data_per_rank) + shapes_from_dim_data_per_rank, + condense, + strides_from_shape) def _dedup_dim_dicts(dim_dicts): @@ -551,7 +553,7 @@ def from_maps(cls, context, maps, targets=None): self = super(Distribution, cls).__new__(cls) self.context = context self.targets = sorted(targets or context.targets) - self.comm = self.context.make_subcomm(self.targets) + self._comm = None self.maps = maps self.shape = tuple(m.size for m in self.maps) self.ndim = len(self.maps) @@ -758,6 +760,12 @@ def __getitem__(self, idx): def __len__(self): return len(self.maps) + @property + def comm(self): + if self._comm is None: + self._comm = self.context.make_subcomm(self.targets) + return self._comm + @property def has_precise_index(self): """ @@ -869,3 +877,140 @@ def view(self, new_dimsize=None): def localshapes(self): return shapes_from_dim_data_per_rank(self.get_dim_data_per_rank()) + + def comm_union(self, *dists): + """ + Make a communicator that includes the union of all targets in `dists`. + + Parameters + ---------- + dists: sequence of distribution objects. + + Returns + ------- + tuple + First element is encompassing communicator proxy; second is a + sequence of all targets in `dists`. + + """ + dist_targets = [d.targets for d in dists] + all_targets = sorted(reduce(set.union, dist_targets, set(self.targets))) + return self.context.make_subcomm(all_targets), all_targets + + # ------------------------------------------------------------------------ + # Redistribution + # ------------------------------------------------------------------------ + + @staticmethod + def _redist_intersection_same_shape(source_dimdata, dest_dimdata): + + intersections = [] + for source_dimdict, dest_dimdict in zip(source_dimdata, dest_dimdata): + + if not (source_dimdict['dist_type'] == + dest_dimdict['dist_type'] == 'b'): + raise ValueError("Only 'b' dist_type supported") + + source_idxs = source_dimdict['start'], source_dimdict['stop'] + dest_idxs = dest_dimdict['start'], dest_dimdict['stop'] + + intersections.append(tuple_intersection(source_idxs, dest_idxs)) + + return intersections + + @staticmethod + def _redist_intersection_reshape(source_dimdata, dest_dimdata): + source_flat = global_flat_indices(source_dimdata) + dest_flat = global_flat_indices(dest_dimdata) + return _global_flat_indices_intersection(source_flat, dest_flat) + + def get_redist_plan(self, other_dist): + # Get all targets + all_targets = sorted(set(self.targets + other_dist.targets)) + union_rank_from_target = {t: r for (r, t) in enumerate(all_targets)} + + source_ranks = range(len(self.targets)) + source_targets = self.targets + union_rank_from_source_rank = {sr: union_rank_from_target[st] + for (sr, st) in + zip(source_ranks, source_targets)} + + dest_ranks = range(len(other_dist.targets)) + dest_targets = other_dist.targets + union_rank_from_dest_rank = {sr: union_rank_from_target[st] + for (sr, st) in + zip(dest_ranks, dest_targets)} + + source_ddpr = self.get_dim_data_per_rank() + dest_ddpr = other_dist.get_dim_data_per_rank() + source_dest_pairs = product(source_ddpr, dest_ddpr) + + if self.shape == other_dist.shape: + _intersection = Distribution._redist_intersection_same_shape + else: + _intersection = Distribution._redist_intersection_reshape + + plan = [] + for source_dd, dest_dd in source_dest_pairs: + intersections = _intersection(source_dd, dest_dd) + if intersections and all(i for i in intersections): + source_coords = tuple(dd['proc_grid_rank'] for dd in source_dd) + source_rank = self.rank_from_coords[source_coords] + dest_coords = tuple(dd['proc_grid_rank'] for dd in dest_dd) + dest_rank = other_dist.rank_from_coords[dest_coords] + plan.append({ + 'source_rank': union_rank_from_source_rank[source_rank], + 'dest_rank': union_rank_from_dest_rank[dest_rank], + 'indices': intersections, + } + ) + + return plan + + +# ---------------------------------------------------------------------------- +# Redistribution helper functions. +# ---------------------------------------------------------------------------- + +def global_flat_indices(dim_data): + """ + Return a list of tuples of indices into the flattened global array. + + Parameters + ---------- + dim_data: dimension dictionary. + + Returns + ------- + list of 2-tuples of ints. + Each tuple is a (start, stop) interval into the flattened global array. + All selected ranges comprise the indices for this dim_data's sub-array. + + """ + # TODO: FIXME: can be optimized when the last dimension is 'n'. + + for dd in dim_data: + if dd['dist_type'] == 'n': + dd['start'] = 0 + dd['stop'] = dd['size'] + + glb_shape = tuple(dd['size'] for dd in dim_data) + glb_strides = strides_from_shape(glb_shape) + + ranges = [range(dd['start'], dd['stop']) for dd in dim_data[:-1]] + start_ranges = ranges + [[dim_data[-1]['start']]] + stop_ranges = ranges + [[dim_data[-1]['stop']]] + + def flatten(idx): + return sum(a * b for (a, b) in zip(idx, glb_strides)) + + starts = map(flatten, product(*start_ranges)) + stops = map(flatten, product(*stop_ranges)) + + intervals = zip(starts, stops) + return condense(intervals) + +def _global_flat_indices_intersection(gfis0, gfis1): + intersections = filter(None, [tuple_intersection(a, b) + for (a, b) in product(gfis0, gfis1)]) + return [i[:2] for i in intersections] diff --git a/distarray/globalapi/tests/test_distarray.py b/distarray/globalapi/tests/test_distarray.py index 4d68427c..3bacdd70 100644 --- a/distarray/globalapi/tests/test_distarray.py +++ b/distarray/globalapi/tests/test_distarray.py @@ -20,7 +20,7 @@ from distarray.externals.six.moves import range from distarray.testing import DefaultContextTestCase from distarray.globalapi.distarray import DistArray -from distarray.globalapi.maps import Distribution +from distarray.globalapi.maps import Distribution, global_flat_indices class TestDistArray(DefaultContextTestCase): @@ -1015,5 +1015,192 @@ def test_incompatible_dtype(self): da.view(dtype=dtype) +class TestBlockRedistribution(DefaultContextTestCase): + + def test_redist_identity(self): + source_dist = dest_dist = Distribution(self.context, (10, 10), + ('b', 'b'), (1, 1), targets=[0]) + source_da = self.context.empty(source_dist, dtype=numpy.int32) + source_da.fill(-42) + dest_da = source_da.distribute_as(dest_dist) + assert_array_equal(source_da.tondarray(), dest_da.tondarray()) + + def test_redist_1D(self): + dist0 = Distribution(self.context, (40,), ('b',), (2,), targets=[1,3]) + dist1 = Distribution(self.context, (40,), ('b',), (2,), targets=[0,2]) + da = self.context.ones(dist0) + db = da.distribute_as(dist1) + + self.assertIs(db.distribution, dist1) + self.assertSequenceEqual(db.localshapes(), da.localshapes()) + assert_array_equal(da.tondarray(), db.tondarray()) + + dist2 = Distribution(self.context, (40,), ('b',), (2,), targets=[0,2]) + + da.fill(-42) + dc = da.distribute_as(dist2) + self.assertIs(dc.distribution, dist2) + self.assertSequenceEqual(dc.localshapes(), da.localshapes()) + assert_array_equal(da.tondarray(), dc.tondarray()) + + def test_redist_2D(self): + nrows, ncols = 7, 13 + source_dist = Distribution(self.context, (nrows, ncols), + ('b', 'b'), (2, 2), targets=range(4)) + dest_gdd = ( + { + 'dist_type': 'b', + 'bounds': [0, nrows//3, nrows], + }, + { + 'dist_type': 'b', + 'bounds': [0, ncols//3, ncols], + } + ) + dest_dist = Distribution.from_global_dim_data(self.context, + dest_gdd, targets=range(4)) + source_da = self.context.empty(source_dist, dtype=numpy.int32) + source_da.fill(-42) + + dest_da = source_da.distribute_as(dest_dist) + assert_array_equal(source_da.tondarray(), dest_da.tondarray()) + + def test_redist_incompatible_sizes(self): + source_da = self.context.empty((10,), dtype=numpy.int32) + + with self.assertRaises(ValueError): + source_da.distribute_as(Distribution(self.context, (9,))) + + with self.assertRaises(ValueError): + source_da.distribute_as(Distribution(self.context, (3, 4))) + + with self.assertRaises(ValueError): + source_da.distribute_as(Distribution(self.context, (3, 4000))) + + def test_redist_unsupported_dist_types(self): + source_dist = Distribution(self.context, (10, 20), ('n', 'c')) + source_da = self.context.empty(source_dist) + + with self.assertRaises(NotImplementedError): + source_da.distribute_as(source_da.distribution) + + source_da = self.context.empty((10, 20), ('b', 'b')) + dest_dist = Distribution(self.context, (10, 20), ('b', 'c')) + + with self.assertRaises(NotImplementedError): + source_da.distribute_as(dest_dist) + + +class TestReshapeRedistribution(DefaultContextTestCase): + + def test_global_flat_indices(self): + dist0 = Distribution(self.context, (40,), ('b',), (2,), targets=[1,3]) + + self.assertSequenceEqual([[(0, 20)], [(20, 40)]], + [global_flat_indices(ddpr) for ddpr in dist0.get_dim_data_per_rank()]) + + dist1 = Distribution(self.context, (5, 8), ('b', 'n'), (2,), targets=[0,2]) + + self.assertSequenceEqual([[(0, 24)], [(24, 40)]], + [global_flat_indices(ddpr) for ddpr in dist1.get_dim_data_per_rank()]) + + dist2 = Distribution(self.context, (5, 8), ('b', 'b'), (2, 2), targets=[0,1,2,3]) + + self.assertSequenceEqual([[(0, 4), (8, 12), (16, 20)], + [(4, 8), (12, 16), (20, 24)], + [(24, 28), (32, 36)], + [(28, 32), (36, 40)]], + [global_flat_indices(ddpr) for ddpr in dist2.get_dim_data_per_rank()]) + + def test_redist_reshape_same_target(self): + dist0 = Distribution(self.context, (40,), ('b',), (1,), targets=[1]) + dist1 = Distribution(self.context, (5, 8), ('b', 'n'), (1,), targets=[1]) + + da_src = self.context.empty(dist0) + da_src.fill(-10) + da_dest = da_src.distribute_as(dist1) + + expected = numpy.empty((5, 8)) + expected.fill(-10) + + assert_array_equal(da_dest.tondarray(), expected) + + def test_redist_reshape_diff_target(self): + dist0 = Distribution(self.context, (40,), ('b',), (1,), targets=[0]) + dist1 = Distribution(self.context, (5, 8), ('b', 'n'), (1,), targets=[1]) + + da_src = self.context.empty(dist0) + da_src.fill(-10) + da_dest = da_src.distribute_as(dist1) + + expected = numpy.empty((5, 8)) + expected.fill(-10) + + assert_array_equal(da_dest.tondarray(), expected) + + def test_redist_reshape_split_targets(self): + dist0 = Distribution(self.context, (40,), ('b',), (1,), targets=[0]) + dist1 = Distribution(self.context, (5, 8), ('b', 'n'), (2,), targets=[0,1]) + + da_src = self.context.empty(dist0) + da_src.fill(-10) + da_dest = da_src.distribute_as(dist1) + + expected = numpy.empty((5, 8)) + expected.fill(-10) + + assert_array_equal(da_dest.tondarray(), expected) + + def test_redist_reshape_disjoint_targets(self): + dist0 = Distribution(self.context, (40,), ('b',), (2,), targets=[1,3]) + dist1 = Distribution(self.context, (5, 8), ('b', 'n'), (2,), targets=[0,2]) + + da_src = self.context.empty(dist0) + da_src.fill(-10) + da_dest = da_src.distribute_as(dist1) + + expected = numpy.empty((5, 8)) + expected.fill(-10) + + assert_array_equal(da_dest.tondarray(), expected) + + def test_redist_reshape_three_dee(self): + three_dee_shape = (3, 5, 2) + two_dee_shape = (3 * 2, 5) + dist0 = Distribution(self.context, three_dee_shape, ('b', 'b', 'n'), (2, 2)) + dist1 = Distribution(self.context, two_dee_shape, ('b', 'b'), (2, 2)) + + da_src = self.context.empty(dist0) + da_src.fill(47) + + da_dest = da_src.distribute_as(dist1) + + self.assertTrue(da_dest.distribution.is_compatible(dist1)) + + expected = numpy.empty(two_dee_shape) + expected.fill(47) + + assert_array_equal(da_dest.tondarray(), expected) + + def test_redist_reshape_big(self): + three_dee_shape = (13, 17, 19) + two_dee_shape = (13, 19 * 17) + + dist0 = Distribution(self.context, three_dee_shape, ('b', 'n', 'b'), (2, 1, 2)) + dist1 = Distribution(self.context, two_dee_shape, ('b', 'b'), (2, 2)) + + da_src = self.context.empty(dist0) + da_src.fill(47) + + da_dest = da_src.distribute_as(dist1) + + self.assertTrue(da_dest.distribution.is_compatible(dist1)) + + expected = numpy.empty(two_dee_shape) + expected.fill(47) + + assert_array_equal(da_dest.tondarray(), expected) + + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/distarray/globalapi/tests/test_maps.py b/distarray/globalapi/tests/test_maps.py index 74e8b08a..71b7423f 100644 --- a/distarray/globalapi/tests/test_maps.py +++ b/distarray/globalapi/tests/test_maps.py @@ -299,6 +299,96 @@ def test_all_n_dist(self): dist=('n', 'n')) self.context.ones(distribution) +class TestRedistribution(DefaultContextTestCase): + + def test_block_redistribution_one_to_one(self): + source_dist = Distribution(self.context, (40,), + ('b',), (2,), targets=[1, 3]) + dest_dist = Distribution(self.context, (40,), + ('b',), (2,), targets=[0, 2]) + plan = source_dist.get_redist_plan(dest_dist) + expected = [ + {'source_rank': 1, 'dest_rank': 0, 'indices': [(0, 20, 1)]}, + {'source_rank': 3, 'dest_rank': 2, 'indices': [(20, 40, 1)]}, + ] + self.assertEqual(plan, expected) + + def test_block_redist_one_to_many(self): + source_dist = Distribution(self.context, (40,), + ('b',), (1,), targets=[1]) + dest_dist = Distribution(self.context, (40,), + ('b',), (2,), targets=[0,2]) + plan = source_dist.get_redist_plan(dest_dist) + expected = [ + {'source_rank': 1, 'dest_rank': 0, 'indices': [(0, 20, 1)]}, + {'source_rank': 1, 'dest_rank': 2, 'indices': [(20, 40, 1)]}, + ] + self.assertEqual(plan, expected) + + def test_block_redist_many_to_one(self): + source_dist = Distribution(self.context, (40,), + ('b',), (2,), targets=[1, 2]) + dest_dist = Distribution(self.context, (40,), + ('b',), (1,), targets=[0]) + plan = source_dist.get_redist_plan(dest_dist) + expected = [ + {'source_rank': 1, 'dest_rank': 0, 'indices': [(0, 20, 1)]}, + {'source_rank': 2, 'dest_rank': 0, 'indices': [(20, 40, 1)]}, + ] + self.assertEqual(plan, expected) + + def test_block_redist_identity(self): + source_dist = dest_dist = Distribution(self.context, (40,), + ('b',), (4,)) + plan_a = source_dist.get_redist_plan(dest_dist) + plan_b = dest_dist.get_redist_plan(source_dist) + self.assertSequenceEqual(plan_a, plan_b) + for p, lshape in zip(plan_a, source_dist.localshapes()): + self.assertEqual(p['source_rank'], p['dest_rank']) + start, stop, step = p['indices'][0] + self.assertEqual(step, 1) + self.assertEqual(stop - start, lshape[0]) + + def test_block_redist_2D_identity(self): + source_dist = dest_dist = Distribution(self.context, (10, 10), + ('b', 'b'), (1, 1), + targets=[0]) + plan = source_dist.get_redist_plan(dest_dist) + expected = [ + {'source_rank': 0, 'dest_rank': 0, 'indices': [(0, 10, 1), (0, 10, 1)]}, + ] + self.assertEqual(plan, expected) + + def test_block_redist_2D_many_to_one(self): + source_dist = Distribution(self.context, (9, 9), + ('b', 'b'), (2, 2), targets=range(4)) + dest_dist = Distribution(self.context, (9, 9), + ('b', 'b'), (1, 1), targets=[2]) + plan = source_dist.get_redist_plan(dest_dist) + expected = [ + {'source_rank': 0, 'dest_rank': 2, 'indices': [(0, 5, 1), (0, 5, 1)]}, + {'source_rank': 1, 'dest_rank': 2, 'indices': [(0, 5, 1), (5, 9, 1)]}, + {'source_rank': 2, 'dest_rank': 2, 'indices': [(5, 9, 1), (0, 5, 1)]}, + {'source_rank': 3, 'dest_rank': 2, 'indices': [(5, 9, 1), (5, 9, 1)]}, + ] + for p, e in zip(plan, expected): + self.assertEqual(p, e) + + def test_block_redist_2D_one_to_many(self): + source_dist = Distribution(self.context, (9, 9), + ('b', 'b'), (1, 1), targets=[2]) + dest_dist = Distribution(self.context, (9, 9), + ('b', 'b'), (2, 2), targets=range(4)) + plan = source_dist.get_redist_plan(dest_dist) + expected = [ + {'source_rank': 2, 'dest_rank': 0, 'indices': [(0, 5, 1), (0, 5, 1)]}, + {'source_rank': 2, 'dest_rank': 1, 'indices': [(0, 5, 1), (5, 9, 1)]}, + {'source_rank': 2, 'dest_rank': 2, 'indices': [(5, 9, 1), (0, 5, 1)]}, + {'source_rank': 2, 'dest_rank': 3, 'indices': [(5, 9, 1), (5, 9, 1)]}, + ] + for p, e in zip(plan, expected): + self.assertEqual(p, e) + class TestNoEmptyLocals(DefaultContextTestCase): diff --git a/distarray/localapi/localarray.py b/distarray/localapi/localarray.py index d1a6937e..432f86e8 100644 --- a/distarray/localapi/localarray.py +++ b/distarray/localapi/localarray.py @@ -13,6 +13,9 @@ from __future__ import print_function, division +def mpi_print(*args, **kwargs): + from distarray.mpionly_utils import get_world_rank + print("[%d]" % get_world_rank(), *args, **kwargs) # --------------------------------------------------------------------------- # Imports @@ -22,14 +25,80 @@ import numpy as np from distarray.externals import six -from distarray.externals.six.moves import zip +from distarray.externals.six.moves import zip, reduce -from distarray.metadata_utils import sanitize_indices +from distarray.metadata_utils import sanitize_indices, strides_from_shape, ndim_from_flat, condense from distarray.localapi.mpiutils import MPI from distarray.localapi import format, maps from distarray.localapi.error import InvalidDimensionError, IncompatibleArrayError +# ---------------------------------------------------------------------------- +# Local Redistribution functions +# ---------------------------------------------------------------------------- + +def _transform(local_distribution, glb_flat): + glb_strides = strides_from_shape(local_distribution.global_shape) + local_strides = strides_from_shape(local_distribution.local_shape) + glb_ndim_inds = ndim_from_flat(glb_flat, glb_strides) + local_ind = local_distribution.local_from_global(glb_ndim_inds) + local_flat = local_distribution.local_flat_from_local(local_ind) + return local_flat + +def _massage_indices(local_distribution, glb_intervals): + # XXX: TODO: document why we do `-1)+1` below. + local_flat_slices = [(_transform(local_distribution, i[0]), + _transform(local_distribution, i[1]-1)+1) + for i in glb_intervals] + return condense(local_flat_slices) + +def _mpi_dtype_from_intervals(larr, glb_intervals): + local_intervals = _massage_indices(larr.distribution, glb_intervals) + blocklengths = [stop-start for (start, stop) in local_intervals] + displacements = [start for (start, _) in local_intervals] + mpidtype = MPI.__TypeDict__[np.sctype2char(larr.dtype)] + newtype = mpidtype.Create_indexed(blocklengths, displacements) + newtype.Commit() + return newtype + +def redistribute_general(comm, plan, la_from, la_to): + myrank = comm.Get_rank() + for dta in plan: + if dta['source_rank'] == dta['dest_rank'] == myrank: + from_dtype = _mpi_dtype_from_intervals(la_from, dta['indices']) + to_dtype = _mpi_dtype_from_intervals(la_to, dta['indices']) + comm.Sendrecv(sendbuf=[la_from.ndarray, 1, from_dtype], dest=myrank, + recvbuf=[la_to.ndarray, 1, to_dtype], source=myrank) + elif dta['source_rank'] == myrank: + from_dtype = _mpi_dtype_from_intervals(la_from, dta['indices']) + comm.Send([la_from.ndarray, 1, from_dtype], dest=dta['dest_rank']) + elif dta['dest_rank'] == myrank: + to_dtype = _mpi_dtype_from_intervals(la_to, dta['indices']) + comm.Recv([la_to.ndarray, 1, to_dtype], source=dta['source_rank']) + +def make_local_slices(local_arr, glb_indices): + slices = tuple(slice(*inds) for inds in glb_indices) + return local_arr.local_from_global(slices) + +def redistribute(comm, plan, la_from, la_to): + myrank = comm.Get_rank() + for dta in plan: + if dta['source_rank'] == dta['dest_rank'] == myrank: + # simple local copy from `la_from` to `la_to` + slices_from = make_local_slices(la_from, dta['indices']) + slices_to = make_local_slices(la_to, dta['indices']) + la_to.ndarray[slices_to] = la_from.ndarray[slices_from] + elif dta['source_rank'] == myrank: + source_slices = make_local_slices(la_from, dta['indices']) + sliced_ndarr = la_from.ndarray[source_slices] + sliced_buffer = sliced_ndarr.ravel() + comm.Send(sliced_buffer, dest=dta['dest_rank']) + elif dta['dest_rank'] == myrank: + dest_slices = make_local_slices(la_to, dta['indices']) + sliced_ndarr = la_to.ndarray[dest_slices] + recv_buffer = np.empty_like(sliced_ndarr) + comm.Recv(recv_buffer, source=dta['source_rank']) + sliced_ndarr[...] = recv_buffer class GlobalIndex(object): """Object which provides access to global indexing on LocalArrays.""" diff --git a/distarray/localapi/maps.py b/distarray/localapi/maps.py index 8e7f8eb7..093800af 100644 --- a/distarray/localapi/maps.py +++ b/distarray/localapi/maps.py @@ -25,7 +25,7 @@ from numbers import Integral import numpy as np -from distarray.externals.six.moves import range, zip +from distarray.externals.six.moves import range, zip, reduce from distarray.localapi import construct from distarray.metadata_utils import (make_grid_shape, normalize_grid_shape, @@ -169,6 +169,12 @@ def global_from_local(self, local_ind): raise TypeError("Index must be Integral or slice.") return tuple(global_idxs) + def local_flat_from_local(self, local_ind): + local_strides = _get_strides(self.local_shape) + def flatten(idx, strides): + return sum(a * b for (a, b) in zip(idx, strides)) + return flatten(local_ind, local_strides) + def map_from_dim_dict(dd): """ Factory function that returns a 1D map for a given dimension @@ -204,6 +210,12 @@ def map_from_dim_dict(dd): raise ValueError("Unsupported dist_type of %r" % dist_type) +def _accum(start, next): + return tuple(s * next for s in start) + (next,) + +def _get_strides(shape): + return reduce(_accum, tuple(shape[1:]) + (1,), ()) + class MapBase(object): """ Base class for all one dimensional Map classes. diff --git a/distarray/metadata_utils.py b/distarray/metadata_utils.py index 02185efe..e865b631 100644 --- a/distarray/metadata_utils.py +++ b/distarray/metadata_utils.py @@ -529,3 +529,71 @@ def shapes_from_dim_data_per_rank(ddpr): # ddpr = dim_data_per_rank shape.append(size_from_dim_data(dd)) shape_list.append(tuple(shape)) return shape_list + +# ---------------------------------------------------------------------------- +# Redistribution-related utilities. +# ---------------------------------------------------------------------------- + +def _accum(start, next): + return tuple(s * next for s in start) + (next,) + +def strides_from_shape(shape): + return reduce(_accum, tuple(shape[1:]) + (1,), ()) + +def ndim_from_flat(flat, strides): + res = [] + for st in strides: + res.append(flat // st) + flat %= st + return tuple(res) + +def _squeeze(accum, next): + last = accum[-1] + if not last: + return [next] + elif last[-1] != next[0]: + return accum + [next] + elif last[-1] == next[0]: + return accum[:-1] + [(last[0], next[-1])] + +def condense(intervals): + intervals = reduce(_squeeze, intervals, [[]]) + return intervals + +# ---------------------------------------------------------------------------- +# `apply` related utilities. +# ---------------------------------------------------------------------------- + +def arg_kwarg_proxy_converter(args, kwargs, module_name='__main__'): + from importlib import import_module + + module = import_module(module_name) + # convert args + + # In some situations, like redistributing a DistArray from one set of + # targets to a disjoint set, the source and destination DistArrays (and + # associated LocalArrays) are in different communicators with different + # targets. In those cases, it is possible for a proxy object for one + # DistArray to not refer to anything on this target. In that case, + # `a.dereference()` raises an `AttributeError`. We intercept that here and + # assign `None` instead. + + args = list(args) + for i, a in enumerate(args): + if isinstance(a, module.Proxy): + try: + args[i] = a.dereference() + except AttributeError: + args[i] = None + args = tuple(args) + + # convert kwargs + for k in kwargs.keys(): + val = kwargs[k] + if isinstance(val, module.Proxy): + try: + kwargs[k] = val.dereference() + except AttributeError: + kwargs[k] = None + + return args, kwargs diff --git a/distarray/mpi_engine.py b/distarray/mpi_engine.py index 1bba45c5..5193f40c 100644 --- a/distarray/mpi_engine.py +++ b/distarray/mpi_engine.py @@ -11,6 +11,7 @@ from importlib import import_module import types +from distarray.metadata_utils import arg_kwarg_proxy_converter from distarray.localapi import LocalArray from distarray.localapi.proxyize import Proxy @@ -42,22 +43,6 @@ def __init__(self): break Engine.INTERCOMM.Free() - def arg_kwarg_proxy_converter(self, args, kwargs): - module = import_module('__main__') - # convert args - args = list(args) - for i, a in enumerate(args): - if isinstance(a, module.Proxy): - args[i] = a.dereference() - args = tuple(args) - - # convert kwargs - for k in kwargs.keys(): - val = kwargs[k] - if isinstance(val, module.Proxy): - kwargs[k] = val.dereference() - - return args, kwargs def is_engine(self): if self.world.rank != self.client_rank: @@ -103,7 +88,7 @@ def func_call(self, msg): module = import_module('__main__') module.proxyize.set_state(nonce) - args, kwargs = self.arg_kwarg_proxy_converter(args, kwargs) + args, kwargs = arg_kwarg_proxy_converter(args, kwargs) new_func_globals = module.__dict__ # add proper proxyize, context_key new_func_globals.update({'proxyize': module.proxyize, @@ -153,7 +138,7 @@ def builtin_call(self, msg): args = msg[2] kwargs = msg[3] - args, kwargs = self.arg_kwarg_proxy_converter(args, kwargs) + args, kwargs = arg_kwarg_proxy_converter(args, kwargs) res = func(*args, **kwargs) Engine.INTERCOMM.send(res, dest=self.client_rank)