diff --git a/distarray/dist/distarray.py b/distarray/dist/distarray.py index 3c371b6e..07247fd2 100644 --- a/distarray/dist/distarray.py +++ b/distarray/dist/distarray.py @@ -131,15 +131,16 @@ def __repr__(self): (self.shape, self.targets) return s - def _process_return_value(self, result, return_proxy, index, targets): + def _process_return_value(self, result, return_proxy, index, targets, + new_distribution=None): if return_proxy: # proxy returned as result of slice # slicing shouldn't alter the dtype result = result[0] return DistArray.from_localarrays(key=result, - context=self.context, targets=targets, + distribution=new_distribution, dtype=self.dtype) elif isinstance(result, Sequence): @@ -182,6 +183,7 @@ def get_slice(arr, index, ddpr, comm): targets = self.distribution.owning_targets(index) args = [self.key, index] + new_distribution = None if self.distribution.has_precise_index: if return_proxy: # returning a new DistArray view new_distribution = self.distribution.slice(index) @@ -194,7 +196,8 @@ def get_slice(arr, index, ddpr, comm): local_fn = checked_getitem result = self.context.apply(local_fn, args=args, targets=targets) - return self._process_return_value(result, return_proxy, index, targets) + return self._process_return_value(result, return_proxy, index, targets, + new_distribution=new_distribution) def __setitem__(self, index, value): # to be run locally diff --git a/distarray/dist/tests/test_distarray.py b/distarray/dist/tests/test_distarray.py index 1f84a5f7..6b913e60 100644 --- a/distarray/dist/tests/test_distarray.py +++ b/distarray/dist/tests/test_distarray.py @@ -120,13 +120,17 @@ def test_full_slice_block_dist(self): size = 10 expected = numpy.random.randint(11, size=size) arr = self.context.fromarray(expected) - assert_array_equal(arr[:].toarray(), expected) + out = arr[:] + self.assertSequenceEqual(out.dist, ('b',)) + assert_array_equal(out.toarray(), expected) def test_partial_slice_block_dist(self): size = 10 expected = numpy.random.randint(10, size=size) arr = self.context.fromarray(expected) - assert_array_equal(arr[0:2].toarray(), expected[0:2]) + out = arr[0:2] + self.assertSequenceEqual(out.dist, ('b',)) + assert_array_equal(out.toarray(), expected[0:2]) def test_slice_a_slice_block_dist_0(self): size = 10 @@ -135,6 +139,7 @@ def test_slice_a_slice_block_dist_0(self): s0 = arr[:9] s1 = s0[0:5] s2 = s1[:2] + self.assertSequenceEqual(s2.dist, ('b',)) assert_array_equal(s2.toarray(), expected[:2]) def test_slice_a_slice_block_dist_1(self): @@ -144,6 +149,7 @@ def test_slice_a_slice_block_dist_1(self): s0 = arr[:9] s1 = s0[0:5] s2 = s1[-2:] + self.assertSequenceEqual(s2.dist, ('b',)) assert_array_equal(s2.toarray(), expected[3:5]) def test_slice_block_dist_1d_with_step(self): @@ -151,32 +157,42 @@ def test_slice_block_dist_1d_with_step(self): step = 2 expected = numpy.random.randint(10, size=size) darr = self.context.fromarray(expected) - assert_array_equal(darr[::step].toarray(), expected[::step]) + out = darr[::step] + self.assertSequenceEqual(out.dist, ('b',)) + assert_array_equal(out.toarray(), expected[::step]) def test_partial_slice_block_dist_2d(self): shape = (10, 20) expected = numpy.random.randint(10, size=shape) arr = self.context.fromarray(expected) - assert_array_equal(arr[2:6, 3:10].toarray(), expected[2:6, 3:10]) + out = arr[2:6, 3:10] + self.assertSequenceEqual(out.dist, ('b', 'n')) + assert_array_equal(out.toarray(), expected[2:6, 3:10]) def test_partial_negative_slice_block_dist_2d(self): shape = (10, 20) expected = numpy.random.randint(10, size=shape) arr = self.context.fromarray(expected) - assert_array_equal(arr[-6:-2, -10:-3].toarray(), + out = arr[-6:-2, -10:-3] + self.assertSequenceEqual(out.dist, ('b', 'n')) + assert_array_equal(out.toarray(), expected[-6:-2, -10:-3]) def test_incomplete_slice_block_dist_2d(self): shape = (10, 20) expected = numpy.random.randint(10, size=shape) arr = self.context.fromarray(expected) - assert_array_equal(arr[3:9].toarray(), expected[3:9]) + out = arr[3:9] + self.assertSequenceEqual(out.dist, ('b', 'n')) + assert_array_equal(out.toarray(), expected[3:9]) def test_incomplete_index_block_dist_2d(self): shape = (10, 20) expected = numpy.random.randint(10, size=shape) arr = self.context.fromarray(expected) - assert_array_equal(arr[1].toarray(), expected[1]) + out = arr[1] + self.assertSequenceEqual(out.dist, ('n')) + assert_array_equal(out.toarray(), expected[1]) @unittest.expectedFailure def test_empty_slice_1d(self): @@ -196,13 +212,17 @@ def test_trailing_ellipsis(self): shape = (2, 3, 7, 6) expected = numpy.random.randint(10, size=shape) arr = self.context.fromarray(expected) - assert_array_equal(arr[1, ...].toarray(), expected[1, ...]) + out = arr[1, ...] + self.assertSequenceEqual(out.dist, ('n', 'n', 'n')) + assert_array_equal(out.toarray(), expected[1, ...]) def test_leading_ellipsis(self): shape = (2, 3, 7, 6) expected = numpy.random.randint(10, size=shape) arr = self.context.fromarray(expected) - assert_array_equal(arr[..., 3].toarray(), expected[..., 3]) + out = arr[..., 3] + self.assertSequenceEqual(out.dist, ('b', 'n', 'n')) + assert_array_equal(out.toarray(), expected[..., 3]) def test_multiple_ellipsis(self): shape = (2, 4, 2, 4, 1, 5)