Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 55 additions & 54 deletions distarray/dist/distarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,43 +131,10 @@ def __repr__(self):
(self.shape, self.targets)
return s

def _process_return_value(self, result, return_proxy, index, targets,
new_distribution=None):

if return_proxy:
# proxy returned as result of slice
# slicing shouldn't alter the dtype
result = result[0]
return DistArray.from_localarrays(key=result,
targets=targets,
distribution=new_distribution,
dtype=self.dtype)

elif isinstance(result, Sequence):
somethings = [i for i in result if i is not None]
if len(somethings) == 0:
# using checked_getitem and all return None
raise IndexError("Index %r is is not present." % (index,))
if len(somethings) == 1:
return somethings[0]
else:
return result
else:
assert False # impossible is nothing


def __getitem__(self, index):

# to be run locally
def checked_getitem(arr, index):
return arr.global_index.checked_getitem(index)

# to be run locally
def raw_getitem(arr, index):
return arr.global_index[index]
def _get_view(self, index):

# to be run locally
def get_slice(arr, index, ddpr, comm):
def get_view(arr, index, ddpr, comm):
from distarray.local.maps import Distribution
if len(ddpr) == 0:
dim_data = ()
Expand All @@ -177,28 +144,62 @@ def get_slice(arr, index, ddpr, comm):
result = arr.global_index.get_slice(index, local_distribution)
return proxyize(result)

new_distribution = self.distribution.slice(index)
ddpr = new_distribution.get_dim_data_per_rank()

args = [self.key, index, ddpr, new_distribution.comm]
targets = new_distribution.targets
result = self.context.apply(get_view, args=args, targets=targets)[0]

return DistArray.from_localarrays(key=result,
targets=targets,
distribution=new_distribution,
dtype=self.dtype)

def _get_value(self, index):

# to be run locally
def get_value(arr, index):
return arr.global_index[index]

args = [self.key, index]
targets = self.distribution.owning_targets(index)
result = self.context.apply(get_value, args=args, targets=targets)

return [i for i in result if i is not None][0]

def _checked_getitem(self, index):

# to be run locally
def checked_getitem(arr, index):
return arr.global_index.checked_getitem(index)

args = [self.key, index]
targets = self.distribution.owning_targets(index)
result = self.context.apply(checked_getitem, args=args, targets=targets)

somethings = [i for i in result if i is not None]
if len(somethings) == 0:
# all return None
raise IndexError("Index %r is is not present." % (index,))
elif len(somethings) == 1:
return somethings[0]
else:
return result

def __getitem__(self, index):
return_type, index = sanitize_indices(index, ndim=self.ndim,
shape=self.shape)
return_proxy = (return_type == 'view')
targets = self.distribution.owning_targets(index) or [0]
if not self.distribution.has_precise_index:
result = self._checked_getitem(index)
elif return_type == 'view':
result = self._get_view(index)
elif return_type == 'value':
result = self._get_value(index)
else:
assert False

args = [self.key, index]
new_distribution = None
if self.distribution.has_precise_index:
if return_proxy: # returning a new DistArray view
new_distribution = self.distribution.slice(index)
targets = new_distribution.targets
ddpr = new_distribution.get_dim_data_per_rank()
args.extend([ddpr, new_distribution.comm])
local_fn = get_slice
else: # returning a value
local_fn = raw_getitem
else: # returning a value from unstructured
local_fn = checked_getitem

result = self.context.apply(local_fn, args=args, targets=targets)
return self._process_return_value(result, return_proxy, index, targets,
new_distribution=new_distribution)
return result

def __setitem__(self, index, value):
# to be run locally
Expand Down