diff --git a/distarray/dist/distarray.py b/distarray/dist/distarray.py index 3c371b6e..0c99fd83 100644 --- a/distarray/dist/distarray.py +++ b/distarray/dist/distarray.py @@ -179,12 +179,13 @@ def get_slice(arr, index, ddpr, comm): return_type, index = sanitize_indices(index, ndim=self.ndim, shape=self.shape) return_proxy = (return_type == 'view') - targets = self.distribution.owning_targets(index) + targets = self.distribution.owning_targets(index) or [0] args = [self.key, index] if self.distribution.has_precise_index: if return_proxy: # returning a new DistArray view new_distribution = self.distribution.slice(index) + targets = new_distribution.targets ddpr = new_distribution.get_dim_data_per_rank() args.extend([ddpr, new_distribution.comm]) local_fn = get_slice diff --git a/distarray/dist/maps.py b/distarray/dist/maps.py index d4cead4f..f7d0e2b3 100644 --- a/distarray/dist/maps.py +++ b/distarray/dist/maps.py @@ -666,7 +666,7 @@ def has_precise_index(self): def slice(self, index_tuple): """Make a new Distribution from a slice.""" - new_targets = self.owning_targets(index_tuple) + new_targets = self.owning_targets(index_tuple) or [0] global_dim_data = [] # iterate over the dimensions for map_, idx in zip(self.maps, index_tuple): diff --git a/distarray/dist/tests/test_distarray.py b/distarray/dist/tests/test_distarray.py index 1f84a5f7..3ddf79cb 100644 --- a/distarray/dist/tests/test_distarray.py +++ b/distarray/dist/tests/test_distarray.py @@ -178,19 +178,25 @@ def test_incomplete_index_block_dist_2d(self): arr = self.context.fromarray(expected) assert_array_equal(arr[1].toarray(), expected[1]) - @unittest.expectedFailure def test_empty_slice_1d(self): shape = (10,) expected = numpy.random.randint(10, size=shape) arr = self.context.fromarray(expected) - assert_array_equal(arr[100:].toarray(), expected[100:]) + out = arr[100:] + self.assertEqual(out.shape, (0,)) + self.assertEqual(out.grid_shape, (1,)) + self.assertEqual(len(out.targets), 1) + assert_array_equal(out.toarray(), expected[100:]) - @unittest.expectedFailure def test_empty_slice_2d(self): shape = (10, 20) expected = numpy.random.randint(10, size=shape) arr = self.context.fromarray(expected) - assert_array_equal(arr[100:, 100:].toarray(), expected[100:, 100:]) + out = arr[100:, 100:] + self.assertEqual(out.shape, (0, 0)) + self.assertEqual(out.grid_shape, (1, 1)) + self.assertEqual(len(out.targets), 1) + assert_array_equal(out.toarray(), expected[100:, 100:]) def test_trailing_ellipsis(self): shape = (2, 3, 7, 6)