Skip to content
Merged
129 changes: 74 additions & 55 deletions distarray/dist/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
# Copyright (C) 2008-2014, IPython Development Team and Enthought, Inc.
# Distributed under the terms of the BSD License. See COPYING.rst.
# ---------------------------------------------------------------------------

"""
`Context` objects contain the information required for distarrays to
communicate with localarrays.
`Context` objects contain the information required for `DistArray`s to
communicate with `LocalArray`s.
"""

from __future__ import absolute_import
Expand All @@ -26,6 +27,7 @@


class Context(object):

"""
Context objects manage the setup and communication of the worker processes
for DistArray objects. A DistArray object has a context, and contexts have
Expand All @@ -35,7 +37,6 @@ class Context(object):
Typically there is just one context object that uses all processes,
although it is possible to have more than one context with a different
selection of engines.

"""

_CLEANUP = None
Expand Down Expand Up @@ -200,14 +201,10 @@ def _push0(self, d):
def _pull0(self, k):
return self.view.pull(k, targets=self.targets[0], block=True)

def _create_local(self, local_call, shape, dist, grid_shape, dtype):
def _create_local(self, local_call, distribution, dtype):
"""Creates LocalArrays with the method named in `local_call`."""
da_key = self._generate_key()
comm_name = self._comm_key
distribution = Distribution.from_shape(context=self,
shape=shape,
dist=dist,
grid_shape=grid_shape)
ddpr = distribution.get_dim_data_per_rank()
ddpr_name, dtype_name = self._key_and_push(ddpr, dtype)
cmd = ('{da_key} = {local_call}(distarray.local.maps.Distribution('
Expand All @@ -217,26 +214,53 @@ def _create_local(self, local_call, shape, dist, grid_shape, dtype):
return DistArray.from_localarrays(da_key, distribution=distribution,
dtype=dtype)

def zeros(self, shape, dtype=float, dist=None, grid_shape=None):
if dist is None:
dist = {0: 'b'}
return self._create_local(local_call='distarray.local.zeros',
shape=shape, dist=dist,
grid_shape=grid_shape, dtype=dtype)
def empty(self, distribution, dtype=float):
"""Create an empty Distarray.

def ones(self, shape, dtype=float, dist=None, grid_shape=None):
if dist is None:
dist = {0: 'b'}
return self._create_local(local_call='distarray.local.ones',
shape=shape, dist=dist,
grid_shape=grid_shape, dtype=dtype,)
Parameters
----------
distribution : Distribution object
dtype : NumPy dtype, optional (default float)

def empty(self, shape, dtype=float, dist=None, grid_shape=None):
if dist is None:
dist = {0: 'b'}
Returns
-------
DistArray
A DistArray distributed as specified, with uninitialized values.
"""
return self._create_local(local_call='distarray.local.empty',
shape=shape, dist=dist,
grid_shape=grid_shape, dtype=dtype)
distribution=distribution, dtype=dtype)

def zeros(self, distribution, dtype=float):
"""Create a Distarray filled with zeros.

Parameters
----------
distribution : Distribution object
dtype : NumPy dtype, optional (default float)

Returns
-------
DistArray
A DistArray distributed as specified, filled with zeros.
"""
return self._create_local(local_call='distarray.local.zeros',
distribution=distribution, dtype=dtype)

def ones(self, distribution, dtype=float):
"""Create a Distarray filled with ones.

Parameters
----------
distribution : Distribution object
dtype : NumPy dtype, optional (default float)

Returns
-------
DistArray
A DistArray distributed as specified, filled with ones.
"""
return self._create_local(local_call='distarray.local.ones',
distribution=distribution, dtype=dtype,)

def save_dnpy(self, name, da):
"""
Expand Down Expand Up @@ -387,51 +411,41 @@ def save_hdf5(self, filename, da, key='buffer', mode='a'):
'distarray.local.save_hdf5(%s, %s, %s, %s)' % subs
)

def load_npy(self, filename, dim_data_per_rank):
def load_npy(self, filename, distribution):
"""
Load a DistArray from a dataset in a ``.npy`` file.

