Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions distarray/dist/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ def process_return_value(self, context, result_key):
client and return it. If all but one of the pulled values is None,
return that non-None value only.
"""
type_key = context._generate_key()
type_statement = "{} = str(type({}))".format(type_key, result_key)
context._execute(type_statement, targets=context.targets)
result_type_str = context._pull(type_key, targets=context.targets)
def get_type_str(key):
return str(type(key))
result_type_str = context.apply(get_type_str, args=(result_key,),
targets=context.targets)

def is_NoneType(typestring):
return (typestring == "<type 'NoneType'>" or
Expand Down
25 changes: 10 additions & 15 deletions distarray/dist/distarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,7 @@ def targets(self):
def tondarray(self):
"""Returns the distributed array as an ndarray."""
arr = np.empty(self.shape, dtype=self.dtype)
local_name = self.context._generate_key()
self.context._execute('%s = %s.copy()' % (local_name, self.key), targets=self.targets)
local_arrays = self.context._pull(local_name, targets=self.targets)
local_arrays = self.get_localarrays()
for local_array in local_arrays:
maps = (list(ax_map.global_iter) for ax_map in
local_array.distribution)
Expand Down Expand Up @@ -323,11 +321,9 @@ def get_ndarrays(self):
one ndarray per process

"""
key = self.context._generate_key()
self.context._execute('%s = %s.get_localarray()' % (key, self.key),
targets=self.targets)
result = self.context._pull(key, targets=self.targets)
return result
def get(key):
return key.get_localarray()
return self.context.apply(get, args=(self.key,), targets=self.targets)

def get_localarrays(self):
"""Pull the LocalArray objects from the engines.
Expand All @@ -338,15 +334,14 @@ def get_localarrays(self):
one localarray per process

"""
result = self.context._pull(self.key, targets=self.targets)
return result
def get(key):
return key.copy()
return self.context.apply(get, args=(self.key,), targets=self.targets)

def get_localshapes(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this method is refactored in another PR to remove all communication, correct?

key = self.context._generate_key()
self.context._execute('%s = %s.local_shape' % (key, self.key),
targets=self.targets)
result = self.context._pull(key, targets=self.targets)
return result
def get(key):
return key.local_shape
return self.context.apply(get, args=(self.key,), targets=self.targets)

# Binary operators

Expand Down