Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions distarray/dist/distarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,16 @@ def __repr__(self):
(self.shape, self.targets)
return s

def _process_return_value(self, result, return_proxy, index, targets):
def _process_return_value(self, result, return_proxy, index, targets,
new_distribution=None):

if return_proxy:
# proxy returned as result of slice
# slicing shouldn't alter the dtype
result = result[0]
return DistArray.from_localarrays(key=result,
context=self.context,
targets=targets,
distribution=new_distribution,
dtype=self.dtype)

elif isinstance(result, Sequence):
Expand Down Expand Up @@ -182,6 +183,7 @@ def get_slice(arr, index, ddpr, comm):
targets = self.distribution.owning_targets(index)

args = [self.key, index]
new_distribution = None
if self.distribution.has_precise_index:
if return_proxy: # returning a new DistArray view
new_distribution = self.distribution.slice(index)
Expand All @@ -194,7 +196,8 @@ def get_slice(arr, index, ddpr, comm):
local_fn = checked_getitem

result = self.context.apply(local_fn, args=args, targets=targets)
return self._process_return_value(result, return_proxy, index, targets)
return self._process_return_value(result, return_proxy, index, targets,
new_distribution=new_distribution)

def __setitem__(self, index, value):
# to be run locally
Expand Down
38 changes: 29 additions & 9 deletions distarray/dist/tests/test_distarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,17 @@ def test_full_slice_block_dist(self):
size = 10
expected = numpy.random.randint(11, size=size)
arr = self.context.fromarray(expected)
assert_array_equal(arr[:].toarray(), expected)
out = arr[:]
self.assertSequenceEqual(out.dist, ('b',))
assert_array_equal(out.toarray(), expected)

def test_partial_slice_block_dist(self):
size = 10
expected = numpy.random.randint(10, size=size)
arr = self.context.fromarray(expected)
assert_array_equal(arr[0:2].toarray(), expected[0:2])
out = arr[0:2]
self.assertSequenceEqual(out.dist, ('b',))
assert_array_equal(out.toarray(), expected[0:2])

def test_slice_a_slice_block_dist_0(self):
size = 10
Expand All @@ -135,6 +139,7 @@ def test_slice_a_slice_block_dist_0(self):
s0 = arr[:9]
s1 = s0[0:5]
s2 = s1[:2]
self.assertSequenceEqual(s2.dist, ('b',))
assert_array_equal(s2.toarray(), expected[:2])

def test_slice_a_slice_block_dist_1(self):
Expand All @@ -144,39 +149,50 @@ def test_slice_a_slice_block_dist_1(self):
s0 = arr[:9]
s1 = s0[0:5]
s2 = s1[-2:]
self.assertSequenceEqual(s2.dist, ('b',))
assert_array_equal(s2.toarray(), expected[3:5])

def test_slice_block_dist_1d_with_step(self):
size = 10
step = 2
expected = numpy.random.randint(10, size=size)
darr = self.context.fromarray(expected)
assert_array_equal(darr[::step].toarray(), expected[::step])
out = darr[::step]
self.assertSequenceEqual(out.dist, ('b',))
assert_array_equal(out.toarray(), expected[::step])

def test_partial_slice_block_dist_2d(self):
shape = (10, 20)
expected = numpy.random.randint(10, size=shape)
arr = self.context.fromarray(expected)
assert_array_equal(arr[2:6, 3:10].toarray(), expected[2:6, 3:10])
out = arr[2:6, 3:10]
self.assertSequenceEqual(out.dist, ('b', 'n'))
assert_array_equal(out.toarray(), expected[2:6, 3:10])

def test_partial_negative_slice_block_dist_2d(self):
shape = (10, 20)
expected = numpy.random.randint(10, size=shape)
arr = self.context.fromarray(expected)
assert_array_equal(arr[-6:-2, -10:-3].toarray(),
out = arr[-6:-2, -10:-3]
self.assertSequenceEqual(out.dist, ('b', 'n'))
assert_array_equal(out.toarray(),
expected[-6:-2, -10:-3])

def test_incomplete_slice_block_dist_2d(self):
shape = (10, 20)
expected = numpy.random.randint(10, size=shape)
arr = self.context.fromarray(expected)
assert_array_equal(arr[3:9].toarray(), expected[3:9])
out = arr[3:9]
self.assertSequenceEqual(out.dist, ('b', 'n'))
assert_array_equal(out.toarray(), expected[3:9])

def test_incomplete_index_block_dist_2d(self):
shape = (10, 20)
expected = numpy.random.randint(10, size=shape)
arr = self.context.fromarray(expected)
assert_array_equal(arr[1].toarray(), expected[1])
out = arr[1]
self.assertSequenceEqual(out.dist, ('n'))
assert_array_equal(out.toarray(), expected[1])

@unittest.expectedFailure
def test_empty_slice_1d(self):
Expand All @@ -196,13 +212,17 @@ def test_trailing_ellipsis(self):
shape = (2, 3, 7, 6)
expected = numpy.random.randint(10, size=shape)
arr = self.context.fromarray(expected)
assert_array_equal(arr[1, ...].toarray(), expected[1, ...])
out = arr[1, ...]
self.assertSequenceEqual(out.dist, ('n', 'n', 'n'))
assert_array_equal(out.toarray(), expected[1, ...])

def test_leading_ellipsis(self):
shape = (2, 3, 7, 6)
expected = numpy.random.randint(10, size=shape)
arr = self.context.fromarray(expected)
assert_array_equal(arr[..., 3].toarray(), expected[..., 3])
out = arr[..., 3]
self.assertSequenceEqual(out.dist, ('b', 'n', 'n'))
assert_array_equal(out.toarray(), expected[..., 3])

def test_multiple_ellipsis(self):
shape = (2, 4, 2, 4, 1, 5)
Expand Down