Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 28 additions & 19 deletions distarray/dist/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
`UnstructuredMap`.

"""
from __future__ import absolute_import
from __future__ import division, absolute_import

import operator
from itertools import product
Expand Down Expand Up @@ -200,7 +200,13 @@ def index_owners(self, idx):
return [0] if 0 <= idx < self.size else []

def slice_owners(self, idx):
return [0] # slicing doesn't complain about out-of-bounds indices
start = idx.start if idx.start is not None else 0
stop = idx.stop if idx.stop is not None else self.size
step = idx.step if idx.step is not None else 1
if tuple_intersection((start, stop, step), (0, self.size)):
return [0]
else:
return []

def get_dimdicts(self):
return ({
Expand All @@ -214,14 +220,16 @@ def slice(self, idx):
"""Make a new Map from a slice."""
start = idx.start if idx.start is not None else 0
stop = idx.stop if idx.stop is not None else self.size
intersection = tuple_intersection((0, self.size), (start, stop))
if intersection:
intersection_size = intersection[1] - intersection[0]
step = idx.step if idx.step is not None else 1
isection = tuple_intersection((start, stop, step), (0, self.size))
if isection:
step = idx.step if idx.step is not None else 1
isection_size = int(np.ceil((isection[1] - isection[0]) / step))
else:
intersection_size = 0
isection_size = 0

return {'dist_type': self.dist,
'size': intersection_size}
'size': isection_size}


class BlockMap(MapBase):
Expand Down Expand Up @@ -281,15 +289,13 @@ def index_owners(self, idx):

def slice_owners(self, idx):
coords = []
if idx.step not in {None, 1}:
msg = "Slicing only implemented for step=1"
raise NotImplementedError(msg)
start = idx.start if idx.start is not None else 0
stop = idx.stop if idx.stop is not None else self.size
step = idx.step if idx.step is not None else 1
for (coord, (lower, upper)) in enumerate(self.bounds):
slice_tuple = (idx.start if idx.start is not None else 0,
idx.stop if idx.stop is not None else self.size)
if tuple_intersection((lower, upper), slice_tuple):
if tuple_intersection((start, stop, step), (lower, upper)):
coords.append(coord)
return coords if coords != [] else [0]
return coords

def get_dimdicts(self):
grid_ranks = range(len(self.bounds))
Expand All @@ -316,14 +322,17 @@ def slice(self, idx):
"""Make a new Map from a slice."""
new_bounds = [0]
start = idx.start if idx.start is not None else 0
step = idx.step if idx.step is not None else 1
# iterate over the processes in this dimension
for proc_start, proc_stop in self.bounds:
stop = idx.stop if idx.stop is not None else proc_stop
intersection = tuple_intersection((proc_start, proc_stop),
(start, stop))
if intersection:
size = intersection[1] - intersection[0]
new_bounds.append(size + new_bounds[-1])
isection = tuple_intersection((start, stop, step),
(proc_start, proc_stop))
if isection:
isection_size = int(np.ceil((isection[1] - (isection[0])) / step))
new_bounds.append(isection_size + new_bounds[-1])
if len(new_bounds) == [0]:
new_bounds = []

return {'dist_type': self.dist,
'bounds': new_bounds}
Expand Down
34 changes: 34 additions & 0 deletions distarray/dist/tests/test_distarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,13 @@ def test_slice_a_slice_block_dist_1(self):
s2 = s1[-2:]
assert_array_equal(s2.toarray(), expected[3:5])

def test_slice_block_dist_1d_with_step(self):
size = 10
step = 2
expected = numpy.random.randint(10, size=size)
darr = self.context.fromarray(expected)
assert_array_equal(darr[::step].toarray(), expected[::step])

def test_partial_slice_block_dist_2d(self):
shape = (10, 20)
expected = numpy.random.randint(10, size=shape)
Expand Down Expand Up @@ -233,6 +240,15 @@ def test_large_1d_slice(self):
arr[slc] = new_data
assert_array_equal(arr.toarray(), source)

def test_1d_slice_with_step(self):
source = numpy.random.randint(10, size=20)
new_data = numpy.random.randint(10, size=5)
slc = slice(7, 17, 2)
arr = self.context.fromarray(source)
source[slc] = new_data
arr[slc] = new_data
assert_array_equal(arr.toarray(), source)

def test_2d_slice_0(self):
# on process boundaries
source = numpy.random.randint(10, size=(10, 20))
Expand All @@ -243,6 +259,15 @@ def test_2d_slice_0(self):
arr[slc] = new_data
assert_array_equal(arr.toarray(), source)

def test_2d_slice_with_step(self):
source = numpy.random.randint(10, size=(10, 20))
new_data = numpy.random.randint(10, size=(2, 5))
slc = (slice(5, 10, 3), slice(5, 15, 2))
arr = self.context.fromarray(source)
source[slc] = new_data
arr[slc] = new_data
assert_array_equal(arr.toarray(), source)

def test_2d_slice_1(self):
# off process boundaries
source = numpy.random.randint(10, size=(10, 20))
Expand Down Expand Up @@ -271,6 +296,15 @@ def test_full_3d_slice_ellipsis(self):
arr[slc] = new_data
assert_array_equal(arr.toarray(), source)

def test_3d_slice_ellipsis_with_step(self):
source = numpy.random.randint(10, size=(5, 4, 5))
new_data = numpy.random.randint(10, size=(5, 2, 5))
slc = (Ellipsis, slice(None, None, 2), Ellipsis)
arr = self.context.fromarray(source)
source[slc] = new_data
arr[slc] = new_data
assert_array_equal(arr.toarray(), source)

def test_partial_indexing_0(self):
source = numpy.random.randint(10, size=(3, 4, 5))
new_data = numpy.random.randint(10, size=(4, 5))
Expand Down
25 changes: 25 additions & 0 deletions distarray/dist/tests/test_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# Distributed under the terms of the BSD License. See COPYING.rst.
# ---------------------------------------------------------------------------

from __future__ import division

import unittest
from random import randrange

Expand Down Expand Up @@ -143,6 +145,29 @@ def test_from_full_slice_1d(self):
self.assertSequenceEqual(d1.targets, d0.targets)
self.assertSequenceEqual(d1.maps[0].bounds, d0.maps[0].bounds)

def test_from_full_slice_with_step_1d_0(self):
d0 = maps.Distribution.from_shape(context=self.context, shape=(15,))

s = (slice(None, None, 2),)
d1 = d0.slice(s)

self.assertEqual(len(d0.maps), len(d1.maps))
self.assertSequenceEqual(d1.dist, d0.dist)
self.assertSequenceEqual(d1.targets, d0.targets)
self.assertEqual(d1.maps[0].bounds[0][0], d0.maps[0].bounds[0][0])

def test_from_full_slice_with_step_1d_1(self):
d0 = maps.Distribution.from_shape(context=self.context, shape=(30,))
step = 4

s = (slice(4, None, step),)
d1 = d0.slice(s)

self.assertEqual(len(d0.maps), len(d1.maps))
self.assertSequenceEqual(d1.dist, d0.dist)
self.assertSequenceEqual(d1.targets, d0.targets)
self.assertEqual(d1.maps[0].bounds[0][0], d0.maps[0].bounds[0][0])

def test_from_full_slice_2d(self):
d0 = maps.Distribution.from_shape(context=self.context, shape=(15, 20))

Expand Down
12 changes: 10 additions & 2 deletions distarray/local/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,19 @@ def local_from_global_index(self, gidx):
return gidx - self.start

def local_from_global_slice(self, gidx):
# we don't make the effort to compute the exact slice
# `__getitem__` doesn't care about overly-large slices, we just
# have to get the offset from the start correct based on the `step`
start = gidx.start if gidx.start is not None else 0
stop = gidx.stop if gidx.stop is not None else self.global_size
new_start = max(start - self.start, 0) # prevent negative inds
step = gidx.step if gidx.step is not None else 1
new_start = start - self.start
if new_start < 0: # don't allow negative starts
new_start += step * abs(new_start // step)
if new_start < 0:
new_start += step
new_stop = stop - self.start
return slice(new_start, new_stop)
return slice(new_start, new_stop, gidx.step)

def global_from_local_index(self, lidx):
if lidx >= self.local_size:
Expand Down
39 changes: 31 additions & 8 deletions distarray/metadata_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# Distributed under the terms of the BSD License. See COPYING.rst.
# ---------------------------------------------------------------------------

from __future__ import division

import operator
from itertools import product
from functools import reduce
Expand Down Expand Up @@ -288,20 +290,42 @@ def _check_bounds(index, size):
raise IndexError("Index %r out of bounds" % index)


def tuple_intersection(t1, t2):
"""Compute intersection of two (start, stop) tuples.
def tuple_intersection(t0, t1):
"""Compute intersection of a (start, stop, step) and a (start, stop) tuple.

Assumes all values are positive.

Parameters
----------
t1, t2 : 2-tuples
t0: 2-tuple or 3-tuple
Tuple of (start, stop, [step]) representing an index range
t1: 2-tuple
Tuple of (start, stop) representing an index range

Returns
-------
2-tuple or None
3-tuple or None
A tightly bounded interval.
"""
stop = min(t1[1], t2[1])
start = max(t1[0], t2[0])
return (start, stop) if stop - start > 0 else None
if len(t0) == 2 or t0[2] is None:
# default step is 1
t0 = (t0[0], t0[1], 1)

