From 68e6e31681cc5589397e6dff5d859b66be071834 Mon Sep 17 00:00:00 2001 From: Robert David Grant Date: Tue, 27 May 2014 12:28:05 -0500 Subject: [PATCH 01/20] WIP: Add failing test. --- distarray/dist/tests/test_distarray.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/distarray/dist/tests/test_distarray.py b/distarray/dist/tests/test_distarray.py index d28a5143..491f2c40 100644 --- a/distarray/dist/tests/test_distarray.py +++ b/distarray/dist/tests/test_distarray.py @@ -210,6 +210,24 @@ def test_vestigial_ellipsis(self): expected[0, :, 0, ...]) +class TestSetItemSlicing(unittest.TestCase): + + def setUp(self): + self.dac = Context() + + def tearDown(self): + self.dac.close() + + def test_1d_slice(self): + source = numpy.random.randint(10, size=20) + new_data = numpy.random.randint(10, size=5) + slc = slice(12, 12+len(new_data)) + arr = self.dac.fromarray(source) + source[slc] = new_data + arr[slc] = new_data + assert_array_equal(arr.toarray(), source) + + class TestDistArrayCreationFromGlobalDimData(unittest.TestCase): def setUp(self): From ac301b9a18fac3f3f876a5e60ee35936f6b99568 Mon Sep 17 00:00:00 2001 From: Robert David Grant Date: Tue, 27 May 2014 12:35:42 -0500 Subject: [PATCH 02/20] Add a comment. --- distarray/dist/distarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distarray/dist/distarray.py b/distarray/dist/distarray.py index 72def05d..774ae850 100644 --- a/distarray/dist/distarray.py +++ b/distarray/dist/distarray.py @@ -211,7 +211,7 @@ def raw_setitem(arr, index, value): args = (self.key, index, value) if self.distribution.has_precise_index: self.context.apply(raw_setitem, args=args, targets=targets) - else: + else: # setting unstructured elements result = self.context.apply(checked_setitem, args=args, targets=targets) result = [i for i in result if i is not None] From 992fad5993dc692e735cf606402f5f4fe30f8a47 Mon Sep 17 00:00:00 2001 From: Robert David Grant Date: Tue, 27 May 2014 12:36:14 -0500 Subject: [PATCH 03/20] Already works for a setitem that doesn't span procs. --- distarray/dist/tests/test_distarray.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distarray/dist/tests/test_distarray.py b/distarray/dist/tests/test_distarray.py index 491f2c40..8e5a2a34 100644 --- a/distarray/dist/tests/test_distarray.py +++ b/distarray/dist/tests/test_distarray.py @@ -220,8 +220,8 @@ def tearDown(self): def test_1d_slice(self): source = numpy.random.randint(10, size=20) - new_data = numpy.random.randint(10, size=5) - slc = slice(12, 12+len(new_data)) + new_data = numpy.random.randint(10, size=3) + slc = slice(1, 4) arr = self.dac.fromarray(source) source[slc] = new_data arr[slc] = new_data From c03b560607eb49696df4546281447034a7f69adf Mon Sep 17 00:00:00 2001 From: Robert David Grant Date: Tue, 27 May 2014 12:37:29 -0500 Subject: [PATCH 04/20] Add a new failing test. --- distarray/dist/tests/test_distarray.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/distarray/dist/tests/test_distarray.py b/distarray/dist/tests/test_distarray.py index 8e5a2a34..06b9f89b 100644 --- a/distarray/dist/tests/test_distarray.py +++ b/distarray/dist/tests/test_distarray.py @@ -218,7 +218,7 @@ def setUp(self): def tearDown(self): self.dac.close() - def test_1d_slice(self): + def test_small_1d_slice(self): source = numpy.random.randint(10, size=20) new_data = numpy.random.randint(10, size=3) slc = slice(1, 4) @@ -227,6 +227,15 @@ def test_1d_slice(self): arr[slc] = new_data assert_array_equal(arr.toarray(), source) + def test_large_1d_slice(self): + source = numpy.random.randint(10, size=20) + new_data = numpy.random.randint(10, size=10) + slc = slice(5, 15) + arr = self.dac.fromarray(source) + source[slc] = new_data + arr[slc] = new_data + assert_array_equal(arr.toarray(), source) + class TestDistArrayCreationFromGlobalDimData(unittest.TestCase): From b38afa354c4bf69e1ae016252744c911a4dfb288 Mon Sep 17 00:00:00 2001 From: Robert David Grant Date: Tue, 27 May 2014 13:57:05 -0500 Subject: [PATCH 05/20] Add `__setitem__` slicing. --- distarray/dist/distarray.py | 36 ++++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/distarray/dist/distarray.py b/distarray/dist/distarray.py index 774ae850..11488ef2 100644 --- a/distarray/dist/distarray.py +++ b/distarray/dist/distarray.py @@ -189,7 +189,6 @@ def get_slice(arr, index, ddpr, comm): return_proxy=return_proxy) return self._process_return_value(result, return_proxy, index, targets) - def __setitem__(self, index, value): #TODO: FIXME: major performance improvements possible here. # Especially when `index == slice(None)` and value is an @@ -205,15 +204,40 @@ def checked_setitem(arr, index, value): def raw_setitem(arr, index, value): arr.global_index[index] = value - _, index = sanitize_indices(index, ndim=self.ndim, shape=self.shape) + # to be run locally + def set_slice(arr, index, value, value_slices): + local_slice = value_slices[arr.comm_rank] + arr.global_index[index] = value[local_slice] + + set_type, index = sanitize_indices(index, ndim=self.ndim, + shape=self.shape) targets = self.distribution.owning_targets(index) - args = (self.key, index, value) + args = [self.key, index, value] if self.distribution.has_precise_index: - self.context.apply(raw_setitem, args=args, targets=targets) + if set_type == 'value': + local_fn = raw_setitem + elif set_type == 'view': + # this could be made more efficient + # we only need the bounds computed by distribution.slice + new_distribution = self.distribution.slice(index) + ddpr = new_distribution.get_dim_data_per_rank() + value_slices = [tuple(slice(dd['start'], dd['stop']) + for dd in dim_data) + for dim_data in ddpr] + # but we need a data structure indexable by a target's rank + # assume contigious range of targets here + value_slices_per_target = [None] * len(self.targets) + value_slices_per_target[targets[0]:targets[-1]] = value_slices + args.append(value_slices_per_target) + local_fn = set_slice + else: + assert False + self.context.apply(local_fn, args=args, targets=targets) + else: # setting unstructured elements - result = self.context.apply(checked_setitem, args=args, - targets=targets) + local_fn = checked_setitem + result = self.context.apply(local_fn, args=args, targets=targets) result = [i for i in result if i is not None] if len(result) > 1: raise IndexError("Setting more than one result (%s) is " From d59e31e228f30090d01020fc19bdc4ecf40ddd4c Mon Sep 17 00:00:00 2001 From: Robert David Grant Date: Tue, 27 May 2014 14:06:09 -0500 Subject: [PATCH 06/20] Make work for 2d slices. --- distarray/dist/distarray.py | 11 +++++++++-- distarray/dist/tests/test_distarray.py | 9 +++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/distarray/dist/distarray.py b/distarray/dist/distarray.py index 11488ef2..7fd8fa0d 100644 --- a/distarray/dist/distarray.py +++ b/distarray/dist/distarray.py @@ -222,8 +222,15 @@ def set_slice(arr, index, value, value_slices): # we only need the bounds computed by distribution.slice new_distribution = self.distribution.slice(index) ddpr = new_distribution.get_dim_data_per_rank() - value_slices = [tuple(slice(dd['start'], dd['stop']) - for dd in dim_data) + def bounds_slice(dd): + if dd['dist_type'] == 'b': + return slice(dd['start'], dd['stop']) + elif dd['dist_type'] == 'n': + return slice(0, dd['size']) + else: + msg = "Function only works for 'n' and 'b' 'dist_type's" + raise TypeError(msg) + value_slices = [tuple(bounds_slice(dd) for dd in dim_data) for dim_data in ddpr] # but we need a data structure indexable by a target's rank # assume contigious range of targets here diff --git a/distarray/dist/tests/test_distarray.py b/distarray/dist/tests/test_distarray.py index 06b9f89b..8df25a16 100644 --- a/distarray/dist/tests/test_distarray.py +++ b/distarray/dist/tests/test_distarray.py @@ -236,6 +236,15 @@ def test_large_1d_slice(self): arr[slc] = new_data assert_array_equal(arr.toarray(), source) + def test_2d_slice(self): + source = numpy.random.randint(10, size=(10, 20)) + new_data = numpy.random.randint(10, size=(5, 10)) + slc = (slice(5, 10), slice(5, 15)) + arr = self.dac.fromarray(source) + source[slc] = new_data + arr[slc] = new_data + assert_array_equal(arr.toarray(), source) + class TestDistArrayCreationFromGlobalDimData(unittest.TestCase): From 23afae0da176c98c4974bfe971611c5f3ba58805 Mon Sep 17 00:00:00 2001 From: Robert David Grant Date: Tue, 27 May 2014 14:23:39 -0500 Subject: [PATCH 07/20] Add more setitem slice tests. --- distarray/dist/tests/test_distarray.py | 31 +++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/distarray/dist/tests/test_distarray.py b/distarray/dist/tests/test_distarray.py index 8df25a16..d9aae6bf 100644 --- a/distarray/dist/tests/test_distarray.py +++ b/distarray/dist/tests/test_distarray.py @@ -236,7 +236,8 @@ def test_large_1d_slice(self): arr[slc] = new_data assert_array_equal(arr.toarray(), source) - def test_2d_slice(self): + def test_2d_slice_0(self): + # on process boundaries source = numpy.random.randint(10, size=(10, 20)) new_data = numpy.random.randint(10, size=(5, 10)) slc = (slice(5, 10), slice(5, 15)) @@ -245,6 +246,34 @@ def test_2d_slice(self): arr[slc] = new_data assert_array_equal(arr.toarray(), source) + def test_2d_slice_1(self): + # off process boundaries + source = numpy.random.randint(10, size=(10, 20)) + new_data = numpy.random.randint(10, size=(5, 10)) + slc = (slice(3, 8), slice(9, 19)) + arr = self.dac.fromarray(source) + source[slc] = new_data + arr[slc] = new_data + assert_array_equal(arr.toarray(), source) + + def test_full_3d_slice(self): + source = numpy.random.randint(10, size=(3, 4, 5)) + new_data = numpy.random.randint(10, size=(3, 4, 5)) + slc = (slice(None), slice(None), slice(None)) + arr = self.dac.fromarray(source) + source[slc] = new_data + arr[slc] = new_data + assert_array_equal(arr.toarray(), source) + + def test_full_3d_slice_ellipsis(self): + source = numpy.random.randint(10, size=(3, 4, 5)) + new_data = numpy.random.randint(10, size=(3, 4, 5)) + slc = Ellipsis + arr = self.dac.fromarray(source) + source[slc] = new_data + arr[slc] = new_data + assert_array_equal(arr.toarray(), source) + class TestDistArrayCreationFromGlobalDimData(unittest.TestCase): From ec659addc02c1337f7b09bb6b46ee40a8f72f0f8 Mon Sep 17 00:00:00 2001 From: Robert David Grant Date: Tue, 27 May 2014 14:26:50 -0500 Subject: [PATCH 08/20] Add more tests. --- distarray/dist/tests/test_distarray.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/distarray/dist/tests/test_distarray.py b/distarray/dist/tests/test_distarray.py index d9aae6bf..2106b935 100644 --- a/distarray/dist/tests/test_distarray.py +++ b/distarray/dist/tests/test_distarray.py @@ -274,6 +274,24 @@ def test_full_3d_slice_ellipsis(self): arr[slc] = new_data assert_array_equal(arr.toarray(), source) + def test_partial_indexing_0(self): + source = numpy.random.randint(10, size=(3, 4, 5)) + new_data = numpy.random.randint(10, size=(4, 5)) + slc = (1,) + arr = self.dac.fromarray(source) + source[slc] = new_data + arr[slc] = new_data + assert_array_equal(arr.toarray(), source) + + def test_partial_indexing_1(self): + source = numpy.random.randint(10, size=(3, 4, 5)) + new_data = numpy.random.randint(10, size=(3, 5)) + slc = (slice(None), 2) + arr = self.dac.fromarray(source) + source[slc] = new_data + arr[slc] = new_data + assert_array_equal(arr.toarray(), source) + class TestDistArrayCreationFromGlobalDimData(unittest.TestCase): From 5ddd354d1c4fcee0785940f3fb0631161589ea52 Mon Sep 17 00:00:00 2001 From: Robert David Grant Date: Tue, 27 May 2014 14:28:09 -0500 Subject: [PATCH 09/20] Make setUp and tearDown classmethods. --- distarray/dist/tests/test_distarray.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/distarray/dist/tests/test_distarray.py b/distarray/dist/tests/test_distarray.py index 2106b935..3da35b9d 100644 --- a/distarray/dist/tests/test_distarray.py +++ b/distarray/dist/tests/test_distarray.py @@ -212,11 +212,13 @@ def test_vestigial_ellipsis(self): class TestSetItemSlicing(unittest.TestCase): - def setUp(self): - self.dac = Context() + @classmethod + def setUpClass(cls): + cls.dac = Context() - def tearDown(self): - self.dac.close() + @classmethod + def tearDownClass(cls): + cls.dac.close() def test_small_1d_slice(self): source = numpy.random.randint(10, size=20) From 858ac2159ed595919220ba26317484a04fd931a5 Mon Sep 17 00:00:00 2001 From: Robert David Grant Date: Tue, 27 May 2014 14:30:30 -0500 Subject: [PATCH 10/20] Remove completed TODO comment. --- distarray/dist/distarray.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/distarray/dist/distarray.py b/distarray/dist/distarray.py index 7fd8fa0d..35af82e9 100644 --- a/distarray/dist/distarray.py +++ b/distarray/dist/distarray.py @@ -190,12 +190,6 @@ def get_slice(arr, index, ddpr, comm): return self._process_return_value(result, return_proxy, index, targets) def __setitem__(self, index, value): - #TODO: FIXME: major performance improvements possible here. - # Especially when `index == slice(None)` and value is an - # ndarray, since for block and cyclic, we can generate slices of - # `value` and assign to local arrays. This would dramatically - # improve the fromndarray method's performance. - # to be run locally def checked_setitem(arr, index, value): return arr.global_index.checked_setitem(index, value) From 342f19fc16a0db4c0aa14626da38e2e64be2f476 Mon Sep 17 00:00:00 2001 From: Robert David Grant Date: Tue, 27 May 2014 14:42:01 -0500 Subject: [PATCH 11/20] Convert a non-array rvalue to array. --- distarray/dist/distarray.py | 2 +- distarray/dist/tests/test_distarray.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/distarray/dist/distarray.py b/distarray/dist/distarray.py index 35af82e9..9dcc1f79 100644 --- a/distarray/dist/distarray.py +++ b/distarray/dist/distarray.py @@ -207,7 +207,7 @@ def set_slice(arr, index, value, value_slices): shape=self.shape) targets = self.distribution.owning_targets(index) - args = [self.key, index, value] + args = [self.key, index, np.asarray(value)] if self.distribution.has_precise_index: if set_type == 'value': local_fn = raw_setitem diff --git a/distarray/dist/tests/test_distarray.py b/distarray/dist/tests/test_distarray.py index 3da35b9d..cc756cfc 100644 --- a/distarray/dist/tests/test_distarray.py +++ b/distarray/dist/tests/test_distarray.py @@ -294,6 +294,15 @@ def test_partial_indexing_1(self): arr[slc] = new_data assert_array_equal(arr.toarray(), source) + def test_non_array_data(self): + source = numpy.random.randint(10, size=(3, 4)) + new_data = [42, 42, 42, 42] + slc = (2,) + arr = self.dac.fromarray(source) + source[slc] = new_data + arr[slc] = new_data + assert_array_equal(arr.toarray(), source) + class TestDistArrayCreationFromGlobalDimData(unittest.TestCase): From 6c7e84af3d0cbac5a3b9a793f72f50f970d73d0f Mon Sep 17 00:00:00 2001 From: Robert David Grant Date: Tue, 27 May 2014 15:57:34 -0500 Subject: [PATCH 12/20] Add failing ValueError test. --- distarray/dist/tests/test_distarray.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/distarray/dist/tests/test_distarray.py b/distarray/dist/tests/test_distarray.py index cc756cfc..0cc7c261 100644 --- a/distarray/dist/tests/test_distarray.py +++ b/distarray/dist/tests/test_distarray.py @@ -303,6 +303,14 @@ def test_non_array_data(self): arr[slc] = new_data assert_array_equal(arr.toarray(), source) + def test_valueerror(self): + source = numpy.random.randint(10, size=21) + new_data = numpy.random.randint(10, size=10) + slc = slice(15, None) + arr = self.dac.fromarray(source) + with self.assertRaises(ValueError): + arr[slc] = new_data + class TestDistArrayCreationFromGlobalDimData(unittest.TestCase): From df4f6997c747a5be8c355f027e1b52a26019d562 Mon Sep 17 00:00:00 2001 From: Robert David Grant Date: Tue, 27 May 2014 16:11:24 -0500 Subject: [PATCH 13/20] Remove an obsolete comment. --- distarray/metadata_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/distarray/metadata_utils.py b/distarray/metadata_utils.py index 1c02ffcf..c3921144 100644 --- a/distarray/metadata_utils.py +++ b/distarray/metadata_utils.py @@ -256,8 +256,6 @@ def positivify(index, size): ------ IndexError for out-of-bounds indices - NotImplementedError - for negative steps """ if isinstance(index, Integral): index = _positivify(index, size) From 29a33a769061066371fe920d32cfc103b1167877 Mon Sep 17 00:00:00 2001 From: Robert David Grant Date: Tue, 27 May 2014 16:11:47 -0500 Subject: [PATCH 14/20] Raise an IndexError instead of a TypeError... to match NumPy. --- distarray/metadata_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distarray/metadata_utils.py b/distarray/metadata_utils.py index c3921144..713ea0e9 100644 --- a/distarray/metadata_utils.py +++ b/distarray/metadata_utils.py @@ -304,7 +304,7 @@ def sanitize_indices(indices, ndim=None, shape=None): else: msg = ("Index must be an Integral, a slice, or a sequence of " "Integrals and slices.") - raise TypeError(msg) + raise IndexError(msg) if Ellipsis in sanitized: if ndim is None: From c78e8d2d7154cc356170613c197b3b5671fde6ed Mon Sep 17 00:00:00 2001 From: Robert David Grant Date: Tue, 27 May 2014 16:39:09 -0500 Subject: [PATCH 15/20] Raise a ValueError if rvalue shape is incorrect... like NumPy does. --- distarray/dist/distarray.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/distarray/dist/distarray.py b/distarray/dist/distarray.py index 9dcc1f79..5e0882a0 100644 --- a/distarray/dist/distarray.py +++ b/distarray/dist/distarray.py @@ -207,14 +207,18 @@ def set_slice(arr, index, value, value_slices): shape=self.shape) targets = self.distribution.owning_targets(index) - args = [self.key, index, np.asarray(value)] + args = [self.key, index, value] if self.distribution.has_precise_index: if set_type == 'value': local_fn = raw_setitem elif set_type == 'view': + args[-1] = np.asarray(args[-1]) # convert to array # this could be made more efficient # we only need the bounds computed by distribution.slice new_distribution = self.distribution.slice(index) + if args[-1].shape != new_distribution.shape: + msg = "Slice shape does not equal rvalue shape." + raise ValueError(msg) ddpr = new_distribution.get_dim_data_per_rank() def bounds_slice(dd): if dd['dist_type'] == 'b': From 2d75aa4c59a654db4920ef87f49c5070594b5760 Mon Sep 17 00:00:00 2001 From: Robert David Grant Date: Thu, 12 Jun 2014 16:08:59 -0500 Subject: [PATCH 16/20] Add test for @kwmsmith's "strange behavior". Seems to pass for me though... --- distarray/dist/tests/test_distarray.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/distarray/dist/tests/test_distarray.py b/distarray/dist/tests/test_distarray.py index 2f587e30..6361ba98 100644 --- a/distarray/dist/tests/test_distarray.py +++ b/distarray/dist/tests/test_distarray.py @@ -212,6 +212,13 @@ def test_0d_ellipsis(self): assert_array_equal(arr[...].toarray(), expected[...]) + def test_resulting_slice(self): + dist = Distribution.from_shape(self.context, (10, 20)) + da = self.context.ones(dist) + db = da[:5, :10] + dc = db * 2 + assert_array_equal(dc.toarray(), numpy.ones(dc.shape) * 2) + class TestSetItemSlicing(ContextTestCase): From 581ae6f14e9df7c8301c3c364af9646dcf5fd3e2 Mon Sep 17 00:00:00 2001 From: Robert David Grant Date: Thu, 12 Jun 2014 16:28:41 -0500 Subject: [PATCH 17/20] Fix typo in comment. --- distarray/dist/distarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distarray/dist/distarray.py b/distarray/dist/distarray.py index a50ff613..2db13a40 100644 --- a/distarray/dist/distarray.py +++ b/distarray/dist/distarray.py @@ -235,7 +235,7 @@ def bounds_slice(dd): value_slices = [tuple(bounds_slice(dd) for dd in dim_data) for dim_data in ddpr] # but we need a data structure indexable by a target's rank - # assume contigious range of targets here + # assume contiguous range of targets here value_slices_per_target = [None] * len(self.targets) value_slices_per_target[targets[0]:targets[-1]] = value_slices args.append(value_slices_per_target) From fef6142564381491742dae13ec0e4c9863916b9c Mon Sep 17 00:00:00 2001 From: Robert David Grant Date: Thu, 12 Jun 2014 16:29:02 -0500 Subject: [PATCH 18/20] Add failing test from @kwmsmith. --- distarray/dist/tests/test_distarray.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/distarray/dist/tests/test_distarray.py b/distarray/dist/tests/test_distarray.py index 6361ba98..b224c0ca 100644 --- a/distarray/dist/tests/test_distarray.py +++ b/distarray/dist/tests/test_distarray.py @@ -313,6 +313,15 @@ def test_valueerror(self): with self.assertRaises(ValueError): arr[slc] = new_data + def test_resulting_slice(self): + dist = Distribution.from_shape(self.context, (10, 20)) + da = self.context.ones(dist) + db = da[:5, :10] + arr = db.tondarray() + db[...] = arr * 2 + assert_array_equal(db.toarray(), numpy.ones(db.shape) * 2) + db[...] = db * 2 + class TestDistArrayCreationFromGlobalDimData(ContextTestCase): From 085cc31af0655f92124f9fd7acca11294aa51ed5 Mon Sep 17 00:00:00 2001 From: Robert David Grant Date: Thu, 12 Jun 2014 17:49:57 -0500 Subject: [PATCH 19/20] Improve failing test. --- distarray/dist/tests/test_distarray.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/distarray/dist/tests/test_distarray.py b/distarray/dist/tests/test_distarray.py index b224c0ca..1c6d2940 100644 --- a/distarray/dist/tests/test_distarray.py +++ b/distarray/dist/tests/test_distarray.py @@ -313,14 +313,11 @@ def test_valueerror(self): with self.assertRaises(ValueError): arr[slc] = new_data - def test_resulting_slice(self): + def test_set_DistArray_slice(self): dist = Distribution.from_shape(self.context, (10, 20)) da = self.context.ones(dist) - db = da[:5, :10] - arr = db.tondarray() - db[...] = arr * 2 - assert_array_equal(db.toarray(), numpy.ones(db.shape) * 2) - db[...] = db * 2 + db = self.context.zeros(dist) + da[...] = db class TestDistArrayCreationFromGlobalDimData(ContextTestCase): From bd01d1f2a6719007021a68814bb82b25c685c885 Mon Sep 17 00:00:00 2001 From: Robert David Grant Date: Thu, 12 Jun 2014 17:50:27 -0500 Subject: [PATCH 20/20] Fix DistArrays as rvalues in slicing setitem. --- distarray/dist/distarray.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/distarray/dist/distarray.py b/distarray/dist/distarray.py index 2db13a40..9fb44dbd 100644 --- a/distarray/dist/distarray.py +++ b/distarray/dist/distarray.py @@ -204,8 +204,12 @@ def raw_setitem(arr, index, value): # to be run locally def set_slice(arr, index, value, value_slices): - local_slice = value_slices[arr.comm_rank] - arr.global_index[index] = value[local_slice] + from distarray.local.localarray import LocalArray + slice_ = value_slices[arr.comm_rank] + if isinstance(value, LocalArray): + arr.global_index[index] = value.ndarray + else: + arr.global_index[index] = value[slice_] set_type, index = sanitize_indices(index, ndim=self.ndim, shape=self.shape) @@ -216,13 +220,20 @@ def set_slice(arr, index, value, value_slices): if set_type == 'value': local_fn = raw_setitem elif set_type == 'view': - args[-1] = np.asarray(args[-1]) # convert to array + new_distribution = self.distribution.slice(index) # this could be made more efficient # we only need the bounds computed by distribution.slice - new_distribution = self.distribution.slice(index) - if args[-1].shape != new_distribution.shape: - msg = "Slice shape does not equal rvalue shape." - raise ValueError(msg) + if isinstance(args[-1], DistArray): + if not args[-1].distribution.is_compatible( + new_distribution): + msg = "rvalue Distribution not compatible." + raise ValueError(msg) + args[-1] = args[-1].key + else: + args[-1] = np.asarray(args[-1]) # convert to array + if args[-1].shape != new_distribution.shape: + msg = "Slice shape does not equal rvalue shape." + raise ValueError(msg) ddpr = new_distribution.get_dim_data_per_rank() def bounds_slice(dd): if dd['dist_type'] == 'b':