Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 27 additions & 53 deletions distarray/dist/distarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import operator
from itertools import product
from functools import reduce

import numpy as np

Expand All @@ -31,44 +32,6 @@
# Code
# ---------------------------------------------------------------------------

def process_return_value(subcontext, result_key, targets):
"""Figure out what to return on the Client.

Parameters
----------
key : string
Key corresponding to wrapped function's return value.

Returns
-------
A DistArray (if locally all values are DistArray), a None (if
locally all values are None), or else, pull the result back to the
client and return it. If all but one of the pulled values is None,
return that non-None value only.
"""
type_key = subcontext._generate_key()
type_statement = "{} = str(type({}))".format(type_key, result_key)
subcontext._execute(type_statement, targets=targets)
result_type_str = subcontext._pull(type_key, targets=targets)

def is_NoneType(typestring):
return (typestring == "<type 'NoneType'>" or
typestring == "<class 'NoneType'>")

def is_LocalArray(typestring):
return typestring == "<class 'distarray.local.localarray.LocalArray'>"

if all(is_LocalArray(r) for r in result_type_str):
result = DistArray.from_localarrays(result_key, context=subcontext)
elif all(is_NoneType(r) for r in result_type_str):
result = None
else:
result = subcontext._pull(result_key, targets=targets)
if has_exactly_one(result):
result = next(x for x in result if x is not None)

return result

_DIM_DATA_PER_RANK = """
{ddpr_name} = {local_name}.dim_data
"""
Expand Down Expand Up @@ -158,21 +121,28 @@ def __getitem__(self, index):
# especially for special cases like `index == slice(None)`.
# This would dramatically improve tondarray's performance.

# func that runs locally
def getit(arr, index):
return arr.checked_getitem(index)

if isinstance(index, int) or isinstance(index, slice):
tuple_index = (index,)
return self.__getitem__(tuple_index)

elif isinstance(index, tuple):
targets = self.distribution.owning_targets(index)
result_key = self.context._generate_key()
fmt = '%s = %s.checked_getitem(%s)'
statement = fmt % (result_key, self.key, index)
self.context._execute(statement, targets=targets)
result = process_return_value(self.context, result_key, targets=targets)
if result is None:

args = (self.key, index)
result = self.context.apply(getit, args=args,
targets=targets)
result = [i for i in result if i is not None]
if len(result) != 1:
raise IndexError("Getting more than one result (%s) is not "
" supported yet." % (result,))
elif result is None:
raise IndexError("Index %r is out of bounds" % (index,))
else:
return result
return result[0]
else:
raise TypeError("Invalid index type.")

Expand All @@ -183,20 +153,24 @@ def __setitem__(self, index, value):
# `value` and assign to local arrays. This would dramatically
# improve the fromndarray method's performance.

def setit(arr, index, value):
return arr.checked_setitem(index, value)

if isinstance(index, int) or isinstance(index, slice):
tuple_index = (index,)
return self.__setitem__(tuple_index, value)

elif isinstance(index, tuple):
targets = self.distribution.owning_targets(index)
result_key = self.context._generate_key()
fmt = '%s = %s.checked_setitem(%s, %s)'
statement = fmt % (result_key, self.key, index, value)
self.context._execute(statement, targets=targets)
result = process_return_value(self.context, result_key, targets=targets)
if result is None:
raise IndexError("Index %r is out of bounds" % (index,))

args = (self.key, index, value)
result = self.context.apply(setit, args=args,
targets=targets)
result = [i for i in result if i is not None]
if len(result) > 1:
raise IndexError("Setting more than one result (%s) is not "
" supported yet." % (result,))
elif result == []:
raise IndexError("Index %s is out of bounds" % (index,))
else:
raise TypeError("Invalid index type.")

Expand Down