start0, stop0, step0 = t0
start1, stop1 = t1
if start0 < start1:
n = int(numpy.ceil((start1 - start0) / step0))
start2 = start0 + n*step0
else:
start2 = start0

max_stop = min(t0[1], t1[1])
if (max_stop - start2) % step0 == 0:
n = ((max_stop - start2) // step0) - 1
else:
n = (max_stop - start2) // step0
stop2 = (start2 + n*step0) + 1
return (start2, stop2, step0) if stop2 > start2 else None


def positivify(index, size):
Expand Down Expand Up @@ -384,7 +408,6 @@ def replace_ellipsis(idx):
return idx
sanitized = tuple(replace_ellipsis(i) for i in sanitized)


if ndim is not None:
diff = ndim - len(sanitized)
if diff < 0:
Expand Down
75 changes: 75 additions & 0 deletions distarray/tests/test_metadata_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,81 @@ def test_multiple_ellipsis(self):
ndim=ndim)
self.assertEqual(sanitized, (slice(None),) * 4 + (10, slice(None)))

def test_step(self):
# currently doesn't touch step
indices = (slice(None, None, 2), slice(None, 8, 4))
tag, sanitized = metadata_utils.sanitize_indices(indices)
self.assertEqual(tag, 'view')
self.assertEqual(sanitized, indices)