Parameters
----------
filename : str
Filename to load.
dim_data_per_rank : sequence of tuples of dict
A "dim_data" data structure for every rank. Described here:
https://github.com/enthought/distributed-array-protocol
distribution: Distribution object

Returns
-------
result : DistArray
A DistArray encapsulating the file loaded.

"""
if len(self.targets) != len(dim_data_per_rank):
errmsg = "`dim_data_per_rank` must contain a dim_data for every rank."
raise TypeError(errmsg)

da_key = self._generate_key()
subs = ((da_key,) + self._key_and_push(filename, dim_data_per_rank) +
ddpr = distribution.get_dim_data_per_rank()
subs = ((da_key,) + self._key_and_push(filename, ddpr) +
(self._comm_key,) + (self._comm_key,))

self._execute(
'%s = distarray.local.load_npy(%s, %s[%s.Get_rank()], %s)' % subs
)

distribution = Distribution.from_dim_data_per_rank(self,
dim_data_per_rank)
return DistArray.from_localarrays(da_key, distribution=distribution)

def load_hdf5(self, filename, dim_data_per_rank, key='buffer'):
def load_hdf5(self, filename, distribution, key='buffer'):
"""
Load a DistArray from a dataset in an ``.hdf5`` file.

Parameters
----------
filename : str
Filename to load.
dim_data_per_rank : sequence of tuples of dict
A "dim_data" data structure for every rank. Described here:
https://github.com/enthought/distributed-array-protocol
distribution: Distribution object
key : str, optional
The identifier for the group to load the DistArray from (the
default is 'buffer').
Expand All @@ -448,29 +462,34 @@ def load_hdf5(self, filename, dim_data_per_rank, key='buffer'):
errmsg = "An MPI-enabled h5py must be available to use load_hdf5."
raise ImportError(errmsg)

if len(self.targets) != len(dim_data_per_rank):
errmsg = "`dim_data_per_rank` must contain a dim_data for every rank."
raise TypeError(errmsg)

da_key = self._generate_key()
subs = ((da_key,) + self._key_and_push(filename, dim_data_per_rank) +
ddpr = distribution.get_dim_data_per_rank()
subs = ((da_key,) + self._key_and_push(filename, ddpr) +
(self._comm_key,) + self._key_and_push(key) + (self._comm_key,))

self._execute(
'%s = distarray.local.load_hdf5(%s, %s[%s.Get_rank()], %s, %s)' % subs
)
return DistArray.from_localarrays(da_key, distribution=distribution)

distribution = Distribution.from_dim_data_per_rank(self,
dim_data_per_rank)
def fromndarray(self, arr, distribution=None):
"""Create a DistArray from an ndarray.

return DistArray.from_localarrays(da_key, distribution=distribution)
Parameters
----------
distribution : Distribution object, optional
If a Distribution object is not provided, one is created with
`Distribution.from_shape(arr.shape)`.

def fromndarray(self, arr, dist=None, grid_shape=None):
"""Convert an ndarray to a distarray."""
if dist is None:
dist = {0: 'b'}
out = self.empty(arr.shape, dtype=arr.dtype, dist=dist,
grid_shape=grid_shape)
Returns
-------
DistArray
A DistArray distributed as specified, using the values and dtype
from `arr`.
"""
if distribution is None:
distribution = Distribution.from_shape(self, arr.shape)
out = self.empty(distribution, dtype=arr.dtype)
for index, value in numpy.ndenumerate(arr):
out[index] = value
return out
Expand Down
4 changes: 1 addition & 3 deletions distarray/dist/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,7 @@ def __call__(self, *args, **kwargs):
for arg in args:
if isinstance(arg, DistArray):
# Create the output distarray.
out = context.empty(arg.shape, dtype=arg.dtype,
dist=arg.dist,
grid_shape=arg.grid_shape)
out = context.empty(arg.distribution, dtype=arg.dtype)
# parse args
args_str, kwargs_str = self.key_and_push_args(
args, kwargs, context=context,
Expand Down
Loading