diff --git a/distarray/dist/distarray.py b/distarray/dist/distarray.py index 0e806e28..9fb44dbd 100644 --- a/distarray/dist/distarray.py +++ b/distarray/dist/distarray.py @@ -193,14 +193,7 @@ def get_slice(arr, index, ddpr, comm): result = self.context.apply(local_fn, args=args, targets=targets) return self._process_return_value(result, return_proxy, index, targets) - def __setitem__(self, index, value): - #TODO: FIXME: major performance improvements possible here. - # Especially when `index == slice(None)` and value is an - # ndarray, since for block and cyclic, we can generate slices of - # `value` and assign to local arrays. This would dramatically - # improve the fromndarray method's performance. - # to be run locally def checked_setitem(arr, index, value): return arr.global_index.checked_setitem(index, value) @@ -209,15 +202,62 @@ def checked_setitem(arr, index, value): def raw_setitem(arr, index, value): arr.global_index[index] = value - _, index = sanitize_indices(index, ndim=self.ndim, shape=self.shape) + # to be run locally + def set_slice(arr, index, value, value_slices): + from distarray.local.localarray import LocalArray + slice_ = value_slices[arr.comm_rank] + if isinstance(value, LocalArray): + arr.global_index[index] = value.ndarray + else: + arr.global_index[index] = value[slice_] + + set_type, index = sanitize_indices(index, ndim=self.ndim, + shape=self.shape) targets = self.distribution.owning_targets(index) - args = (self.key, index, value) + args = [self.key, index, value] if self.distribution.has_precise_index: - self.context.apply(raw_setitem, args=args, targets=targets) - else: - result = self.context.apply(checked_setitem, args=args, - targets=targets) + if set_type == 'value': + local_fn = raw_setitem + elif set_type == 'view': + new_distribution = self.distribution.slice(index) + # this could be made more efficient + # we only need the bounds computed by distribution.slice + if isinstance(args[-1], DistArray): + if not args[-1].distribution.is_compatible( + new_distribution): + msg = "rvalue Distribution not compatible." + raise ValueError(msg) + args[-1] = args[-1].key + else: + args[-1] = np.asarray(args[-1]) # convert to array + if args[-1].shape != new_distribution.shape: + msg = "Slice shape does not equal rvalue shape." + raise ValueError(msg) + ddpr = new_distribution.get_dim_data_per_rank() + def bounds_slice(dd): + if dd['dist_type'] == 'b': + return slice(dd['start'], dd['stop']) + elif dd['dist_type'] == 'n': + return slice(0, dd['size']) + else: + msg = "Function only works for 'n' and 'b' 'dist_type's" + raise TypeError(msg) + value_slices = [tuple(bounds_slice(dd) for dd in dim_data) + for dim_data in ddpr] + # but we need a data structure indexable by a target's rank + # assume contiguous range of targets here + value_slices_per_target = [None] * len(self.targets) + value_slices_per_target[targets[0]:targets[-1]] = value_slices + args.append(value_slices_per_target) + local_fn = set_slice + else: + assert False + self.context.apply(local_fn, args=args, targets=targets) + + else: # setting unstructured elements + local_fn = checked_setitem + result = self.context.apply(local_fn, args=args, targets=targets) result = [i for i in result if i is not None] if len(result) > 1: raise IndexError("Setting more than one result (%s) is " diff --git a/distarray/dist/tests/test_distarray.py b/distarray/dist/tests/test_distarray.py index 18ceac53..1c6d2940 100644 --- a/distarray/dist/tests/test_distarray.py +++ b/distarray/dist/tests/test_distarray.py @@ -212,6 +212,113 @@ def test_0d_ellipsis(self): assert_array_equal(arr[...].toarray(), expected[...]) + def test_resulting_slice(self): + dist = Distribution.from_shape(self.context, (10, 20)) + da = self.context.ones(dist) + db = da[:5, :10] + dc = db * 2 + assert_array_equal(dc.toarray(), numpy.ones(dc.shape) * 2) + + +class TestSetItemSlicing(ContextTestCase): + + def test_small_1d_slice(self): + source = numpy.random.randint(10, size=20) + new_data = numpy.random.randint(10, size=3) + slc = slice(1, 4) + arr = self.context.fromarray(source) + source[slc] = new_data + arr[slc] = new_data + assert_array_equal(arr.toarray(), source) + + def test_large_1d_slice(self): + source = numpy.random.randint(10, size=20) + new_data = numpy.random.randint(10, size=10) + slc = slice(5, 15) + arr = self.context.fromarray(source) + source[slc] = new_data + arr[slc] = new_data + assert_array_equal(arr.toarray(), source) + + def test_2d_slice_0(self): + # on process boundaries + source = numpy.random.randint(10, size=(10, 20)) + new_data = numpy.random.randint(10, size=(5, 10)) + slc = (slice(5, 10), slice(5, 15)) + arr = self.context.fromarray(source) + source[slc] = new_data + arr[slc] = new_data + assert_array_equal(arr.toarray(), source) + + def test_2d_slice_1(self): + # off process boundaries + source = numpy.random.randint(10, size=(10, 20)) + new_data = numpy.random.randint(10, size=(5, 10)) + slc = (slice(3, 8), slice(9, 19)) + arr = self.context.fromarray(source) + source[slc] = new_data + arr[slc] = new_data + assert_array_equal(arr.toarray(), source) + + def test_full_3d_slice(self): + source = numpy.random.randint(10, size=(3, 4, 5)) + new_data = numpy.random.randint(10, size=(3, 4, 5)) + slc = (slice(None), slice(None), slice(None)) + arr = self.context.fromarray(source) + source[slc] = new_data + arr[slc] = new_data + assert_array_equal(arr.toarray(), source) + + def test_full_3d_slice_ellipsis(self): + source = numpy.random.randint(10, size=(3, 4, 5)) + new_data = numpy.random.randint(10, size=(3, 4, 5)) + slc = Ellipsis + arr = self.context.fromarray(source) + source[slc] = new_data + arr[slc] = new_data + assert_array_equal(arr.toarray(), source) + + def test_partial_indexing_0(self): + source = numpy.random.randint(10, size=(3, 4, 5)) + new_data = numpy.random.randint(10, size=(4, 5)) + slc = (1,) + arr = self.context.fromarray(source) + source[slc] = new_data + arr[slc] = new_data + assert_array_equal(arr.toarray(), source) + + def test_partial_indexing_1(self): + source = numpy.random.randint(10, size=(3, 4, 5)) + new_data = numpy.random.randint(10, size=(3, 5)) + slc = (slice(None), 2) + arr = self.context.fromarray(source) + source[slc] = new_data + arr[slc] = new_data + assert_array_equal(arr.toarray(), source) + + def test_non_array_data(self): + source = numpy.random.randint(10, size=(3, 4)) + new_data = [42, 42, 42, 42] + slc = (2,) + arr = self.context.fromarray(source) + source[slc] = new_data + arr[slc] = new_data + assert_array_equal(arr.toarray(), source) + + def test_valueerror(self): + source = numpy.random.randint(10, size=21) + new_data = numpy.random.randint(10, size=10) + slc = slice(15, None) + arr = self.context.fromarray(source) + with self.assertRaises(ValueError): + arr[slc] = new_data + + def test_set_DistArray_slice(self): + dist = Distribution.from_shape(self.context, (10, 20)) + da = self.context.ones(dist) + db = self.context.zeros(dist) + da[...] = db + class TestDistArrayCreationFromGlobalDimData(ContextTestCase): diff --git a/distarray/metadata_utils.py b/distarray/metadata_utils.py index e01f7636..945e5310 100644 --- a/distarray/metadata_utils.py +++ b/distarray/metadata_utils.py @@ -316,8 +316,6 @@ def positivify(index, size): ------ IndexError for out-of-bounds indices - NotImplementedError - for negative steps """ if isinstance(index, Integral): index = _positivify(index, size) @@ -366,7 +364,7 @@ def sanitize_indices(indices, ndim=None, shape=None): else: msg = ("Index must be an Integral, a slice, or a sequence of " "Integrals and slices.") - raise TypeError(msg) + raise IndexError(msg) if Ellipsis in sanitized: if ndim is None: