Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion distarray/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from distarray.client_map import ClientMDMap
from distarray.ipython_utils import IPythonClient

DISTARRAY_BASE_NAME = '__distarray__'


class Context(object):
'''
Expand Down Expand Up @@ -109,6 +111,18 @@ def get_size():
block=True
)

# `localize` and `vectorize` allow extra functions to be added to the context.

def localize(self, func):
from distarray.decorators import Localize
lf = Localize(func, self)
setattr(self, func.__name__, lf)

def vectorize(self, func):
from distarray.decorators import Vectorize
lf = Vectorize(func, self)
setattr(self, func.__name__, lf)

# Key management routines:

def _setup_key_context(self):
Expand All @@ -123,7 +137,7 @@ def _setup_key_context(self):

def _key_basename(self):
""" Get the base name for all keys. """
return '_distarray_key'
return DISTARRAY_BASE_NAME

def _key_prefix(self):
""" Generate a prefix for a key name for this context. """
Expand Down
247 changes: 107 additions & 140 deletions distarray/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,108 +11,59 @@
import functools

from distarray.client import DistArray
from distarray.context import Context
from distarray.context import DISTARRAY_BASE_NAME
from distarray.error import ContextError
from distarray.utils import has_exactly_one
from distarray.externals.six import string_types


class DecoratorBase(object):
"""
Base class for decorators, handles name wrapping and allows the
decorator to take an optional kwarg.
class FunctionRegistrationBase(object):
"""
Base class for local function registration.

Subclasses:
Localize
Vectorize

def __init__(self, fn):
self.fn = fn
self.fn_key = self.fn.__name__
functools.update_wrapper(self, fn)
self.context = None
"""

def push_fn(self, context, fn_key, fn):
"""Push function to the engines."""
context._push({fn_key: fn})
def __init__(self, func, context):
self.func = func
self.func_key = self.func.__name__
functools.update_wrapper(self, func)
self.context = context

def determine_context(self, args, kwargs):
""" Determine a context from a functions arguments."""

contexts = []
# inspect args for a context
for arg in args + tuple(kwargs.values()):
if isinstance(arg, DistArray):
contexts.append(arg.context)
elif isinstance(arg, Context):
contexts.append(arg)

# check the args had a context
if contexts == []:
raise TypeError('Function must take DistArray or Context objects.')
if arg.context != self.context:
msg = "DistArray %r not in same context as registered function %r."
raise ContextError(msg % (arg, self.func))

# check that all contexts are equal
if not contexts.count(contexts[0]) == len(contexts):
msg = ("Arguments must use the same Context (given arguments of "
"type %r)")
msg %= (tuple(set(contexts)),)
raise ContextError(msg)
return self.context

return contexts[0]

def key_and_push_args(self, args, kwargs, context=None, da_handler=None):
"""
Push a tuple of args and dict of kwargs to the engines. Return a
tuple with keys corresponding to args values on the engines. And a
dictionary with the same keys and values which are the keys to the
input dictionary's values.

This allows us to use the following interface to execute code on
the engines:

>>> def foo(*args, **kwargs):
>>> args, kwargs = _key_and_push_args(args, kwargs)
>>> exec_str = "remote_foo(*%s, **%s)"
>>> exec_str %= (args, kwargs)
>>> context.execute(exec_str)
def build_args(self, args, kwargs):
"""
Returns a new args tuple and kwargs dictionary with all distarrays in
the original args and kwargs arguments replaced by their .keys.

if context is None:
context = self.determine_context(args, kwargs)
"""

# handle positional arguments
arg_keys = []
push_keys = {}
for arg in args:
args = list(args)
for idx, arg in enumerate(args):
if isinstance(arg, DistArray):
if da_handler is None:
arg_keys.append(arg.key)
# da_handler handles distarrays.
else:
arg_keys = da_handler(arg, arg_keys)
else:
new_key = context._generate_key()
arg_keys.append(new_key)
push_keys[new_key] = arg
args[idx] = arg.key

# handle key word arguments
for kw in kwargs:
if isinstance(kwargs[kw], DistArray):
kwargs[kw] = kwargs[kw].key
else:
new_key = context._generate_key()
push_keys[new_key] = kwargs[kw]
kwargs[kw] = new_key

# push the keys to the engines
context._push(push_keys)

# build arg string
arg_str = '(' + ', '.join(arg_keys) + ',)'
for k, v in kwargs.items():
if isinstance(v, DistArray):
kwargs[k] = v.key

# build kwarg string
kwarg_iter = ["'%s': %s" % (k, v) for (k, v) in kwargs.items()]
kwarg_str = '{' + ', '.join(kwarg_iter) + '}'
return args, kwargs

return arg_str, kwarg_str

def process_return_value(self, context, result_key):
def process_return_value(self, result_from_target):
"""Figure out what to return on the Client.

Parameters
Expand All @@ -128,85 +79,101 @@ def process_return_value(self, context, result_key):
client and return it. If all but one of the pulled values is None,
return that non-None value only.
"""
type_key = context._generate_key()
type_statement = "{} = str(type({}))".format(type_key, result_key)
context._execute(type_statement)
result_type_str = context._pull(type_key)

def is_NoneType(typestring):
return (typestring == "<type 'NoneType'>" or
typestring == "<class 'NoneType'>")

def is_LocalArray(typestring):
return (typestring == "<class 'distarray.local.localarray."
"LocalArray'>")

if all(is_LocalArray(r) for r in result_type_str):
result = DistArray.from_localarrays(result_key, context)
elif all(is_NoneType(r) for r in result_type_str):

results = list(result_from_target.values())

if all(isinstance(r, string_types) and r.startswith(DISTARRAY_BASE_NAME)
for r in results):
result = DistArray.from_localarrays(results[0], self.context)
elif all(r is None for r in results):
result = None
else:
result = context._pull(result_key)
if has_exactly_one(result):
result = next(x for x in result if x is not None)

non_nones = [r for r in results if r is not None]
if len(non_nones) == 1:
result = non_nones[0]
else:
result = results
return result


class local(DecoratorBase):
"""Decorator to run a function locally on the engines."""
def _rpc_localize(func, args, kwargs, result_key, prefix):

ns = __import__('__main__')

from distarray.local.localarray import LocalArray
from distarray.externals.six import string_types

args = list(args)
for idx, a in enumerate(args):
if isinstance(a, string_types):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This arg and kwarg logic is duplicated below. So maybe it should be a method on DecoratorBase.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are the functions that are pushed to the engines, so it wouldn't make sense to have them as methods on DecoratorBase. I'd like to remove the duplication, but there isn't a clean way of doing so at the moment.

if a.startswith(prefix):
args[idx] = getattr(ns, a)

for k, v in kwargs.items():
if isinstance(v, string_types):
if v.startswith(prefix):
kwargs[k] = getattr(ns, v)

res = func(*args, **kwargs)
if isinstance(res, LocalArray):
setattr(ns, result_key, res)
return result_key
return res


class Localize(FunctionRegistrationBase):
"""Runs a function locally on the engines."""

def __call__(self, *args, **kwargs):
# get context from args
context = self.determine_context(args, kwargs)
# push function
self.push_fn(context, self.fn_key, self.fn)

args, kwargs = self.key_and_push_args(args, kwargs,
context=context)
args, kwargs = self.build_args(args, kwargs)
result_key = context._generate_key()
results = context.view.apply_async(_rpc_localize, self.func,
args, kwargs, result_key,
DISTARRAY_BASE_NAME).get_dict()
return self.process_return_value(results)


def _rpc_vectorize(func, args, kwargs, out, prefix):

ns = __import__('__main__')
import numpy as np
from distarray.externals.six import string_types

exec_str = "%s = %s(*%s, **%s)"
exec_str %= (result_key, self.fn_key, args, kwargs)
context._execute(exec_str)
args = list(args)
for idx, a in enumerate(args):
if isinstance(a, string_types):
if a.startswith(prefix):
args[idx] = getattr(ns, a).local_array

return self.process_return_value(context, result_key)
for k, v in kwargs.items():
if isinstance(v, string_types):
if v.startswith(prefix):
kwargs[k] = getattr(ns, v).local_array

out = getattr(ns, out)

class vectorize(DecoratorBase):
func = np.vectorize(func)
out.local_array = func(*args)


class Vectorize(FunctionRegistrationBase):
"""
Analogous to numpy.vectorize. Input DistArray's must all be the
same shape, and this will be the shape of the output distarray.
Like `Localize`, but vectorizes the function with numpy.vectorize and runs
it on the engines.
"""

def get_local_array(self, da, arg_keys):
return arg_keys + [da.key + '.local_array']

def __call__(self, *args, **kwargs):
# get context from args
context = self.determine_context(args, kwargs)
# push function
self.push_fn(context, self.fn_key, self.fn)
# vectorize the function
exec_str = "%s = numpy.vectorize(%s)" % (self.fn_key, self.fn_key)
context._execute(exec_str)

# Find the first distarray, they should all be the same up to the data.
# TODO: FIXME: This uses an extra round-trip (or two (or three)) to
# create the `out` array. Better would be to create a new LocalArray
# inside _rpc_vectorize and return its metadata to create a DistArray
# using `.from_localarrays()`.
for arg in args:
if isinstance(arg, DistArray):
# Create the output distarray.
out = context.empty(arg.shape, dtype=arg.dtype,
dist=arg.dist,
grid_shape=arg.grid_shape)
# parse args
args_str, kwargs_str = self.key_and_push_args(
args, kwargs, context=context,
da_handler=self.get_local_array)

# Call the function
exec_str = ("if %s.local_array.size != 0: %s.local_array = "
"%s(*%s, **%s)")
exec_str %= (out.key, out.key, self.fn_key, args_str,
kwargs_str)
context._execute(exec_str)
return out
dist=arg.dist, grid_shape=arg.grid_shape)
args, kwargs = self.build_args(args, kwargs)
context.view.apply_sync(_rpc_vectorize, self.func,
args, kwargs, out.key, DISTARRAY_BASE_NAME)
return out
6 changes: 6 additions & 0 deletions distarray/local/localarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
from distarray.local.error import InvalidDimensionError, IncompatibleArrayError


def _rpc(payload):
funcname, args, kwargs = payload
func = globals()[funcname]
return func(*args, **kwargs)


def _start_stop_block(size, proc_grid_size, proc_grid_rank):
nelements = size // proc_grid_size
if size % proc_grid_size != 0:
Expand Down
33 changes: 16 additions & 17 deletions distarray/plotting/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,7 @@

from six.moves import range
from matplotlib import pyplot, colors, cm
from numpy import arange, concatenate, empty, linspace, resize

from distarray.decorators import local


@local
def _get_ranks(arr):
"""
Given a distarray arr, return a distarray with the same shape, but
with the elements equal to the rank of the process the element is
on.
"""
out = arr.copy()
out.local_array[:] = arr.comm_rank
out.local_array = out.local_array.astype(int)
return out
from numpy import concatenate, empty, linspace


def cmap_discretize(cmap, N):
Expand Down Expand Up @@ -272,8 +257,22 @@ def plot_array_distribution(darray,
# This is based somewhat on:
# http://matplotlib.org/examples/api/colorbar_only.html

def _get_ranks(arr):
"""
Given a distarray arr, return a distarray with the same shape, but
with the elements equal to the rank of the process the element is
on.
"""
out = arr.copy()
out.local_array[:] = arr.comm_rank
out.local_array = out.local_array.astype(int)
return out

ctx = darray.context
ctx.register(_get_ranks)

# Process per element.
process_darray = _get_ranks(darray)
process_darray = ctx._get_ranks(darray)
process_array = process_darray.toarray()

# Values per element.
Expand Down
Loading