diff --git a/distarray/dist/distarray.py b/distarray/dist/distarray.py index 8fd0853e..7592a7a2 100644 --- a/distarray/dist/distarray.py +++ b/distarray/dist/distarray.py @@ -16,6 +16,7 @@ import operator from itertools import product +from functools import reduce import numpy as np @@ -31,44 +32,6 @@ # Code # --------------------------------------------------------------------------- -def process_return_value(subcontext, result_key, targets): - """Figure out what to return on the Client. - - Parameters - ---------- - key : string - Key corresponding to wrapped function's return value. - - Returns - ------- - A DistArray (if locally all values are DistArray), a None (if - locally all values are None), or else, pull the result back to the - client and return it. If all but one of the pulled values is None, - return that non-None value only. - """ - type_key = subcontext._generate_key() - type_statement = "{} = str(type({}))".format(type_key, result_key) - subcontext._execute(type_statement, targets=targets) - result_type_str = subcontext._pull(type_key, targets=targets) - - def is_NoneType(typestring): - return (typestring == "" or - typestring == "") - - def is_LocalArray(typestring): - return typestring == "" - - if all(is_LocalArray(r) for r in result_type_str): - result = DistArray.from_localarrays(result_key, context=subcontext) - elif all(is_NoneType(r) for r in result_type_str): - result = None - else: - result = subcontext._pull(result_key, targets=targets) - if has_exactly_one(result): - result = next(x for x in result if x is not None) - - return result - _DIM_DATA_PER_RANK = """ {ddpr_name} = {local_name}.dim_data """ @@ -158,21 +121,28 @@ def __getitem__(self, index): # especially for special cases like `index == slice(None)`. # This would dramatically improve tondarray's performance. + # func that runs locally + def getit(arr, index): + return arr.checked_getitem(index) + if isinstance(index, int) or isinstance(index, slice): tuple_index = (index,) return self.__getitem__(tuple_index) elif isinstance(index, tuple): targets = self.distribution.owning_targets(index) - result_key = self.context._generate_key() - fmt = '%s = %s.checked_getitem(%s)' - statement = fmt % (result_key, self.key, index) - self.context._execute(statement, targets=targets) - result = process_return_value(self.context, result_key, targets=targets) - if result is None: + + args = (self.key, index) + result = self.context.apply(getit, args=args, + targets=targets) + result = [i for i in result if i is not None] + if len(result) != 1: + raise IndexError("Getting more than one result (%s) is not " + " supported yet." % (result,)) + elif result is None: raise IndexError("Index %r is out of bounds" % (index,)) else: - return result + return result[0] else: raise TypeError("Invalid index type.") @@ -183,20 +153,24 @@ def __setitem__(self, index, value): # `value` and assign to local arrays. This would dramatically # improve the fromndarray method's performance. + def setit(arr, index, value): + return arr.checked_setitem(index, value) + if isinstance(index, int) or isinstance(index, slice): tuple_index = (index,) return self.__setitem__(tuple_index, value) elif isinstance(index, tuple): targets = self.distribution.owning_targets(index) - result_key = self.context._generate_key() - fmt = '%s = %s.checked_setitem(%s, %s)' - statement = fmt % (result_key, self.key, index, value) - self.context._execute(statement, targets=targets) - result = process_return_value(self.context, result_key, targets=targets) - if result is None: - raise IndexError("Index %r is out of bounds" % (index,)) - + args = (self.key, index, value) + result = self.context.apply(setit, args=args, + targets=targets) + result = [i for i in result if i is not None] + if len(result) > 1: + raise IndexError("Setting more than one result (%s) is not " + " supported yet." % (result,)) + elif result == []: + raise IndexError("Index %s is out of bounds" % (index,)) else: raise TypeError("Invalid index type.")