diff --git a/distarray/dist/distarray.py b/distarray/dist/distarray.py index 0cfe79be..0e806e28 100644 --- a/distarray/dist/distarray.py +++ b/distarray/dist/distarray.py @@ -17,10 +17,12 @@ import operator from itertools import product from functools import reduce +from collections import Sequence import numpy as np import distarray +from distarray.metadata_utils import sanitize_indices from distarray.dist.maps import Distribution from distarray.utils import _raise_nie from distarray.metadata_utils import normalize_reduction_axes @@ -32,7 +34,6 @@ # Code # --------------------------------------------------------------------------- - class DistArray(object): __array_priority__ = 20.0 @@ -84,7 +85,8 @@ def get_dim_datas_and_dtype(arr): # has context, get dist and dtype elif (distribution is None) and (dtype is None): - res = context.apply(get_dim_datas_and_dtype, args=(key,)) + res = context.apply(get_dim_datas_and_dtype, args=(key,), + targets=targets) dim_datas = [i[0] for i in res] dtypes = [i[1] for i in res] da._dtype = dtypes[0] @@ -95,7 +97,8 @@ def get_dim_datas_and_dtype(arr): # has context and dtype, get dist elif (distribution is None) and (dtype is not None): da._dtype = dtype - dim_datas = context.apply(getattr, args=(key, 'dim_data')) + dim_datas = context.apply(getattr, args=(key, 'dim_data'), + targets=targets) da.distribution = Distribution.from_dim_data_per_rank(context, dim_datas, targets) @@ -128,10 +131,31 @@ def __repr__(self): (self.shape, self.targets) return s + def _process_return_value(self, result, return_proxy, index, targets): + + if return_proxy: + # proxy returned as result of slice + # slicing shouldn't alter the dtype + result = result[0] + return DistArray.from_localarrays(key=result, + context=self.context, + targets=targets, + dtype=self.dtype) + + elif isinstance(result, Sequence): + somethings = [i for i in result if i is not None] + if len(somethings) == 0: + # using checked_getitem and all return None + raise IndexError("Index %r is is not present." % (index,)) + if len(somethings) == 1: + return somethings[0] + else: + return result + else: + assert False # impossible is nothing + + def __getitem__(self, index): - #TODO: FIXME: major performance improvements possible here, - # especially for special cases like `index == slice(None)`. - # This would dramatically improve tondarray's performance. # to be run locally def checked_getitem(arr, index): @@ -141,30 +165,34 @@ def checked_getitem(arr, index): def raw_getitem(arr, index): return arr.global_index[index] - if isinstance(index, int) or isinstance(index, slice): - tuple_index = (index,) - return self.__getitem__(tuple_index) + # to be run locally + def get_slice(arr, index, ddpr, comm): + from distarray.local.maps import Distribution + local_distribution = Distribution(comm=comm, + dim_data=ddpr[comm.Get_rank()]) + result = arr.global_index.get_slice(index, local_distribution) + return proxyize(result) + + return_type, index = sanitize_indices(index, ndim=self.ndim, + shape=self.shape) + return_proxy = (return_type == 'view') + targets = self.distribution.owning_targets(index) + + args = [self.key, index] + if self.distribution.has_precise_index: + if return_proxy: # returning a new DistArray view + new_distribution = self.distribution.slice(index) + ddpr = new_distribution.get_dim_data_per_rank() + args.extend([ddpr, new_distribution.comm]) + local_fn = get_slice + else: # returning a value + local_fn = raw_getitem + else: # returning a value from unstructured + local_fn = checked_getitem + + result = self.context.apply(local_fn, args=args, targets=targets) + return self._process_return_value(result, return_proxy, index, targets) - elif isinstance(index, tuple): - targets = self.distribution.owning_targets(index) - - args = (self.key, index) - if self.distribution.has_precise_index: - result = self.context.apply(raw_getitem, args=args, - targets=targets) - else: - result = self.context.apply(checked_getitem, args=args, - targets=targets) - result = [i for i in result if i is not None] - if len(result) != 1: - raise IndexError("Getting more than one result (%s) is not " - "supported yet." % (result,)) - elif result is None: - raise IndexError("Index %r is out of bounds" % (index,)) - else: - return result[0] - else: - raise TypeError("Invalid index type.") def __setitem__(self, index, value): #TODO: FIXME: major performance improvements possible here. @@ -181,26 +209,21 @@ def checked_setitem(arr, index, value): def raw_setitem(arr, index, value): arr.global_index[index] = value - if isinstance(index, int) or isinstance(index, slice): - tuple_index = (index,) - return self.__setitem__(tuple_index, value) + _, index = sanitize_indices(index, ndim=self.ndim, shape=self.shape) - elif isinstance(index, tuple): - targets = self.distribution.owning_targets(index) - args = (self.key, index, value) - if self.distribution.has_precise_index: - self.context.apply(raw_setitem, args=args, targets=targets) - else: - result = self.context.apply(checked_setitem, args=args, - targets=targets) - result = [i for i in result if i is not None] - if len(result) > 1: - raise IndexError("Setting more than one result (%s) is " - "not supported yet." % (result,)) - elif result == []: - raise IndexError("Index %s is out of bounds" % (index,)) + targets = self.distribution.owning_targets(index) + args = (self.key, index, value) + if self.distribution.has_precise_index: + self.context.apply(raw_setitem, args=args, targets=targets) else: - raise TypeError("Invalid index type.") + result = self.context.apply(checked_setitem, args=args, + targets=targets) + result = [i for i in result if i is not None] + if len(result) > 1: + raise IndexError("Setting more than one result (%s) is " + "not supported yet." % (result,)) + elif result == []: + raise IndexError("Index %s is out of bounds" % (index,)) @property def context(self): diff --git a/distarray/dist/maps.py b/distarray/dist/maps.py index 46195202..2e9a6747 100644 --- a/distarray/dist/maps.py +++ b/distarray/dist/maps.py @@ -26,6 +26,7 @@ import operator from itertools import product from abc import ABCMeta, abstractmethod +from numbers import Integral import numpy as np @@ -34,11 +35,12 @@ from distarray.utils import remove_elements from distarray.metadata_utils import (normalize_dist, normalize_grid_shape, - make_grid_shape, - positivify, - _start_stop_block, normalize_dim_dict, normalize_reduction_axes, + make_grid_shape, + sanitize_indices, + _start_stop_block, + tuple_intersection, shapes_from_dim_data_per_rank) @@ -137,13 +139,13 @@ class MapBase(object): dimension of a distributed array. Maps allow distributed arrays to keep track of which process to talk to when indexing and slicing. - Classes that inherit from `MapBase` must implement the `owners()` + Classes that inherit from `MapBase` must implement the `index_owners()` abstractmethod. """ @abstractmethod - def owners(self, idx): + def index_owners(self, idx): """ Returns a list of process IDs in this dimension that might possibly own `idx`. @@ -194,9 +196,12 @@ def __init__(self, size, grid_size): self.size = size self.grid_size = grid_size - def owners(self, idx): + def index_owners(self, idx): return [0] if 0 <= idx < self.size else [] + def slice_owners(self, idx): + return [0] # slicing doesn't complain about out-of-bounds indices + def get_dimdicts(self): return ({ 'dist_type': 'n', @@ -205,6 +210,19 @@ def get_dimdicts(self): 'proc_grid_rank': 0, },) + def slice(self, idx): + """Make a new Map from a slice.""" + start = idx.start if idx.start is not None else 0 + stop = idx.stop if idx.stop is not None else self.size + intersection = tuple_intersection((0, self.size), (start, stop)) + if intersection: + intersection_size = intersection[1] - intersection[0] + else: + intersection_size = 0 + + return {'dist_type': self.dist, + 'size': intersection_size} + class BlockMap(MapBase): @@ -254,13 +272,25 @@ def __init__(self, size, grid_size): for grid_rank in range(grid_size)] self.boundary_padding = self.comm_padding = 0 - def owners(self, idx): + def index_owners(self, idx): coords = [] for (coord, (lower, upper)) in enumerate(self.bounds): if lower <= idx < upper: coords.append(coord) return coords + def slice_owners(self, idx): + coords = [] + if idx.step not in {None, 1}: + msg = "Slicing only implemented for step=1" + raise NotImplementedError(msg) + for (coord, (lower, upper)) in enumerate(self.bounds): + slice_tuple = (idx.start if idx.start is not None else 0, + idx.stop if idx.stop is not None else self.size) + if tuple_intersection((lower, upper), slice_tuple): + coords.append(coord) + return coords if coords != [] else [0] + def get_dimdicts(self): grid_ranks = range(len(self.bounds)) cpadding = self.comm_padding @@ -282,6 +312,22 @@ def get_dimdicts(self): }) return tuple(out) + def slice(self, idx): + """Make a new Map from a slice.""" + new_bounds = [0] + start = idx.start if idx.start is not None else 0 + # iterate over the processes in this dimension + for proc_start, proc_stop in self.bounds: + stop = idx.stop if idx.stop is not None else proc_stop + intersection = tuple_intersection((proc_start, proc_stop), + (start, stop)) + if intersection: + size = intersection[1] - intersection[0] + new_bounds.append(size + new_bounds[-1]) + + return {'dist_type': self.dist, + 'bounds': new_bounds} + class BlockCyclicMap(MapBase): @@ -317,7 +363,7 @@ def __init__(self, size, grid_size, block_size=1): self.grid_size = grid_size self.block_size = block_size - def owners(self, idx): + def index_owners(self, idx): idx_block = idx // self.block_size return [idx_block % self.grid_size] @@ -367,13 +413,13 @@ def __init__(self, size, grid_size, indices=None): if self.indices is not None: # Convert to NumPy arrays if not already. self.indices = [np.asarray(ind) for ind in self.indices] - self._owners = range(self.grid_size) + self._index_owners = range(self.grid_size) - def owners(self, idx): + def index_owners(self, idx): # TODO: FIXME: for now, the unstructured map just returns all # processes. Can be optimized if we know the upper and lower bounds # for each local array's global indices. - return self._owners + return self._index_owners def get_dimdicts(self): if self.indices is None: @@ -607,6 +653,24 @@ def has_precise_index(self): """ return not any(isinstance(m, UnstructuredMap) for m in self.maps) + def slice(self, index_tuple): + """Make a new Distribution from a slice.""" + new_targets = self.owning_targets(index_tuple) + global_dim_data = [] + # iterate over the dimensions + for map_, idx in zip(self.maps, index_tuple): + if isinstance(idx, Integral): + continue # integral indexing returns reduced dimensionality + elif isinstance(idx, slice): + global_dim_data.append(map_.slice(idx)) + else: + msg = "Index must be a sequence of Integrals and slices." + raise TypeError(msg) + + return self.__class__(context=self.context, + global_dim_data=global_dim_data, + targets=new_targets) + def owning_ranks(self, idxs): """ Returns a list of ranks that may *possibly* own the location in the `idxs` tuple. @@ -618,8 +682,15 @@ def owning_ranks(self, idxs): If the `idxs` tuple is out of bounds, raises `IndexError`. """ - idxs = map(positivify, idxs, self.shape) # positivify and check - dim_coord_hits = [m.owners(idx) for (m, idx) in zip(self.maps, idxs)] + _, idxs = sanitize_indices(idxs, ndim=self.ndim, shape=self.shape) + dim_coord_hits = [] + for m, idx in zip(self.maps, idxs): + if isinstance(idx, Integral): + owners = m.index_owners(idx) + elif isinstance(idx, slice): + owners = m.slice_owners(idx) + dim_coord_hits.append(owners) + all_coords = product(*dim_coord_hits) ranks = [self.rank_from_coords[c] for c in all_coords] return ranks diff --git a/distarray/dist/tests/test_distarray.py b/distarray/dist/tests/test_distarray.py index 4ec083f7..2ffd7c9a 100644 --- a/distarray/dist/tests/test_distarray.py +++ b/distarray/dist/tests/test_distarray.py @@ -114,6 +114,64 @@ def test_global_tolocal_bug(self): numpy.testing.assert_array_equal(dap.tondarray(), ndarr) +class TestGetItemSlicing(ContextTestCase): + + def test_full_slice_block_dist(self): + size = 10 + expected = numpy.random.randint(11, size=size) + arr = self.context.fromarray(expected) + assert_array_equal(arr[:].toarray(), expected) + + def test_partial_slice_block_dist(self): + size = 10 + expected = numpy.random.randint(10, size=size) + arr = self.context.fromarray(expected) + assert_array_equal(arr[0:2].toarray(), expected[0:2]) + + def test_slice_a_slice_block_dist_0(self): + size = 10 + expected = numpy.random.randint(10, size=size) + arr = self.context.fromarray(expected) + s0 = arr[:9] + s1 = s0[0:5] + s2 = s1[:2] + assert_array_equal(s2.toarray(), expected[:2]) + + def test_slice_a_slice_block_dist_1(self): + size = 10 + expected = numpy.random.randint(10, size=size) + arr = self.context.fromarray(expected) + s0 = arr[:9] + s1 = s0[0:5] + s2 = s1[-2:] + assert_array_equal(s2.toarray(), expected[3:5]) + + def test_partial_slice_block_dist_2d(self): + shape = (10, 20) + expected = numpy.random.randint(10, size=shape) + arr = self.context.fromarray(expected) + assert_array_equal(arr[2:6, 3:10].toarray(), expected[2:6, 3:10]) + + def test_partial_negative_slice_block_dist_2d(self): + shape = (10, 20) + expected = numpy.random.randint(10, size=shape) + arr = self.context.fromarray(expected) + assert_array_equal(arr[-6:-2, -10:-3].toarray(), + expected[-6:-2, -10:-3]) + + def test_incomplete_slice_block_dist_2d(self): + shape = (10, 20) + expected = numpy.random.randint(10, size=shape) + arr = self.context.fromarray(expected) + assert_array_equal(arr[3:9].toarray(), expected[3:9]) + + def test_incomplete_index_block_dist_2d(self): + shape = (10, 20) + expected = numpy.random.randint(10, size=shape) + arr = self.context.fromarray(expected) + assert_array_equal(arr[1].toarray(), expected[1]) + + class TestDistArrayCreationFromGlobalDimData(ContextTestCase): def test_from_global_dim_data_irregular_block(self): diff --git a/distarray/dist/tests/test_maps.py b/distarray/dist/tests/test_maps.py index 9eec4dda..9af40fc7 100644 --- a/distarray/dist/tests/test_maps.py +++ b/distarray/dist/tests/test_maps.py @@ -10,7 +10,7 @@ from distarray.externals.six.moves import range from distarray.testing import ContextTestCase -from distarray.dist import maps as client_map +from distarray.dist import maps from distarray.dist.maps import MapBase @@ -18,10 +18,10 @@ class TestClientMap(ContextTestCase): def test_2D_bn(self): nrows, ncols = 31, 53 - cm = client_map.Distribution.from_shape(self.context, - (nrows, ncols), - {0: 'b'}, - (4, 1)) + cm = maps.Distribution.from_shape(self.context, + (nrows, ncols), + {0: 'b'}, + (4, 1)) chunksize = (nrows // 4) + 1 for _ in range(100): r, c = randrange(nrows), randrange(ncols) @@ -31,9 +31,10 @@ def test_2D_bn(self): def test_2D_bb(self): nrows, ncols = 3, 5 nprocs_per_dim = 2 - cm = client_map.Distribution.from_shape( - self.context, (nrows, ncols), ('b', 'b'), - (nprocs_per_dim, nprocs_per_dim)) + cm = maps.Distribution.from_shape(self.context, + (nrows, ncols), + ('b', 'b'), + (nprocs_per_dim, nprocs_per_dim)) row_chunks = nrows // nprocs_per_dim + 1 col_chunks = ncols // nprocs_per_dim + 1 for r in range(nrows): @@ -45,25 +46,28 @@ def test_2D_bb(self): def test_2D_cc(self): nrows, ncols = 3, 5 nprocs_per_dim = 2 - cm = client_map.Distribution.from_shape( - self.context, (nrows, ncols), ('c', 'c'), - (nprocs_per_dim, nprocs_per_dim)) + cm = maps.Distribution.from_shape(self.context, + (nrows, ncols), + ('c', 'c'), + (nprocs_per_dim, nprocs_per_dim)) for r in range(nrows): for c in range(ncols): - rank = (r % nprocs_per_dim) * nprocs_per_dim + (c % nprocs_per_dim) + rank = ((r % nprocs_per_dim) * nprocs_per_dim + + (c % nprocs_per_dim)) actual = cm.owning_ranks((r,c)) self.assertSequenceEqual(actual, [rank]) - def test_is_compatible(self): nr, nc, nd = 10**5, 10**6, 10**4 - cm0 = client_map.Distribution.from_shape( - self.context, (nr, nc, nd), ('b', 'c', 'n')) + cm0 = maps.Distribution.from_shape(self.context, + (nr, nc, nd), + ('b', 'c', 'n')) self.assertTrue(cm0.is_compatible(cm0)) - cm1 = client_map.Distribution.from_shape( - self.context, (nr, nc, nd), ('b', 'c', 'n')) + cm1 = maps.Distribution.from_shape(self.context, + (nr, nc, nd), + ('b', 'c', 'n')) self.assertTrue(cm1.is_compatible(cm1)) self.assertTrue(cm0.is_compatible(cm1)) @@ -71,8 +75,9 @@ def test_is_compatible(self): nr -= 1; nc -= 1; nd -= 1 - cm2 = client_map.Distribution.from_shape( - self.context, (nr, nc, nd), ('b', 'c', 'n')) + cm2 = maps.Distribution.from_shape(self.context, + (nr, nc, nd), + ('b', 'c', 'n')) self.assertFalse(cm1.is_compatible(cm2)) self.assertFalse(cm2.is_compatible(cm1)) @@ -80,9 +85,10 @@ def test_is_compatible(self): def test_reduce(self): nr, nc, nd = 10**5, 10**6, 10**4 - dist = client_map.Distribution.from_shape( - self.context, (nr, nc, nd), ('b', 'c', 'n'), - grid_shape=(2, 2, 1)) + dist = maps.Distribution.from_shape(self.context, + (nr, nc, nd), + ('b', 'c', 'n'), + grid_shape=(2, 2, 1)) new_dist0 = dist.reduce(axes=[0]) self.assertEqual(new_dist0.dist, ('c', 'n')) @@ -93,7 +99,8 @@ def test_reduce(self): new_dist1 = dist.reduce(axes=[1]) self.assertEqual(new_dist1.dist, ('b', 'n')) self.assertSequenceEqual(new_dist1.shape, (nr, nd)) - self.assertEqual(new_dist1.grid_shape, dist.grid_shape[:1]+dist.grid_shape[2:]) + self.assertEqual(new_dist1.grid_shape, + dist.grid_shape[:1] + dist.grid_shape[2:]) self.assertLess(set(new_dist1.targets), set(dist.targets)) new_dist2 = dist.reduce(axes=[2]) @@ -104,7 +111,7 @@ def test_reduce(self): def test_reduce_0D(self): N = 10**5 - dist = client_map.Distribution.from_shape(self.context, (N,)) + dist = maps.Distribution.from_shape(self.context, (N,)) new_dist = dist.reduce(axes=[0]) self.assertEqual(new_dist.dist, ()) self.assertSequenceEqual(new_dist.shape, ()) @@ -112,13 +119,72 @@ def test_reduce_0D(self): self.assertEqual(set(new_dist.targets), set(dist.targets[:1])) +class TestSlice(ContextTestCase): + + def test_from_partial_slice_1d(self): + d0 = maps.Distribution.from_shape(context=self.context, shape=(15,)) + + s = (slice(0, 3),) + d1 = d0.slice(s) + + self.assertEqual(len(d0.maps), len(d1.maps)) + self.assertSequenceEqual(d1.dist, d0.dist) + self.assertSequenceEqual(d1.targets, [0]) + self.assertSequenceEqual(d1.shape, (3,)) + + def test_from_full_slice_1d(self): + d0 = maps.Distribution.from_shape(context=self.context, shape=(15,)) + + s = (slice(None),) + d1 = d0.slice(s) + + self.assertEqual(len(d0.maps), len(d1.maps)) + self.assertSequenceEqual(d1.dist, d0.dist) + self.assertSequenceEqual(d1.targets, d0.targets) + self.assertSequenceEqual(d1.maps[0].bounds, d0.maps[0].bounds) + + def test_from_full_slice_2d(self): + d0 = maps.Distribution.from_shape(context=self.context, shape=(15, 20)) + + s = (slice(None), slice(None)) + d1 = d0.slice(s) + + self.assertEqual(len(d0.maps), len(d1.maps)) + self.assertSequenceEqual(d1.dist, d0.dist) + for m0, m1 in zip(d0.maps, d1.maps): + if m0.dist == 'b': + self.assertSequenceEqual(m0.bounds, m1.bounds) + self.assertSequenceEqual(d1.targets, d0.targets) + + def test_from_partial_slice_2d(self): + d0 = maps.Distribution.from_shape(context=self.context, shape=(15, 20)) + + s = (slice(3, 7), 4) + d1 = d0.slice(s) + + self.assertEqual(len(d0.maps)-1, len(d1.maps)) + self.assertSequenceEqual(d1.dist, d0.dist[:-1]) + for m, expected in zip(d1.maps, ([(0, 1), (1, 4)], [(0, 1)])): + self.assertSequenceEqual(m.bounds, expected) + + def test_full_slice_with_int_2d(self): + d0 = maps.Distribution.from_shape(context=self.context, shape=(15, 20)) + + s = (slice(None), 4) + d1 = d0.slice(s) + + self.assertEqual(len(d0.maps)-1, len(d1.maps)) + self.assertSequenceEqual(d1.dist, d0.dist[:-1]) + self.assertEqual(d1.shape, (15,)) + + class TestDunderMethods(ContextTestCase): @classmethod def setUpClass(cls): super(TestDunderMethods, cls).setUpClass() cls.shape = (3, 4, 5, 6) - cls.cm = client_map.Distribution.from_shape(cls.context, cls.shape) + cls.cm = maps.Distribution.from_shape(cls.context, cls.shape) def test___len__(self): self.assertEqual(len(self.cm), 4) @@ -135,9 +201,9 @@ def test___getitem__(self): class TestDistributionCreation(ContextTestCase): def test_all_n_dist(self): - distribution = client_map.Distribution.from_shape(self.context, - shape=(3, 3), - dist=('n', 'n')) + distribution = maps.Distribution.from_shape(self.context, + shape=(3, 3), + dist=('n', 'n')) self.context.ones(distribution) diff --git a/distarray/local/localarray.py b/distarray/local/localarray.py index b2fe829e..b8dc3ea2 100644 --- a/distarray/local/localarray.py +++ b/distarray/local/localarray.py @@ -11,37 +11,22 @@ # Imports # --------------------------------------------------------------------------- from collections import Mapping -from numbers import Integral import numpy as np from distarray.externals import six from distarray.externals.six.moves import zip -from distarray.local.mpiutils import MPI from distarray.utils import _raise_nie +from distarray.metadata_utils import sanitize_indices + +from distarray.local.mpiutils import MPI from distarray.local import format, maps from distarray.local.error import InvalidDimensionError, IncompatibleArrayError -# Register numpy integer types with numbers.Integral ABC. -Integral.register(np.signedinteger) -Integral.register(np.unsignedinteger) - - -def _sanitize_indices(indices): - if isinstance(indices, Integral) or isinstance(indices, slice): - return (indices,) - elif all(isinstance(i, Integral) or isinstance(i, slice) for i in indices): - return indices - else: - raise TypeError("Index must be a sequence of ints and slices") - - class GlobalIndex(object): - """Object which provides access to global indexing on - LocalArrays. - """ + """Object which provides access to global indexing on LocalArrays.""" def __init__(self, distribution, ndarray): self.distribution = distribution self.ndarray = ndarray @@ -65,16 +50,32 @@ def global_to_local(self, *global_ind): def local_to_global(self, *local_ind): return self.distribution.global_from_local(*local_ind) + def get_slice(self, global_inds, new_distribution): + try: + local_inds = self.global_to_local(*global_inds) + except KeyError as err: + raise IndexError(err) + view = self.ndarray[local_inds] + return LocalArray(distribution=new_distribution, + dtype=self.ndarray.dtype, + buf=view) + def __getitem__(self, global_inds): - global_inds = _sanitize_indices(global_inds) + return_type, global_inds = sanitize_indices(global_inds) + if return_type == 'view': + msg = "__getitem__ does not support slices. See `get_slice`." + raise TypeError(msg) + try: local_inds = self.global_to_local(*global_inds) - return self.ndarray[local_inds] except KeyError as err: raise IndexError(err) + return self.ndarray[local_inds] + + def __setitem__(self, global_inds, value): - global_inds = _sanitize_indices(global_inds) + _, global_inds = sanitize_indices(global_inds) try: local_inds = self.global_to_local(*global_inds) self.ndarray[local_inds] = value @@ -399,7 +400,14 @@ def __len__(self): def __getitem__(self, index): """Get a local item.""" - return self.ndarray[index] + return_type, index = sanitize_indices(index) + if return_type == 'value': + return self.ndarray[index] + elif return_type == 'view': + msg = "__getitem__ does not support slices. See `global_index.get_item`." + raise TypeError(msg) + else: + assert False # impossible is nothing def __setitem__(self, index, value): """Set a local item.""" diff --git a/distarray/local/maps.py b/distarray/local/maps.py index dc654037..e0c03be5 100644 --- a/distarray/local/maps.py +++ b/distarray/local/maps.py @@ -22,6 +22,7 @@ import operator from functools import reduce +from numbers import Integral import numpy as np from distarray.externals.six.moves import range, zip @@ -29,7 +30,11 @@ from distarray.local import construct from distarray.metadata_utils import (make_grid_shape, normalize_grid_shape, normalize_dist, distribute_indices, - positivify) + sanitize_indices) + +# Register numpy integer types with numbers.Integral ABC. +Integral.register(np.signedinteger) +Integral.register(np.unsignedinteger) class Distribution(object): @@ -134,14 +139,28 @@ def rank_from_coords(self, coords): def local_from_global(self, *global_ind): """ Given `global_ind` indices, translate into local indices.""" - global_ind = tuple(map(positivify, global_ind, self.global_shape)) - return tuple(self._maps[dim].local_from_global(global_ind[dim]) - for dim in range(self.ndim)) + _, idxs = sanitize_indices(global_ind, self.ndim, self.global_shape) + local_idxs = [] + for m, idx in zip(self._maps, global_ind): + if isinstance(idx, Integral): + local_idxs.append(m.local_from_global_index(idx)) + elif isinstance(idx, slice): + local_idxs.append(m.local_from_global_slice(idx)) + else: + raise TypeError("Index must be Integral or slice.") + return tuple(local_idxs) def global_from_local(self, *local_ind): """ Given `local_ind` indices, translate into global indices.""" - return tuple(self._maps[dim].global_from_local(local_ind[dim]) - for dim in range(self.ndim)) + global_idxs = [] + for m, idx in zip(self._maps, local_ind): + if isinstance(idx, Integral): + global_idxs.append(m.global_from_local_index(idx)) + elif isinstance(idx, slice): + global_idxs.append(m.global_from_local_slice(idx)) + else: + raise TypeError("Index must be Integral or slice.") + return tuple(global_idxs) def map_from_dim_dict(dd): @@ -199,16 +218,30 @@ def __init__(self, global_size, grid_size, grid_rank, start, stop): self.grid_size = grid_size self.grid_rank = grid_rank - def local_from_global(self, gidx): + def local_from_global_index(self, gidx): if gidx < self.start or gidx >= self.stop: raise IndexError("Global index %s out of bounds" % gidx) return gidx - self.start - def global_from_local(self, lidx): + def local_from_global_slice(self, gidx): + start = gidx.start if gidx.start is not None else 0 + stop = gidx.stop if gidx.stop is not None else self.global_size + new_start = max(start - self.start, 0) # prevent negative inds + new_stop = stop - self.start + return slice(new_start, new_stop) + + def global_from_local_index(self, lidx): if lidx >= self.local_size: raise IndexError("Local index %s out of bounds" % lidx) return lidx + self.start + def global_from_local_slice(self, lidx): + start = lidx.start if lidx.start is not None else 0 + stop = lidx.stop if lidx.stop is not None else self.global_size + new_start = start + self.start + new_stop = stop + self.start + return slice(new_start, new_stop) + @property def dim_dict(self): return {'dist_type': self.dist, @@ -248,13 +281,12 @@ def __init__(self, global_size, grid_size, grid_rank, start): self.local_size = (global_size - 1 - grid_rank) // grid_size + 1 self.global_size = global_size - - def local_from_global(self, gidx): + def local_from_global_index(self, gidx): if (gidx - self.start) % self.grid_size: raise IndexError("Global index %s out of bounds" % gidx) return (gidx - self.start) // self.grid_size - def global_from_local(self, lidx): + def global_from_local_index(self, lidx): if lidx >= self.local_size: raise IndexError("Local index %s out of bounds" % lidx) return (lidx * self.grid_size) + self.start @@ -299,14 +331,13 @@ def __init__(self, global_size, grid_size, grid_rank, start, block_size): self.local_size = local_nblocks * block_size + local_partial self.global_size = global_size - - def local_from_global(self, gidx): + def local_from_global_index(self, gidx): global_block, offset = divmod(gidx, self.block_size) if (global_block - self.start_block) % self.grid_size: raise IndexError("Global index %s out of bounds" % gidx) return self.block_size * ((global_block - self.start_block) // self.grid_size) + offset - def global_from_local(self, lidx): + def global_from_local_index(self, lidx): if lidx >= self.local_size: raise IndexError("Local index %s out of bounds" % lidx) local_block, offset = divmod(lidx, self.block_size) @@ -328,7 +359,7 @@ def global_iter(self): _global_index = np.empty((self.local_size,), dtype=np.int32) # FIXME: this is the slow way to do this... for i in range(self.local_size): - _global_index[i] = self.global_from_local(i) + _global_index[i] = self.global_from_local_index(i) return iter(_global_index) @property @@ -351,14 +382,14 @@ def __init__(self, global_size, grid_size, grid_rank, indices): local_indices = range(self.local_size) self._local_index = dict(zip(self.indices, local_indices)) - def local_from_global(self, gidx): + def local_from_global_index(self, gidx): try: lidx = self._local_index[gidx] except KeyError: raise IndexError("Global index %s out of bounds" % gidx) return lidx - def global_from_local(self, lidx): + def global_from_local_index(self, lidx): return self.indices[lidx] @property diff --git a/distarray/local/tests/paralleltest_localarray.py b/distarray/local/tests/paralleltest_localarray.py index fc40869f..4d5a24da 100644 --- a/distarray/local/tests/paralleltest_localarray.py +++ b/distarray/local/tests/paralleltest_localarray.py @@ -7,11 +7,12 @@ import unittest import numpy as np +from numpy.testing import assert_array_equal from distarray import utils from distarray.testing import (MpiTestCase, assert_localarrays_allclose, assert_localarrays_equal) -from distarray.local.localarray import LocalArray, ndenumerate +from distarray.local.localarray import LocalArray, ndenumerate, ones from distarray.local.maps import Distribution from distarray.local.error import InvalidDimensionError, IncompatibleArrayError @@ -340,6 +341,44 @@ def test_pack_unpack_index(self): self.assertEqual(global_inds, a.unpack_index(packed_ind)) +class TestSlicing(MpiTestCase): + + comm_size = 2 + + def test_slicing(self): + distribution = Distribution.from_shape(self.comm, + (16, 16), + dist=('b', 'n')) + a = ones(distribution) + if self.comm.Get_rank() == 0: + dd00 = {"dist_type": 'b', + "size": 5, + "start": 0, + "stop": 3, + "proc_grid_size": 2, + "proc_grid_rank": 0} + dd01 = {"dist_type": 'n', + "size": 16} + + new_distribution = Distribution(self.comm, [dd00, dd01]) + rvals = a.global_index.get_slice((slice(5, None), slice(None)), + new_distribution=new_distribution) + assert_array_equal(rvals, np.ones((3, 16))) + + elif self.comm.Get_rank() == 1: + dd10 = {"dist_type": 'b', + "size": 5, + "start": 3, + "stop": 5, + "proc_grid_size": 2, + "proc_grid_rank": 1} + dd11 = {"dist_type": 'n', + "size": 16} + new_distribution = Distribution(self.comm, [dd10, dd11]) + rvals = a.global_index.get_slice((slice(None, 10), slice(None)), + new_distribution=new_distribution) + assert_array_equal(rvals, np.ones((2, 16))) + class TestLocalArrayMethods(MpiTestCase): ddpr = [ diff --git a/distarray/local/tests/test_maps.py b/distarray/local/tests/test_maps.py index 0c19f283..6a3bebba 100644 --- a/distarray/local/tests/test_maps.py +++ b/distarray/local/tests/test_maps.py @@ -17,25 +17,25 @@ def setUp(self): dimdict = dict(dist_type='n', size=size) self.m = maps.map_from_dim_dict(dimdict) - def test_local_from_global(self): + def test_local_from_global_index(self): gis = range(0, 20) - lis = [self.m.local_from_global(gi) for gi in gis] + lis = [self.m.local_from_global_index(gi) for gi in gis] expected = list(range(20)) self.assertSequenceEqual(lis, expected) - def test_local_from_global_IndexError(self): + def test_local_from_global_index_IndexError(self): gi = 20 - self.assertRaises(IndexError, self.m.local_from_global, gi) + self.assertRaises(IndexError, self.m.local_from_global_index, gi) - def test_global_from_local(self): + def test_global_from_local_index(self): lis = range(20) - gis = [self.m.global_from_local(li) for li in lis] + gis = [self.m.global_from_local_index(li) for li in lis] expected = list(range(20)) self.assertSequenceEqual(gis, expected) - def test_global_from_local_IndexError(self): + def test_global_from_local_index_IndexError(self): li = 20 - self.assertRaises(IndexError, self.m.global_from_local, li) + self.assertRaises(IndexError, self.m.global_from_local_index, li) class TestBlockMap(unittest.TestCase): @@ -44,28 +44,28 @@ def setUp(self): dimdict = dict(dist_type='b', size=(39-16), start=16, stop=39) self.m = maps.map_from_dim_dict(dimdict) - def test_local_from_global(self): + def test_local_from_global_index(self): gis = range(16, 39) - lis = [self.m.local_from_global(gi) for gi in gis] + lis = [self.m.local_from_global_index(gi) for gi in gis] expected = list(range(23)) self.assertSequenceEqual(lis, expected) - def test_local_from_global_IndexError(self): + def test_local_from_global_index_IndexError(self): gi = 15 - self.assertRaises(IndexError, self.m.local_from_global, gi) + self.assertRaises(IndexError, self.m.local_from_global_index, gi) gi = 39 - self.assertRaises(IndexError, self.m.local_from_global, gi) + self.assertRaises(IndexError, self.m.local_from_global_index, gi) - def test_global_from_local(self): + def test_global_from_local_index(self): lis = range(23) - gis = [self.m.global_from_local(li) for li in lis] + gis = [self.m.global_from_local_index(li) for li in lis] expected = list(range(16, 39)) self.assertSequenceEqual(gis, expected) - def test_global_from_local_IndexError(self): + def test_global_from_local_index_IndexError(self): li = 25 - self.assertRaises(IndexError, self.m.global_from_local, li) + self.assertRaises(IndexError, self.m.global_from_local_index, li) class TestCyclicMap(unittest.TestCase): @@ -74,28 +74,28 @@ def setUp(self): dimdict = dict(dist_type='c', start=2, size=16, proc_grid_size=4, proc_grid_rank=2) self.m = maps.map_from_dim_dict(dimdict) - def test_local_from_global(self): + def test_local_from_global_index(self): gis = (2, 6, 10, 14) - lis = [self.m.local_from_global(gi) for gi in gis] + lis = [self.m.local_from_global_index(gi) for gi in gis] expected = tuple(range(4)) self.assertSequenceEqual(lis, expected) - def test_local_from_global_IndexError(self): + def test_local_from_global_index_IndexError(self): gi = 3 - self.assertRaises(IndexError, self.m.local_from_global, gi) + self.assertRaises(IndexError, self.m.local_from_global_index, gi) gi = 7 - self.assertRaises(IndexError, self.m.local_from_global, gi) + self.assertRaises(IndexError, self.m.local_from_global_index, gi) - def test_global_from_local(self): + def test_global_from_local_index(self): lis = range(4) - gis = [self.m.global_from_local(li) for li in lis] + gis = [self.m.global_from_local_index(li) for li in lis] expected = (2, 6, 10, 14) self.assertSequenceEqual(gis, expected) - def test_global_from_local_IndexError(self): + def test_global_from_local_index_IndexError(self): li = 5 - self.assertRaises(IndexError, self.m.global_from_local, li) + self.assertRaises(IndexError, self.m.global_from_local_index, li) class TestBlockCyclicMap(unittest.TestCase): @@ -105,28 +105,29 @@ def setUp(self): block_size=2) self.m = maps.map_from_dim_dict(dimdict) - def test_local_from_global(self): + def test_local_from_global_index(self): """Test the local_index method of BlockCyclicMap.""" gis = (2, 3, 10, 11) - lis = [self.m.local_from_global(gi) for gi in gis] + lis = [self.m.local_from_global_index(gi) for gi in gis] expected = tuple(range(4)) self.assertSequenceEqual(lis, expected) - def test_local_from_global_IndexError(self): + def test_local_from_global_index_IndexError(self): gi = 4 - self.assertRaises(IndexError, self.m.local_from_global, gi) + self.assertRaises(IndexError, self.m.local_from_global_index, gi) gi = 12 - self.assertRaises(IndexError, self.m.local_from_global, gi) + self.assertRaises(IndexError, self.m.local_from_global_index, gi) - def test_global_from_local(self): + def test_global_from_local_index(self): lis = range(4) - gis = [self.m.global_from_local(li) for li in lis] + gis = [self.m.global_from_local_index(li) for li in lis] expected = (2, 3, 10, 11) self.assertSequenceEqual(gis, expected) - def test_global_from_local_IndexError(self): + def test_global_from_local_index_IndexError(self): li = 5 - self.assertRaises(IndexError, self.m.global_from_local, li) + self.assertRaises(IndexError, self.m.global_from_local_index, li) + class TestMapEquivalences(unittest.TestCase): @@ -145,8 +146,8 @@ def test_compare_bcm_bm_local_index(self): [('dist_type', 'b'), ('stop', size // grid + start)])) - bcm_lis = [bcm.local_from_global(e) for e in range(4, 8)] - bm_lis = [bm.local_from_global(e) for e in range(4, 8)] + bcm_lis = [bcm.local_from_global_index(e) for e in range(4, 8)] + bm_lis = [bm.local_from_global_index(e) for e in range(4, 8)] self.assertSequenceEqual(bcm_lis, bm_lis) def test_compare_bcm_cm_local_index(self): @@ -161,8 +162,8 @@ def test_compare_bcm_cm_local_index(self): [('dist_type', 'c')])) cm = maps.map_from_dim_dict(dict(list(dimdict.items()) + [('dist_type', 'c')])) - bcm_lis = [bcm.local_from_global(e) for e in range(1, 16, 4)] - cm_lis = [cm.local_from_global(e) for e in range(1, 16, 4)] + bcm_lis = [bcm.local_from_global_index(e) for e in range(1, 16, 4)] + cm_lis = [cm.local_from_global_index(e) for e in range(1, 16, 4)] self.assertSequenceEqual(bcm_lis, cm_lis) diff --git a/distarray/metadata_utils.py b/distarray/metadata_utils.py index 1bf3d6e0..d63e15aa 100644 --- a/distarray/metadata_utils.py +++ b/distarray/metadata_utils.py @@ -7,13 +7,19 @@ import operator from itertools import product from functools import reduce +from numbers import Integral from collections import Sequence, Mapping import numpy from distarray import utils from distarray.externals.six import next -from distarray.externals.six.moves import map +from distarray.externals.six.moves import map, zip + + +# Register numpy integer types with numbers.Integral ABC. +Integral.register(numpy.signedinteger) +Integral.register(numpy.unsignedinteger) class InvalidGridShapeError(Exception): @@ -266,13 +272,112 @@ def normalize_dim_dict(dd): dd['proc_grid_rank'] = 0 -def positivify(index, size): - if 0 <= index < size: +def _positivify(index, size): + """Return a positive index offset from a Sequence's start.""" + if index is None or index >= 0: return index - elif -size <= index < 0: + elif index < 0: return size + index + +def _check_bounds(index, size): + """Check if an index is in bounds. + + Assumes a positive index as returned by _positivify. + """ + if not 0 <= index < size: + raise IndexError("Index %r out of bounds" % index) + + +def tuple_intersection(t1, t2): + """Compute intersection of two (start, stop) tuples. + + Parameters + ---------- + t1, t2 : 2-tuples + + Returns + ------- + 2-tuple or None + """ + stop = min(t1[1], t2[1]) + start = max(t1[0], t2[0]) + return (start, stop) if stop - start > 0 else None + + +def positivify(index, size): + """Check that an index is within bounds and return a positive version. + + Parameters + ---------- + index : Integral or slice + size : Integral + + Raises + ------ + IndexError + for out-of-bounds indices + NotImplementedError + for negative steps + """ + if isinstance(index, Integral): + index = _positivify(index, size) + _check_bounds(index, size) + return index + elif isinstance(index, slice): + start = _positivify(index.start, size) + stop = _positivify(index.stop, size) + # slice indexing doesn't check bounds + return slice(start, stop, index.step) + else: + raise TypeError("`index` must be of type Integral or slice.") + + +def sanitize_indices(indices, ndim=None, shape=None): + """Classify and sanitize `indices`. + + * Wrap Integral or slice indices into tuples + * Classify as 'value' or 'view' + * If the length of the tuple-ized `indices` is < ndim (and it's + provided), add slice(None)'s to indices until `indices` is ndim long + * If `shape` is provided, call `positivify` on the indices + + Raises + ------ + TypeError + If `indices` is other than Integral, slice or a Sequence of these + IndexError + If len(indices) > ndim + + Returns + ------- + 2-tuple of (str, ndim-tuple of slices and Integral values) + """ + if isinstance(indices, Integral): + rtype, sanitized = 'value', (indices,) + elif isinstance(indices, slice): + rtype, sanitized = 'view', (indices,) + elif all(isinstance(i, Integral) for i in indices): + rtype, sanitized = 'value', indices + elif all(isinstance(i, Integral) or isinstance(i, slice) for i in indices): + rtype, sanitized = 'view', indices else: - raise IndexError("Index %s out of bounds" % index) + msg = ("Index must be an Integral, a slice, or a sequence of " + "Integrals and slices.") + raise TypeError(msg) + + if ndim is not None: + diff = ndim - len(sanitized) + if diff < 0: + raise IndexError("Too many indices.") + if diff > 0: + # allow incomplete indexing + rtype = 'view' + sanitized = sanitized + (slice(None),) * diff + + if shape is not None: + sanitized = tuple(positivify(i, size) for (i, size) in zip(sanitized, + shape)) + return (rtype, sanitized) def normalize_reduction_axes(axes, ndim): diff --git a/distarray/tests/test_metadata_utils.py b/distarray/tests/test_metadata_utils.py index fe8760c2..1d944d54 100644 --- a/distarray/tests/test_metadata_utils.py +++ b/distarray/tests/test_metadata_utils.py @@ -26,6 +26,100 @@ def test_negative_index(self): result = metadata_utils.positivify(-2, 10) self.assertEqual(result, 8) + def test_out_of_bounds_positive(self): + with self.assertRaises(IndexError): + metadata_utils.positivify(11, 10) + + def test_out_of_bounds_negative(self): + with self.assertRaises(IndexError): + metadata_utils.positivify(-51, 10) + + def test_positive_slice(self): + s = slice(5, 7) + result = metadata_utils.positivify(s, 10) + self.assertEqual(result, s) + + def test_negative_slice_stop(self): + s = slice(5, -2) + result = metadata_utils.positivify(s, 10) + expected = slice(5, 8) + self.assertEqual(result, expected) + + def test_no_slice_start(self): + s = slice(5) + result = metadata_utils.positivify(s, 10) + expected = s + self.assertEqual(result, expected) + + def test_no_slice_stop(self): + s = slice(5, None) + result = metadata_utils.positivify(s, 10) + expected = s + self.assertEqual(result, expected) + + def test_positive_slice_with_step(self): + s = slice(5, 7, 2) + result = metadata_utils.positivify(s, 10) + expected = s + self.assertEqual(result, expected) + + def test_negative_slice_with_step(self): + s = slice(-7, -1, 2) + result = metadata_utils.positivify(s, 10) + expected = slice(3, 9, 2) + self.assertEqual(result, expected) + + def test_out_of_bounds_slice(self): + s = slice(50, 90) + result = metadata_utils.positivify(s, 10) + self.assertEqual(result, s) + + +class TestSanitizeIndices(unittest.TestCase): + + def test_value_index(self): + tag, sanitized = metadata_utils.sanitize_indices(10) + self.assertSequenceEqual(sanitized, (10,)) + self.assertEqual(tag, 'value') + + def test_slice_index(self): + tag, sanitized = metadata_utils.sanitize_indices(slice(10, 20)) + self.assertSequenceEqual(sanitized, (slice(10, 20),)) + self.assertEqual(tag, 'view') + + def test_tuple_of_values(self): + tag, sanitized = metadata_utils.sanitize_indices((5, 10)) + self.assertSequenceEqual(sanitized, (5, 10)) + self.assertEqual(tag, 'value') + + def test_tuple_of_slices(self): + slices = slice(10, 20), slice(20, 30), slice(40, 50) + tag, sanitized = metadata_utils.sanitize_indices(slices) + self.assertSequenceEqual(sanitized, slices) + self.assertEqual(tag, 'view') + + def test_tuple_of_mixed(self): + slices = slice(10, 20), 25, slice(40, 50) + tag, sanitized = metadata_utils.sanitize_indices(slices) + self.assertSequenceEqual(sanitized, slices) + self.assertEqual(tag, 'view') + + def test_incomplete_indexing_values(self): + slices = 10, 20, 25, 40, 50 + tag, sanitized = metadata_utils.sanitize_indices(slices, ndim=10) + self.assertSequenceEqual(sanitized, slices + (slice(None),) * 5) + self.assertEqual(tag, 'view') + + def test_incomplete_indexing_mixed(self): + slices = slice(10, 20), 25, slice(40, 50) + tag, sanitized = metadata_utils.sanitize_indices(slices, ndim=10) + self.assertSequenceEqual(sanitized, slices + (slice(None),) * 7) + self.assertEqual(tag, 'view') + + def test_too_many_indices(self): + with self.assertRaises(IndexError): + metadata_utils.sanitize_indices((2, 3, 4), ndim=2) + class TestGridSizes(unittest.TestCase): diff --git a/distarray/tests/test_utils.py b/distarray/tests/test_utils.py index c6308699..7dd37b1b 100644 --- a/distarray/tests/test_utils.py +++ b/distarray/tests/test_utils.py @@ -31,24 +31,6 @@ def test_mult_partitions(self): self.assertEqual(utils.mult_partitions(6, 3), [(1, 1, 6), (1, 2, 3)]) -class TestSanitizeIndices(unittest.TestCase): - - def test_point(self): - itype, inds = utils.sanitize_indices(1) - self.assertEqual(itype, 'point') - self.assertEqual(inds, (1,)) - - def test_slice(self): - itype, inds = utils.sanitize_indices(slice(1,10)) - self.assertEqual(itype, 'view') - self.assertEqual(inds, (slice(1,10),)) - - def test_mixed(self): - provided = (5, 3, slice(7, 10, 2), 99, slice(1,10)) - itype, inds = utils.sanitize_indices(provided) - self.assertEqual(itype, 'view') - self.assertEqual(inds, provided) - class TestSliceIntersection(unittest.TestCase): @@ -105,5 +87,7 @@ def test_count_round_trips(self): view.execute('42') self.assertEqual(r.count, len(view)) + + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/distarray/utils.py b/distarray/utils.py index c631fdaf..8bfa6c07 100644 --- a/distarray/utils.py +++ b/distarray/utils.py @@ -109,37 +109,6 @@ def _raise_nie(): raise NotImplementedError(msg) -def sanitize_indices(indices): - """Check and possibly sanitize indices. - - Parameters - ---------- - indices : int, slice, or sequence of ints and slices - If an int or slice is passed in, it is converted to a - 1-tuple. - - Returns - ------- - 2-tuple - ('point', indices) if all `indices` are ints, or - ('view', indices) if some `indices` are slices. - - Raises - ------ - TypeError - If `indices` is not all ints or slices. - """ - - if isinstance(indices, int) or isinstance(indices, slice): - return sanitize_indices((indices,)) - elif all(isinstance(i, int) for i in indices): - return 'point', indices - elif all(isinstance(i, int) or isinstance(i, slice) for i in indices): - return 'view', indices - else: - raise TypeError("Index must be a sequence of ints and slices") - - def slice_intersection(s1, s2): """Compute a slice that represents the intersection of two slices.