class TestTupleIntersection(unittest.TestCase):

def check_intersection_and_reverse(self, t0, t1, expected):
result = metadata_utils.tuple_intersection(t0, t1)
self.assertEqual(result, expected)
result = metadata_utils.tuple_intersection(t1, t0)
self.assertEqual(result, expected)

def test_no_step_full_enclosure(self):
t0 = (0, 60)
t1 = (15, 30)
expected = (15, 30, 1)
self.check_intersection_and_reverse(t0, t1, expected)

def test_no_step_partial_overlap(self):
t0 = (0, 60)
t1 = (15, 90)
expected = (15, 60, 1)
self.check_intersection_and_reverse(t0, t1, expected)

def test_no_step_no_overlap(self):
t0 = (0, 60)
t1 = (80, 130)
expected = None
self.check_intersection_and_reverse(t0, t1, expected)

def test_no_step_partial_overlap_0(self):
t0 = (0, 60)
t1 = (15, 90)
expected = (15, 60, 1)
self.check_intersection_and_reverse(t0, t1, expected)

def test_no_step_partial_overlap_1(self):
# regression test
t0 = (0, 4)
t1 = (3, 7)
expected = (3, 4, 1)
self.check_intersection_and_reverse(t0, t1, expected)

def test_with_step_1(self):
t0 = (0, 60, 1)
t1 = (15, 30)
expected = (15, 30, 1)
result = metadata_utils.tuple_intersection(t0, t1)
self.assertSequenceEqual(result, expected)

def test_with_step_2(self):
t0 = (0, 60, 2)
t1 = (15, 30)
expected = (16, 29, 2)
result = metadata_utils.tuple_intersection(t0, t1)
self.assertSequenceEqual(result, expected)

def test_with_step_3(self):
t0 = (0, 59, 2)
t1 = (15, 90)
expected = (16, 59, 2)
result = metadata_utils.tuple_intersection(t0, t1)
self.assertSequenceEqual(result, expected)

def test_big_step(self):
t0 = (0, 59, 1000)
t1 = (15, 90)
expected = None
result = metadata_utils.tuple_intersection(t0, t1)
self.assertEqual(result, expected)


class TestGridSizes(unittest.TestCase):

Expand Down