Skip to content
1 change: 1 addition & 0 deletions distarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@

from distarray.client import DistArray
from distarray.context import Context
from distarray.creation import *
from distarray.functions import *
100 changes: 100 additions & 0 deletions distarray/creation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# encoding: utf-8
#----------------------------------------------------------------------------
# Copyright (C) 2008-2014, IPython Development Team and Enthought, Inc.
# Distributed under the terms of the BSD License. See COPYING.rst.
#----------------------------------------------------------------------------
"""
DistArray creation functions.
"""

from functools import wraps

import numpy

from distarray.client import DistArray
from distarray.world import WORLD


def _create_local(context, local_call, shape, dtype, dist, grid_shape):
"""Creates a local array, according to the method named in `local_call`."""
keys = context._key_and_push(shape, dtype, dist, grid_shape)
shape_name, dtype_name, dist_name, grid_shape_name = keys
da_key = context._generate_key()
comm = context._comm_key
cmd = ('{da_key} = {local_call}({shape_name}, {dtype_name}, {dist_name}, '
'{grid_shape_name}, {comm})')
context._execute(cmd.format(**locals()))
return DistArray.from_localarrays(da_key, context)


def from_dim_data(dim_data_per_rank, context=WORLD, dtype=float):
"""Make a DistArray from dim_data structures.

Parameters
----------
dim_data_per_rank : sequence of tuples of dict
A "dim_data" data structure for every rank. Described here:
https://github.com/enthought/distributed-array-protocol
dtype : numpy dtype, optional
dtype for underlying arrays

Returns
-------
result : DistArray
An empty DistArray of the specified size, dimensionality, and
distribution.

"""
if len(context.targets) != len(dim_data_per_rank):
errmsg = "`dim_data_per_rank` must contain a dim_data for every rank."
raise TypeError(errmsg)

da_key = context._generate_key()
subs = ((da_key,) + context._key_and_push(dim_data_per_rank) +
(context._comm_key,) + context._key_and_push(dtype) +
(context._comm_key,))

cmd = ('%s = distarray.local.LocalArray.'
'from_dim_data(%s[%s.Get_rank()], dtype=%s, comm=%s)')
context._execute(cmd % subs)

return DistArray.from_localarrays(da_key, context)


def zeros(shape, context=WORLD, dtype=float, dist={0: 'b'}, grid_shape=None):
return _create_local(context, local_call='distarray.local.zeros',
shape=shape, dtype=dtype, dist=dist,
grid_shape=grid_shape)


def ones(shape, context=WORLD, dtype=float, dist={0: 'b'}, grid_shape=None):
return _create_local(context, local_call='distarray.local.ones',
shape=shape, dtype=dtype, dist=dist,
grid_shape=grid_shape)


def empty(shape, context=WORLD, dtype=float, dist={0: 'b'}, grid_shape=None):
return _create_local(context, local_call='distarray.local.empty',
shape=shape, dtype=dtype, dist=dist,
grid_shape=grid_shape)


def fromndarray(arr, context=WORLD, dist={0: 'b'}, grid_shape=None):
"""Convert an ndarray to a distarray."""
out = empty(arr.shape, dtype=arr.dtype, dist=dist, grid_shape=grid_shape)
for index, value in numpy.ndenumerate(arr):
out[index] = value
return out

fromarray = fromndarray


def fromfunction(function, shape, context=WORLD, **kwargs):
func_key = context._generate_key()
context.view.push_function({func_key: function}, targets=context.targets,
block=True)
keys = context._key_and_push(shape, kwargs)
new_key = context._generate_key()
subs = (new_key, func_key) + keys
context._execute('%s = distarray.local.fromfunction(%s,%s,**%s)' % subs)
return DistArray.from_localarrays(new_key, context)
6 changes: 3 additions & 3 deletions distarray/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from distarray.client import DistArray
from distarray.context import Context
from distarray.error import ContextError
from distarray.creation import empty
from distarray.utils import has_exactly_one


Expand Down Expand Up @@ -197,9 +198,8 @@ def __call__(self, *args, **kwargs):
for arg in args:
if isinstance(arg, DistArray):
# Create the output distarray.
out = self.context.empty(arg.shape, dtype=arg.dtype,
dist=arg.dist,
grid_shape=arg.grid_shape)
out = empty(arg.shape, context=self.context, dtype=arg.dtype,
dist=arg.dist, grid_shape=arg.grid_shape)
# parse args
args_str, kwargs_str = self.key_and_push_args(
args, kwargs, context=self.context,
Expand Down
5 changes: 2 additions & 3 deletions distarray/tests/ipcluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@
from __future__ import print_function

import sys
from distarray.externals import six
from time import sleep
from subprocess import Popen, PIPE


if six.PY2:
if sys.version_info[0] == 2:
ipcluster_cmd = 'ipcluster'
elif six.PY3:
elif sys.version_info[0] == 3:
ipcluster_cmd = 'ipcluster3'
else:
raise NotImplementedError("Not run with Python 2 *or* 3?")
Expand Down
Loading