-
Notifications
You must be signed in to change notification settings - Fork 1
Fix Distribution.from_dim_data_per_rank
#354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0142593
c73df57
db1d4a1
70b0288
daa925d
dc4dc34
56b7e79
5b2776b
e9b3432
7272d21
ae7c65d
11dd030
5351dc9
92582af
ddcc1e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -200,14 +200,21 @@ def _pull0(self, k): | |
| return self.view.pull(k, targets=self.targets[0], block=True) | ||
|
|
||
| def _create_local(self, local_call, shape, dist, grid_shape, dtype): | ||
| """ Creates a local array, according to the method named in `local_call`.""" | ||
| shape_name, dist_name, grid_shape_name, dtype_name = \ | ||
| self._key_and_push(shape, dist, grid_shape, dtype) | ||
| """Creates LocalArrays with the method named in `local_call`.""" | ||
| da_key = self._generate_key() | ||
| comm_key = self._comm_key | ||
| cmd = '{da_key} = {local_call}(distarray.local.maps.Distribution.from_shape({shape_name}, {dist_name}, {grid_shape_name}, {comm_key}), {dtype_name})' | ||
| comm_name = self._comm_key | ||
| distribution = Distribution.from_shape(context=self, | ||
| shape=shape, | ||
| dist=dist, | ||
| grid_shape=grid_shape) | ||
| ddpr = distribution.get_dim_data_per_rank() | ||
| ddpr_name, dtype_name = self._key_and_push(ddpr, dtype) | ||
| cmd = ('{da_key} = {local_call}(distarray.local.maps.Distribution(' | ||
| '{ddpr_name}[{comm_name}.Get_rank()], comm={comm_name}), ' | ||
| 'dtype={dtype_name})') | ||
| self._execute(cmd.format(**locals())) | ||
| return DistArray.from_localarrays(da_key, self) | ||
| return DistArray.from_localarrays(da_key, distribution=distribution, | ||
| dtype=dtype) | ||
|
|
||
| def zeros(self, shape, dtype=float, dist=None, grid_shape=None): | ||
| if dist is None: | ||
|
|
@@ -340,7 +347,7 @@ def load_dnpy(self, name): | |
| errmsg = "`name` must be a string or a list." | ||
| raise TypeError(errmsg) | ||
|
|
||
| return DistArray.from_localarrays(da_key, self) | ||
| return DistArray.from_localarrays(da_key, context=self) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (can't comment on line 332 / 341 above) I'd like this method to also use an |
||
|
|
||
| def save_hdf5(self, filename, da, key='buffer', mode='a'): | ||
| """ | ||
|
|
@@ -409,7 +416,9 @@ def load_npy(self, filename, dim_data_per_rank): | |
| '%s = distarray.local.load_npy(%s, %s[%s.Get_rank()], %s)' % subs | ||
| ) | ||
|
|
||
| return DistArray.from_localarrays(da_key, self) | ||
| distribution = Distribution.from_dim_data_per_rank(self, | ||
| dim_data_per_rank) | ||
| return DistArray.from_localarrays(da_key, distribution=distribution) | ||
|
|
||
| def load_hdf5(self, filename, dim_data_per_rank, key='buffer'): | ||
| """ | ||
|
|
@@ -450,7 +459,10 @@ def load_hdf5(self, filename, dim_data_per_rank, key='buffer'): | |
| '%s = distarray.local.load_hdf5(%s, %s[%s.Get_rank()], %s, %s)' % subs | ||
| ) | ||
|
|
||
| return DistArray.from_localarrays(da_key, self) | ||
| distribution = Distribution.from_dim_data_per_rank(self, | ||
| dim_data_per_rank) | ||
|
|
||
| return DistArray.from_localarrays(da_key, distribution=distribution) | ||
|
|
||
| def fromndarray(self, arr, dist=None, grid_shape=None): | ||
| """Convert an ndarray to a distarray.""" | ||
|
|
@@ -472,4 +484,4 @@ def fromfunction(self, function, shape, **kwargs): | |
| new_key = self._generate_key() | ||
| subs = (new_key, func_key) + keys | ||
| self._execute('%s = distarray.local.fromfunction(%s,%s,**%s)' % subs) | ||
| return DistArray.from_localarrays(new_key, self) | ||
| return DistArray.from_localarrays(new_key, context=self) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |
|
|
||
|
|
||
| from distarray.client import DistArray | ||
| from distarray.client_map import Distribution | ||
|
|
||
|
|
||
| class Random(object): | ||
|
|
@@ -68,17 +69,20 @@ def rand(self, size=None, dist=None, grid_shape=None): | |
| if dist is None: | ||
| dist = {0: 'b'} | ||
| da_key = self.context._generate_key() | ||
| size_key, dist_key, grid_shape_key = \ | ||
| self.context._key_and_push(size, dist, grid_shape) | ||
| comm_key = self.context._comm_key | ||
|
|
||
| distribution = Distribution.from_shape(context=self.context, | ||
| shape=size, | ||
| dist=dist, | ||
| grid_shape=grid_shape) | ||
| ddpr = distribution.get_dim_data_per_rank() | ||
| ddpr_name = self.context._key_and_push(ddpr)[0] | ||
| comm_name = self.context._comm_key | ||
| self.context._execute( | ||
| '{da_key} = distarray.local.random.rand(' | ||
| 'distribution=distarray.local.maps.Distribution.from_shape(' | ||
| 'shape={size_key}, dist={dist_key},' | ||
| 'grid_shape={grid_shape_key}, comm={comm_key}' | ||
| '))'.format(**locals()) | ||
| ) | ||
| return DistArray.from_localarrays(da_key, self.context) | ||
| 'distribution=distarray.local.maps.Distribution(' | ||
| 'dim_data={ddpr_name}[{comm_name}.Get_rank()], ' | ||
| 'comm={comm_name}))'.format(**locals())) | ||
| return DistArray.from_localarrays(da_key, distribution=distribution) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to mention one more time -- translating this |
||
|
|
||
| def normal(self, loc=0.0, scale=1.0, size=None, dist=None, | ||
| grid_shape=None): | ||
|
|
@@ -140,18 +144,22 @@ def normal(self, loc=0.0, scale=1.0, size=None, dist=None, | |
| if dist is None: | ||
| dist = {0: 'b'} | ||
| da_key = self.context._generate_key() | ||
| loc_key, scale_key, size_key, dist_key, grid_shape_key = \ | ||
| self.context._key_and_push(loc, scale, size, dist, grid_shape) | ||
| comm_key = self.context._comm_key | ||
|
|
||
| distribution = Distribution.from_shape(context=self.context, | ||
| shape=size, | ||
| dist=dist, | ||
| grid_shape=grid_shape) | ||
| ddpr = distribution.get_dim_data_per_rank() | ||
| loc_name, scale_name, ddpr_name = \ | ||
| self.context._key_and_push(loc, scale, ddpr) | ||
| comm_name = self.context._comm_key | ||
| self.context._execute( | ||
| '{da_key} = distarray.local.random.normal(' | ||
| 'loc={loc_key}, scale={scale_key},' | ||
| 'distribution=distarray.local.maps.Distribution.from_shape(' | ||
| 'shape={size_key}, dist={dist_key},' | ||
| 'grid_shape={grid_shape_key}, comm={comm_key}' | ||
| '))'.format(**locals()) | ||
| ) | ||
| return DistArray.from_localarrays(da_key, self.context) | ||
| 'loc={loc_name}, scale={scale_name},' | ||
| 'distribution=distarray.local.maps.Distribution(' | ||
| 'dim_data={ddpr_name}[{comm_name}.Get_rank()], ' | ||
| 'comm={comm_name}))'.format(**locals())) | ||
| return DistArray.from_localarrays(da_key, distribution=distribution) | ||
|
|
||
| def randint(self, low, high=None, size=None, dist=None, grid_shape=None): | ||
| """ | ||
|
|
@@ -190,18 +198,22 @@ def randint(self, low, high=None, size=None, dist=None, grid_shape=None): | |
| if dist is None: | ||
| dist = {0: 'b'} | ||
| da_key = self.context._generate_key() | ||
| low_key, high_key, size_key, dist_key, grid_shape_key = \ | ||
| self.context._key_and_push(low, high, size, dist, grid_shape) | ||
| comm_key = self.context._comm_key | ||
|
|
||
| distribution = Distribution.from_shape(context=self.context, | ||
| shape=size, | ||
| dist=dist, | ||
| grid_shape=grid_shape) | ||
| ddpr = distribution.get_dim_data_per_rank() | ||
| low_name, high_name, ddpr_name = \ | ||
| self.context._key_and_push(low, high, ddpr) | ||
| comm_name = self.context._comm_key | ||
| self.context._execute( | ||
| '{da_key} = distarray.local.random.randint(' | ||
| 'low={low_key}, high={high_key},' | ||
| 'distribution=distarray.local.maps.Distribution.from_shape(' | ||
| 'shape={size_key}, dist={dist_key},' | ||
| 'grid_shape={grid_shape_key}, comm={comm_key}' | ||
| '))'.format(**locals()) | ||
| ) | ||
| return DistArray.from_localarrays(da_key, self.context) | ||
| 'low={low_name}, high={high_name},' | ||
| 'distribution=distarray.local.maps.Distribution(' | ||
| 'dim_data={ddpr_name}[{comm_name}.Get_rank()], ' | ||
| 'comm={comm_name}))'.format(**locals())) | ||
| return DistArray.from_localarrays(da_key, distribution=distribution) | ||
|
|
||
| def randn(self, size=None, dist=None, grid_shape=None): | ||
| """ | ||
|
|
@@ -229,14 +241,17 @@ def randn(self, size=None, dist=None, grid_shape=None): | |
| if dist is None: | ||
| dist = {0: 'b'} | ||
| da_key = self.context._generate_key() | ||
| size_key, dist_key, grid_shape_key = \ | ||
| self.context._key_and_push(size, dist, grid_shape) | ||
| comm_key = self.context._comm_key | ||
|
|
||
| distribution = Distribution.from_shape(context=self.context, | ||
| shape=size, | ||
| dist=dist, | ||
| grid_shape=grid_shape) | ||
| ddpr = distribution.get_dim_data_per_rank() | ||
| ddpr_name = self.context._key_and_push(ddpr)[0] | ||
| comm_name = self.context._comm_key | ||
| self.context._execute( | ||
| '{da_key} = distarray.local.random.randn(' | ||
| 'distribution=distarray.local.maps.Distribution.from_shape(' | ||
| 'shape={size_key}, dist={dist_key},' | ||
| 'grid_shape={grid_shape_key}, comm={comm_key}' | ||
| '))'.format(**locals()) | ||
| ) | ||
| return DistArray.from_localarrays(da_key, self.context) | ||
| 'distribution=distarray.local.maps.Distribution(' | ||
| 'dim_data={ddpr_name}[{comm_name}.Get_rank()], ' | ||
| 'comm={comm_name}))'.format(**locals())) | ||
| return DistArray.from_localarrays(da_key, distribution=distribution) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer it if we translated this
_execute()call into anapply_async()call. That would allow us to do without the_key_and_push()round-trip, since we could just pass those in directly.Perhaps that translation should be part of a follow-on PR.