diff --git a/distarray/dist/maps.py b/distarray/dist/maps.py index 2e9a6747..c636a087 100644 --- a/distarray/dist/maps.py +++ b/distarray/dist/maps.py @@ -21,7 +21,7 @@ `UnstructuredMap`. """ -from __future__ import absolute_import +from __future__ import division, absolute_import import operator from itertools import product @@ -200,7 +200,13 @@ def index_owners(self, idx): return [0] if 0 <= idx < self.size else [] def slice_owners(self, idx): - return [0] # slicing doesn't complain about out-of-bounds indices + start = idx.start if idx.start is not None else 0 + stop = idx.stop if idx.stop is not None else self.size + step = idx.step if idx.step is not None else 1 + if tuple_intersection((start, stop, step), (0, self.size)): + return [0] + else: + return [] def get_dimdicts(self): return ({ @@ -214,14 +220,16 @@ def slice(self, idx): """Make a new Map from a slice.""" start = idx.start if idx.start is not None else 0 stop = idx.stop if idx.stop is not None else self.size - intersection = tuple_intersection((0, self.size), (start, stop)) - if intersection: - intersection_size = intersection[1] - intersection[0] + step = idx.step if idx.step is not None else 1 + isection = tuple_intersection((start, stop, step), (0, self.size)) + if isection: + step = idx.step if idx.step is not None else 1 + isection_size = int(np.ceil((isection[1] - isection[0]) / step)) else: - intersection_size = 0 + isection_size = 0 return {'dist_type': self.dist, - 'size': intersection_size} + 'size': isection_size} class BlockMap(MapBase): @@ -281,15 +289,13 @@ def index_owners(self, idx): def slice_owners(self, idx): coords = [] - if idx.step not in {None, 1}: - msg = "Slicing only implemented for step=1" - raise NotImplementedError(msg) + start = idx.start if idx.start is not None else 0 + stop = idx.stop if idx.stop is not None else self.size + step = idx.step if idx.step is not None else 1 for (coord, (lower, upper)) in enumerate(self.bounds): - slice_tuple = (idx.start if idx.start is not None else 0, - idx.stop if idx.stop is not None else self.size) - if tuple_intersection((lower, upper), slice_tuple): + if tuple_intersection((start, stop, step), (lower, upper)): coords.append(coord) - return coords if coords != [] else [0] + return coords def get_dimdicts(self): grid_ranks = range(len(self.bounds)) @@ -316,14 +322,17 @@ def slice(self, idx): """Make a new Map from a slice.""" new_bounds = [0] start = idx.start if idx.start is not None else 0 + step = idx.step if idx.step is not None else 1 # iterate over the processes in this dimension for proc_start, proc_stop in self.bounds: stop = idx.stop if idx.stop is not None else proc_stop - intersection = tuple_intersection((proc_start, proc_stop), - (start, stop)) - if intersection: - size = intersection[1] - intersection[0] - new_bounds.append(size + new_bounds[-1]) + isection = tuple_intersection((start, stop, step), + (proc_start, proc_stop)) + if isection: + isection_size = int(np.ceil((isection[1] - (isection[0])) / step)) + new_bounds.append(isection_size + new_bounds[-1]) + if len(new_bounds) == [0]: + new_bounds = [] return {'dist_type': self.dist, 'bounds': new_bounds} diff --git a/distarray/dist/tests/test_distarray.py b/distarray/dist/tests/test_distarray.py index 2f587e30..928566c3 100644 --- a/distarray/dist/tests/test_distarray.py +++ b/distarray/dist/tests/test_distarray.py @@ -146,6 +146,13 @@ def test_slice_a_slice_block_dist_1(self): s2 = s1[-2:] assert_array_equal(s2.toarray(), expected[3:5]) + def test_slice_block_dist_1d_with_step(self): + size = 10 + step = 2 + expected = numpy.random.randint(10, size=size) + darr = self.context.fromarray(expected) + assert_array_equal(darr[::step].toarray(), expected[::step]) + def test_partial_slice_block_dist_2d(self): shape = (10, 20) expected = numpy.random.randint(10, size=shape) @@ -233,6 +240,15 @@ def test_large_1d_slice(self): arr[slc] = new_data assert_array_equal(arr.toarray(), source) + def test_1d_slice_with_step(self): + source = numpy.random.randint(10, size=20) + new_data = numpy.random.randint(10, size=5) + slc = slice(7, 17, 2) + arr = self.context.fromarray(source) + source[slc] = new_data + arr[slc] = new_data + assert_array_equal(arr.toarray(), source) + def test_2d_slice_0(self): # on process boundaries source = numpy.random.randint(10, size=(10, 20)) @@ -243,6 +259,15 @@ def test_2d_slice_0(self): arr[slc] = new_data assert_array_equal(arr.toarray(), source) + def test_2d_slice_with_step(self): + source = numpy.random.randint(10, size=(10, 20)) + new_data = numpy.random.randint(10, size=(2, 5)) + slc = (slice(5, 10, 3), slice(5, 15, 2)) + arr = self.context.fromarray(source) + source[slc] = new_data + arr[slc] = new_data + assert_array_equal(arr.toarray(), source) + def test_2d_slice_1(self): # off process boundaries source = numpy.random.randint(10, size=(10, 20)) @@ -271,6 +296,15 @@ def test_full_3d_slice_ellipsis(self): arr[slc] = new_data assert_array_equal(arr.toarray(), source) + def test_3d_slice_ellipsis_with_step(self): + source = numpy.random.randint(10, size=(5, 4, 5)) + new_data = numpy.random.randint(10, size=(5, 2, 5)) + slc = (Ellipsis, slice(None, None, 2), Ellipsis) + arr = self.context.fromarray(source) + source[slc] = new_data + arr[slc] = new_data + assert_array_equal(arr.toarray(), source) + def test_partial_indexing_0(self): source = numpy.random.randint(10, size=(3, 4, 5)) new_data = numpy.random.randint(10, size=(4, 5)) diff --git a/distarray/dist/tests/test_maps.py b/distarray/dist/tests/test_maps.py index 9af40fc7..9ebd8392 100644 --- a/distarray/dist/tests/test_maps.py +++ b/distarray/dist/tests/test_maps.py @@ -4,6 +4,8 @@ # Distributed under the terms of the BSD License. See COPYING.rst. # --------------------------------------------------------------------------- +from __future__ import division + import unittest from random import randrange @@ -143,6 +145,29 @@ def test_from_full_slice_1d(self): self.assertSequenceEqual(d1.targets, d0.targets) self.assertSequenceEqual(d1.maps[0].bounds, d0.maps[0].bounds) + def test_from_full_slice_with_step_1d_0(self): + d0 = maps.Distribution.from_shape(context=self.context, shape=(15,)) + + s = (slice(None, None, 2),) + d1 = d0.slice(s) + + self.assertEqual(len(d0.maps), len(d1.maps)) + self.assertSequenceEqual(d1.dist, d0.dist) + self.assertSequenceEqual(d1.targets, d0.targets) + self.assertEqual(d1.maps[0].bounds[0][0], d0.maps[0].bounds[0][0]) + + def test_from_full_slice_with_step_1d_1(self): + d0 = maps.Distribution.from_shape(context=self.context, shape=(30,)) + step = 4 + + s = (slice(4, None, step),) + d1 = d0.slice(s) + + self.assertEqual(len(d0.maps), len(d1.maps)) + self.assertSequenceEqual(d1.dist, d0.dist) + self.assertSequenceEqual(d1.targets, d0.targets) + self.assertEqual(d1.maps[0].bounds[0][0], d0.maps[0].bounds[0][0]) + def test_from_full_slice_2d(self): d0 = maps.Distribution.from_shape(context=self.context, shape=(15, 20)) diff --git a/distarray/local/maps.py b/distarray/local/maps.py index e0c03be5..2c187912 100644 --- a/distarray/local/maps.py +++ b/distarray/local/maps.py @@ -224,11 +224,19 @@ def local_from_global_index(self, gidx): return gidx - self.start def local_from_global_slice(self, gidx): + # we don't make the effort to compute the exact slice + # `__getitem__` doesn't care about overly-large slices, we just + # have to get the offset from the start correct based on the `step` start = gidx.start if gidx.start is not None else 0 stop = gidx.stop if gidx.stop is not None else self.global_size - new_start = max(start - self.start, 0) # prevent negative inds + step = gidx.step if gidx.step is not None else 1 + new_start = start - self.start + if new_start < 0: # don't allow negative starts + new_start += step * abs(new_start // step) + if new_start < 0: + new_start += step new_stop = stop - self.start - return slice(new_start, new_stop) + return slice(new_start, new_stop, gidx.step) def global_from_local_index(self, lidx): if lidx >= self.local_size: diff --git a/distarray/metadata_utils.py b/distarray/metadata_utils.py index 945e5310..2c704ade 100644 --- a/distarray/metadata_utils.py +++ b/distarray/metadata_utils.py @@ -4,6 +4,8 @@ # Distributed under the terms of the BSD License. See COPYING.rst. # --------------------------------------------------------------------------- +from __future__ import division + import operator from itertools import product from functools import reduce @@ -288,20 +290,42 @@ def _check_bounds(index, size): raise IndexError("Index %r out of bounds" % index) -def tuple_intersection(t1, t2): - """Compute intersection of two (start, stop) tuples. +def tuple_intersection(t0, t1): + """Compute intersection of a (start, stop, step) and a (start, stop) tuple. + + Assumes all values are positive. Parameters ---------- - t1, t2 : 2-tuples + t0: 2-tuple or 3-tuple + Tuple of (start, stop, [step]) representing an index range + t1: 2-tuple + Tuple of (start, stop) representing an index range Returns ------- - 2-tuple or None + 3-tuple or None + A tightly bounded interval. """ - stop = min(t1[1], t2[1]) - start = max(t1[0], t2[0]) - return (start, stop) if stop - start > 0 else None + if len(t0) == 2 or t0[2] is None: + # default step is 1 + t0 = (t0[0], t0[1], 1) + + start0, stop0, step0 = t0 + start1, stop1 = t1 + if start0 < start1: + n = int(numpy.ceil((start1 - start0) / step0)) + start2 = start0 + n*step0 + else: + start2 = start0 + + max_stop = min(t0[1], t1[1]) + if (max_stop - start2) % step0 == 0: + n = ((max_stop - start2) // step0) - 1 + else: + n = (max_stop - start2) // step0 + stop2 = (start2 + n*step0) + 1 + return (start2, stop2, step0) if stop2 > start2 else None def positivify(index, size): @@ -384,7 +408,6 @@ def replace_ellipsis(idx): return idx sanitized = tuple(replace_ellipsis(i) for i in sanitized) - if ndim is not None: diff = ndim - len(sanitized) if diff < 0: diff --git a/distarray/tests/test_metadata_utils.py b/distarray/tests/test_metadata_utils.py index b57113a9..4a898101 100644 --- a/distarray/tests/test_metadata_utils.py +++ b/distarray/tests/test_metadata_utils.py @@ -139,6 +139,81 @@ def test_multiple_ellipsis(self): ndim=ndim) self.assertEqual(sanitized, (slice(None),) * 4 + (10, slice(None))) + def test_step(self): + # currently doesn't touch step + indices = (slice(None, None, 2), slice(None, 8, 4)) + tag, sanitized = metadata_utils.sanitize_indices(indices) + self.assertEqual(tag, 'view') + self.assertEqual(sanitized, indices) + + +class TestTupleIntersection(unittest.TestCase): + + def check_intersection_and_reverse(self, t0, t1, expected): + result = metadata_utils.tuple_intersection(t0, t1) + self.assertEqual(result, expected) + result = metadata_utils.tuple_intersection(t1, t0) + self.assertEqual(result, expected) + + def test_no_step_full_enclosure(self): + t0 = (0, 60) + t1 = (15, 30) + expected = (15, 30, 1) + self.check_intersection_and_reverse(t0, t1, expected) + + def test_no_step_partial_overlap(self): + t0 = (0, 60) + t1 = (15, 90) + expected = (15, 60, 1) + self.check_intersection_and_reverse(t0, t1, expected) + + def test_no_step_no_overlap(self): + t0 = (0, 60) + t1 = (80, 130) + expected = None + self.check_intersection_and_reverse(t0, t1, expected) + + def test_no_step_partial_overlap_0(self): + t0 = (0, 60) + t1 = (15, 90) + expected = (15, 60, 1) + self.check_intersection_and_reverse(t0, t1, expected) + + def test_no_step_partial_overlap_1(self): + # regression test + t0 = (0, 4) + t1 = (3, 7) + expected = (3, 4, 1) + self.check_intersection_and_reverse(t0, t1, expected) + + def test_with_step_1(self): + t0 = (0, 60, 1) + t1 = (15, 30) + expected = (15, 30, 1) + result = metadata_utils.tuple_intersection(t0, t1) + self.assertSequenceEqual(result, expected) + + def test_with_step_2(self): + t0 = (0, 60, 2) + t1 = (15, 30) + expected = (16, 29, 2) + result = metadata_utils.tuple_intersection(t0, t1) + self.assertSequenceEqual(result, expected) + + def test_with_step_3(self): + t0 = (0, 59, 2) + t1 = (15, 90) + expected = (16, 59, 2) + result = metadata_utils.tuple_intersection(t0, t1) + self.assertSequenceEqual(result, expected) + + def test_big_step(self): + t0 = (0, 59, 1000) + t1 = (15, 90) + expected = None + result = metadata_utils.tuple_intersection(t0, t1) + self.assertEqual(result, expected) + class TestGridSizes(unittest.TestCase):