diff --git a/distarray/client.py b/distarray/client.py index f02f0fe8..0276285a 100644 --- a/distarray/client.py +++ b/distarray/client.py @@ -57,7 +57,7 @@ def is_LocalArray(typestring): return typestring == "" if all(is_LocalArray(r) for r in result_type_str): - result = DistArray.from_localarrays(result_key, subcontext) + result = DistArray.from_localarrays(result_key, context=subcontext) elif all(is_NoneType(r) for r in result_type_str): result = None else: @@ -109,17 +109,38 @@ def __init__(self, distribution, dtype=float): self._dtype = dtype @classmethod - def from_localarrays(cls, key, context): - """ The caller has already created the LocalArray objects. `key` is + def from_localarrays(cls, key, context=None, distribution=None, + dtype=None): + """The caller has already created the LocalArray objects. `key` is their name on the engines. This classmethod creates a DistArray that refers to these LocalArrays. + Either a `context` or a `distribution` must also be provided. If + `context` is provided, a ``dim_data_per_rank`` will be pulled from + the existing ``LocalArray``s and a ``Distribution`` will be created + from it. If `distribution` is provided, it should accurately + reflect the distribution of the existing ``LocalArray``s. + + If `dtype` is not provided, it will be fetched from the engines. """ da = cls.__new__(cls) da.key = key - da.distribution = _make_distribution_from_dim_data_per_rank(key, - context) - da._dtype = _get_attribute(context, key, 'dtype') + + if (context is None) == (distribution is None): + errmsg = "Must provide `context` or `distribution` but not both." + raise RuntimeError(errmsg) + elif (distribution is not None): + da.distribution = distribution + context = distribution.context + elif (context is not None): + da.distribution = _make_distribution_from_dim_data_per_rank(key, + context) + + if dtype is None: + da._dtype = _get_attribute(context, key, 'dtype') + else: + da._dtype = dtype + return da def __del__(self): diff --git a/distarray/client_map.py b/distarray/client_map.py index d84ff5ae..f0abb16e 100644 --- a/distarray/client_map.py +++ b/distarray/client_map.py @@ -35,7 +35,8 @@ make_grid_shape, positivify, validate_grid_shape, - _start_stop_block) + _start_stop_block, + normalize_dim_dict) def _dedup_dim_dicts(dim_dicts): @@ -395,6 +396,9 @@ def from_dim_data_per_rank(cls, context, dim_data_per_rank): self = cls.__new__(cls) dd0 = dim_data_per_rank[0] self.context = context + for dim_data in dim_data_per_rank: + for dim_dict in dim_data: + normalize_dim_dict(dim_dict) self.shape = tuple(dd['size'] for dd in dd0) self.ndim = len(dd0) self.dist = tuple(dd['dist_type'] for dd in dd0) @@ -423,16 +427,19 @@ def from_dim_data_per_rank(cls, context, dim_data_per_rank): return self @classmethod - def from_shape(cls, context, shape, dist, grid_shape=None): + def from_shape(cls, context, shape, dist=None, grid_shape=None): self = cls.__new__(cls) self.context = context self.shape = shape self.ndim = len(shape) + + if dist is None: + dist = {0: 'b'} self.dist = normalize_dist(dist, self.ndim) if grid_shape is None: # Make a new grid_shape if not provided. - self.grid_shape = make_grid_shape(self.shape, dist, + self.grid_shape = make_grid_shape(self.shape, self.dist, len(context.targets)) else: # Otherwise normalize the one passed in. self.grid_shape = normalize_grid_shape(grid_shape, self.ndim) diff --git a/distarray/context.py b/distarray/context.py index 270ba293..498ec36a 100644 --- a/distarray/context.py +++ b/distarray/context.py @@ -200,14 +200,21 @@ def _pull0(self, k): return self.view.pull(k, targets=self.targets[0], block=True) def _create_local(self, local_call, shape, dist, grid_shape, dtype): - """ Creates a local array, according to the method named in `local_call`.""" - shape_name, dist_name, grid_shape_name, dtype_name = \ - self._key_and_push(shape, dist, grid_shape, dtype) + """Creates LocalArrays with the method named in `local_call`.""" da_key = self._generate_key() - comm_key = self._comm_key - cmd = '{da_key} = {local_call}(distarray.local.maps.Distribution.from_shape({shape_name}, {dist_name}, {grid_shape_name}, {comm_key}), {dtype_name})' + comm_name = self._comm_key + distribution = Distribution.from_shape(context=self, + shape=shape, + dist=dist, + grid_shape=grid_shape) + ddpr = distribution.get_dim_data_per_rank() + ddpr_name, dtype_name = self._key_and_push(ddpr, dtype) + cmd = ('{da_key} = {local_call}(distarray.local.maps.Distribution(' + '{ddpr_name}[{comm_name}.Get_rank()], comm={comm_name}), ' + 'dtype={dtype_name})') self._execute(cmd.format(**locals())) - return DistArray.from_localarrays(da_key, self) + return DistArray.from_localarrays(da_key, distribution=distribution, + dtype=dtype) def zeros(self, shape, dtype=float, dist=None, grid_shape=None): if dist is None: @@ -340,7 +347,7 @@ def load_dnpy(self, name): errmsg = "`name` must be a string or a list." raise TypeError(errmsg) - return DistArray.from_localarrays(da_key, self) + return DistArray.from_localarrays(da_key, context=self) def save_hdf5(self, filename, da, key='buffer', mode='a'): """ @@ -409,7 +416,9 @@ def load_npy(self, filename, dim_data_per_rank): '%s = distarray.local.load_npy(%s, %s[%s.Get_rank()], %s)' % subs ) - return DistArray.from_localarrays(da_key, self) + distribution = Distribution.from_dim_data_per_rank(self, + dim_data_per_rank) + return DistArray.from_localarrays(da_key, distribution=distribution) def load_hdf5(self, filename, dim_data_per_rank, key='buffer'): """ @@ -450,7 +459,10 @@ def load_hdf5(self, filename, dim_data_per_rank, key='buffer'): '%s = distarray.local.load_hdf5(%s, %s[%s.Get_rank()], %s, %s)' % subs ) - return DistArray.from_localarrays(da_key, self) + distribution = Distribution.from_dim_data_per_rank(self, + dim_data_per_rank) + + return DistArray.from_localarrays(da_key, distribution=distribution) def fromndarray(self, arr, dist=None, grid_shape=None): """Convert an ndarray to a distarray.""" @@ -472,4 +484,4 @@ def fromfunction(self, function, shape, **kwargs): new_key = self._generate_key() subs = (new_key, func_key) + keys self._execute('%s = distarray.local.fromfunction(%s,%s,**%s)' % subs) - return DistArray.from_localarrays(new_key, self) + return DistArray.from_localarrays(new_key, context=self) diff --git a/distarray/decorators.py b/distarray/decorators.py index 2c3c53bf..3dd60150 100644 --- a/distarray/decorators.py +++ b/distarray/decorators.py @@ -142,7 +142,7 @@ def is_LocalArray(typestring): "LocalArray'>") if all(is_LocalArray(r) for r in result_type_str): - result = DistArray.from_localarrays(result_key, context) + result = DistArray.from_localarrays(result_key, context=context) elif all(is_NoneType(r) for r in result_type_str): result = None else: diff --git a/distarray/functions.py b/distarray/functions.py index 99f8aef3..a4fe8568 100644 --- a/distarray/functions.py +++ b/distarray/functions.py @@ -45,7 +45,8 @@ def proxy_func(a, *args, **kwargs): exec_str %= (new_key, name, a.key) context._execute(exec_str) - return DistArray.from_localarrays(new_key, context) + return DistArray.from_localarrays(new_key, + distribution=a.distribution) return proxy_func @@ -57,12 +58,15 @@ def proxy_func(a, b, *args, **kwargs): if is_a_dap and is_b_dap: a_key = a.key b_key = b.key + distribution = a.distribution elif is_a_dap and numpy.isscalar(b): a_key = a.key b_key = context._key_and_push(b)[0] + distribution = a.distribution elif is_b_dap and numpy.isscalar(a): a_key = context._key_and_push(a)[0] b_key = b.key + distribution = b.distribution else: raise TypeError('only DistArray or scalars are accepted') new_key = context._generate_key() @@ -75,7 +79,7 @@ def proxy_func(a, b, *args, **kwargs): exec_str %= (new_key, name, a_key, b_key) context._execute(exec_str) - return DistArray.from_localarrays(new_key, context) + return DistArray.from_localarrays(new_key, distribution=distribution) return proxy_func diff --git a/distarray/metadata_utils.py b/distarray/metadata_utils.py index 1bb60fb0..3544f747 100644 --- a/distarray/metadata_utils.py +++ b/distarray/metadata_utils.py @@ -69,6 +69,8 @@ def make_grid_shape(shape, dist, comm_size): if not possible to distribute `comm_size` processes over number of dimensions. """ + if not isinstance(dist, Sequence): + raise TypeError("`dist` argument should be a Sequence.") distdims = tuple(i for (i, v) in enumerate(dist) if v != 'n') ndistdim = len(distdims) @@ -194,6 +196,16 @@ def distribute_indices(dd): raise TypeError(msg % dist_type) +def normalize_dim_dict(dd): + """Fill out some degenerate dim_dicts.""" + + # TODO: Fill out empty dim_dict alias here? + + if dd['dist_type'] == 'n': + dd['proc_grid_size'] = 1 + dd['proc_grid_rank'] = 0 + + def positivify(index, size): if 0 <= index < size: return index diff --git a/distarray/random.py b/distarray/random.py index dc937b59..2eb101f8 100644 --- a/distarray/random.py +++ b/distarray/random.py @@ -9,6 +9,7 @@ from distarray.client import DistArray +from distarray.client_map import Distribution class Random(object): @@ -68,17 +69,20 @@ def rand(self, size=None, dist=None, grid_shape=None): if dist is None: dist = {0: 'b'} da_key = self.context._generate_key() - size_key, dist_key, grid_shape_key = \ - self.context._key_and_push(size, dist, grid_shape) - comm_key = self.context._comm_key + + distribution = Distribution.from_shape(context=self.context, + shape=size, + dist=dist, + grid_shape=grid_shape) + ddpr = distribution.get_dim_data_per_rank() + ddpr_name = self.context._key_and_push(ddpr)[0] + comm_name = self.context._comm_key self.context._execute( '{da_key} = distarray.local.random.rand(' - 'distribution=distarray.local.maps.Distribution.from_shape(' - 'shape={size_key}, dist={dist_key},' - 'grid_shape={grid_shape_key}, comm={comm_key}' - '))'.format(**locals()) - ) - return DistArray.from_localarrays(da_key, self.context) + 'distribution=distarray.local.maps.Distribution(' + 'dim_data={ddpr_name}[{comm_name}.Get_rank()], ' + 'comm={comm_name}))'.format(**locals())) + return DistArray.from_localarrays(da_key, distribution=distribution) def normal(self, loc=0.0, scale=1.0, size=None, dist=None, grid_shape=None): @@ -140,18 +144,22 @@ def normal(self, loc=0.0, scale=1.0, size=None, dist=None, if dist is None: dist = {0: 'b'} da_key = self.context._generate_key() - loc_key, scale_key, size_key, dist_key, grid_shape_key = \ - self.context._key_and_push(loc, scale, size, dist, grid_shape) - comm_key = self.context._comm_key + + distribution = Distribution.from_shape(context=self.context, + shape=size, + dist=dist, + grid_shape=grid_shape) + ddpr = distribution.get_dim_data_per_rank() + loc_name, scale_name, ddpr_name = \ + self.context._key_and_push(loc, scale, ddpr) + comm_name = self.context._comm_key self.context._execute( '{da_key} = distarray.local.random.normal(' - 'loc={loc_key}, scale={scale_key},' - 'distribution=distarray.local.maps.Distribution.from_shape(' - 'shape={size_key}, dist={dist_key},' - 'grid_shape={grid_shape_key}, comm={comm_key}' - '))'.format(**locals()) - ) - return DistArray.from_localarrays(da_key, self.context) + 'loc={loc_name}, scale={scale_name},' + 'distribution=distarray.local.maps.Distribution(' + 'dim_data={ddpr_name}[{comm_name}.Get_rank()], ' + 'comm={comm_name}))'.format(**locals())) + return DistArray.from_localarrays(da_key, distribution=distribution) def randint(self, low, high=None, size=None, dist=None, grid_shape=None): """ @@ -190,18 +198,22 @@ def randint(self, low, high=None, size=None, dist=None, grid_shape=None): if dist is None: dist = {0: 'b'} da_key = self.context._generate_key() - low_key, high_key, size_key, dist_key, grid_shape_key = \ - self.context._key_and_push(low, high, size, dist, grid_shape) - comm_key = self.context._comm_key + + distribution = Distribution.from_shape(context=self.context, + shape=size, + dist=dist, + grid_shape=grid_shape) + ddpr = distribution.get_dim_data_per_rank() + low_name, high_name, ddpr_name = \ + self.context._key_and_push(low, high, ddpr) + comm_name = self.context._comm_key self.context._execute( '{da_key} = distarray.local.random.randint(' - 'low={low_key}, high={high_key},' - 'distribution=distarray.local.maps.Distribution.from_shape(' - 'shape={size_key}, dist={dist_key},' - 'grid_shape={grid_shape_key}, comm={comm_key}' - '))'.format(**locals()) - ) - return DistArray.from_localarrays(da_key, self.context) + 'low={low_name}, high={high_name},' + 'distribution=distarray.local.maps.Distribution(' + 'dim_data={ddpr_name}[{comm_name}.Get_rank()], ' + 'comm={comm_name}))'.format(**locals())) + return DistArray.from_localarrays(da_key, distribution=distribution) def randn(self, size=None, dist=None, grid_shape=None): """ @@ -229,14 +241,17 @@ def randn(self, size=None, dist=None, grid_shape=None): if dist is None: dist = {0: 'b'} da_key = self.context._generate_key() - size_key, dist_key, grid_shape_key = \ - self.context._key_and_push(size, dist, grid_shape) - comm_key = self.context._comm_key + + distribution = Distribution.from_shape(context=self.context, + shape=size, + dist=dist, + grid_shape=grid_shape) + ddpr = distribution.get_dim_data_per_rank() + ddpr_name = self.context._key_and_push(ddpr)[0] + comm_name = self.context._comm_key self.context._execute( '{da_key} = distarray.local.random.randn(' - 'distribution=distarray.local.maps.Distribution.from_shape(' - 'shape={size_key}, dist={dist_key},' - 'grid_shape={grid_shape_key}, comm={comm_key}' - '))'.format(**locals()) - ) - return DistArray.from_localarrays(da_key, self.context) + 'distribution=distarray.local.maps.Distribution(' + 'dim_data={ddpr_name}[{comm_name}.Get_rank()], ' + 'comm={comm_name}))'.format(**locals())) + return DistArray.from_localarrays(da_key, distribution=distribution)