diff --git a/distarray/dist/distarray.py b/distarray/dist/distarray.py index f7d72086..0cdeb1ee 100644 --- a/distarray/dist/distarray.py +++ b/distarray/dist/distarray.py @@ -131,43 +131,10 @@ def __repr__(self): (self.shape, self.targets) return s - def _process_return_value(self, result, return_proxy, index, targets, - new_distribution=None): - - if return_proxy: - # proxy returned as result of slice - # slicing shouldn't alter the dtype - result = result[0] - return DistArray.from_localarrays(key=result, - targets=targets, - distribution=new_distribution, - dtype=self.dtype) - - elif isinstance(result, Sequence): - somethings = [i for i in result if i is not None] - if len(somethings) == 0: - # using checked_getitem and all return None - raise IndexError("Index %r is is not present." % (index,)) - if len(somethings) == 1: - return somethings[0] - else: - return result - else: - assert False # impossible is nothing - - - def __getitem__(self, index): - - # to be run locally - def checked_getitem(arr, index): - return arr.global_index.checked_getitem(index) - - # to be run locally - def raw_getitem(arr, index): - return arr.global_index[index] + def _get_view(self, index): # to be run locally - def get_slice(arr, index, ddpr, comm): + def get_view(arr, index, ddpr, comm): from distarray.local.maps import Distribution if len(ddpr) == 0: dim_data = () @@ -177,28 +144,62 @@ def get_slice(arr, index, ddpr, comm): result = arr.global_index.get_slice(index, local_distribution) return proxyize(result) + new_distribution = self.distribution.slice(index) + ddpr = new_distribution.get_dim_data_per_rank() + + args = [self.key, index, ddpr, new_distribution.comm] + targets = new_distribution.targets + result = self.context.apply(get_view, args=args, targets=targets)[0] + + return DistArray.from_localarrays(key=result, + targets=targets, + distribution=new_distribution, + dtype=self.dtype) + + def _get_value(self, index): + + # to be run locally + def get_value(arr, index): + return arr.global_index[index] + + args = [self.key, index] + targets = self.distribution.owning_targets(index) + result = self.context.apply(get_value, args=args, targets=targets) + + return [i for i in result if i is not None][0] + + def _checked_getitem(self, index): + + # to be run locally + def checked_getitem(arr, index): + return arr.global_index.checked_getitem(index) + + args = [self.key, index] + targets = self.distribution.owning_targets(index) + result = self.context.apply(checked_getitem, args=args, targets=targets) + + somethings = [i for i in result if i is not None] + if len(somethings) == 0: + # all return None + raise IndexError("Index %r is is not present." % (index,)) + elif len(somethings) == 1: + return somethings[0] + else: + return result + + def __getitem__(self, index): return_type, index = sanitize_indices(index, ndim=self.ndim, shape=self.shape) - return_proxy = (return_type == 'view') - targets = self.distribution.owning_targets(index) or [0] + if not self.distribution.has_precise_index: + result = self._checked_getitem(index) + elif return_type == 'view': + result = self._get_view(index) + elif return_type == 'value': + result = self._get_value(index) + else: + assert False - args = [self.key, index] - new_distribution = None - if self.distribution.has_precise_index: - if return_proxy: # returning a new DistArray view - new_distribution = self.distribution.slice(index) - targets = new_distribution.targets - ddpr = new_distribution.get_dim_data_per_rank() - args.extend([ddpr, new_distribution.comm]) - local_fn = get_slice - else: # returning a value - local_fn = raw_getitem - else: # returning a value from unstructured - local_fn = checked_getitem - - result = self.context.apply(local_fn, args=args, targets=targets) - return self._process_return_value(result, return_proxy, index, targets, - new_distribution=new_distribution) + return result def __setitem__(self, index, value): # to be run locally