Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions distarray/dist/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,19 @@ def index_owners(self, idx):
"""
raise IndexError()

def _is_compatible_degenerate(self, map):
right_types = all(isinstance(m, (NoDistMap, BlockMap, BlockCyclicMap))
for m in (self, map))
return (right_types
and self.grid_size == map.grid_size == 1
and self.size == map.size)

def is_compatible(self, map):
return ((self.dist == map.dist) and
(vars(self) == vars(map)))
if self._is_compatible_degenerate(map):
return True
else:
return ((self.dist == map.dist) and
(vars(self) == vars(map)))


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -231,6 +241,11 @@ def slice(self, idx):
return {'dist_type': self.dist,
'size': isection_size}

def is_compatible(self, other):
return (isinstance(other, (NoDistMap, BlockMap, BlockCyclicMap)) and
other.grid_size == self.grid_size and
other.size == self.size)


class BlockMap(MapBase):

Expand Down Expand Up @@ -339,6 +354,11 @@ def slice(self, idx):
return {'dist_type': self.dist,
'bounds': new_bounds}

def is_compatible(self, other):
if isinstance(other, NoDistMap):
return other.is_compatible(self)
return super(BlockMap, self).is_compatible(other)


class BlockCyclicMap(MapBase):

Expand Down Expand Up @@ -387,6 +407,11 @@ def get_dimdicts(self):
'block_size': self.block_size,
}) for grid_rank in range(self.grid_size))

def is_compatible(self, other):
if isinstance(other, NoDistMap):
return other.is_compatible(self)
return super(BlockCyclicMap, self).is_compatible(other)


class UnstructuredMap(MapBase):

Expand Down Expand Up @@ -725,8 +750,8 @@ def get_dim_data_per_rank(self):
return [dd for (_, dd) in rank_and_dd]

def is_compatible(self, o):
return ((self.context, self.targets, self.shape, self.ndim, self.dist, self.grid_shape) ==
(o.context, o.targets, o.shape, o.ndim, o.dist, o.grid_shape) and
return ((self.context, self.targets, self.shape, self.ndim, self.grid_shape) ==
(o.context, o.targets, o.shape, o.ndim, o.grid_shape) and
all(m.is_compatible(om) for (m, om) in zip(self.maps, o.maps)))

def reduce(self, axes):
Expand Down
177 changes: 138 additions & 39 deletions distarray/dist/tests/test_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,17 @@
from distarray.externals.six.moves import range

from distarray.testing import ContextTestCase
from distarray.dist import maps
from distarray.dist.maps import MapBase
from distarray.dist.maps import MapBase, Distribution


class TestClientMap(ContextTestCase):

def test_2D_bn(self):
nrows, ncols = 31, 53
cm = maps.Distribution.from_shape(self.context,
(nrows, ncols),
{0: 'b'},
(4, 1))
cm = Distribution.from_shape(self.context,
(nrows, ncols),
{0: 'b'},
(4, 1))
chunksize = (nrows // 4) + 1
for _ in range(100):
r, c = randrange(nrows), randrange(ncols)
Expand All @@ -33,10 +32,10 @@ def test_2D_bn(self):
def test_2D_bb(self):
nrows, ncols = 3, 5
nprocs_per_dim = 2
cm = maps.Distribution.from_shape(self.context,
(nrows, ncols),
('b', 'b'),
(nprocs_per_dim, nprocs_per_dim))
cm = Distribution.from_shape(self.context,
(nrows, ncols),
('b', 'b'),
(nprocs_per_dim, nprocs_per_dim))
row_chunks = nrows // nprocs_per_dim + 1
col_chunks = ncols // nprocs_per_dim + 1
for r in range(nrows):
Expand All @@ -48,10 +47,10 @@ def test_2D_bb(self):
def test_2D_cc(self):
nrows, ncols = 3, 5
nprocs_per_dim = 2
cm = maps.Distribution.from_shape(self.context,
(nrows, ncols),
('c', 'c'),
(nprocs_per_dim, nprocs_per_dim))
cm = Distribution.from_shape(self.context,
(nrows, ncols),
('c', 'c'),
(nprocs_per_dim, nprocs_per_dim))
for r in range(nrows):
for c in range(ncols):
rank = ((r % nprocs_per_dim) * nprocs_per_dim
Expand All @@ -62,35 +61,135 @@ def test_2D_cc(self):
def test_is_compatible(self):
nr, nc, nd = 10**5, 10**6, 10**4

cm0 = maps.Distribution.from_shape(self.context,
(nr, nc, nd),
('b', 'c', 'n'))
cm0 = Distribution.from_shape(self.context,
(nr, nc, nd),
('b', 'c', 'n'))
self.assertTrue(cm0.is_compatible(cm0))

cm1 = maps.Distribution.from_shape(self.context,
(nr, nc, nd),
('b', 'c', 'n'))
cm1 = Distribution.from_shape(self.context,
(nr, nc, nd),
('b', 'c', 'n'))
self.assertTrue(cm1.is_compatible(cm1))

self.assertTrue(cm0.is_compatible(cm1))
self.assertTrue(cm1.is_compatible(cm0))

nr -= 1; nc -= 1; nd -= 1

cm2 = maps.Distribution.from_shape(self.context,
(nr, nc, nd),
('b', 'c', 'n'))
cm2 = Distribution.from_shape(self.context,
(nr, nc, nd),
('b', 'c', 'n'))

self.assertFalse(cm1.is_compatible(cm2))
self.assertFalse(cm2.is_compatible(cm1))

def test_is_compatible_nodist(self):
# See GH issue #461.
dist_bcn = Distribution.from_shape(self.context,
(10, 10, 10),
('b', 'c', 'n'),
(1, 1, 1),
targets=[0])
dist_nnn = Distribution.from_shape(self.context,
(10, 10, 10),
('n', 'n', 'n'),
(1, 1, 1),
targets=[0])
self.assertTrue(dist_bcn.is_compatible(dist_nnn))
self.assertTrue(dist_nnn.is_compatible(dist_bcn))

def test_is_compatible_degenerate(self):
dist_bc = Distribution.from_shape(self.context,
(10, 10),
('b', 'c'),
(1, 1),
targets=[0])
dist_cb = Distribution.from_shape(self.context,
(10, 10),
('c', 'b'),
(1, 1),
targets=[0])
self.assertTrue(dist_bc.is_compatible(dist_cb))
self.assertTrue(dist_cb.is_compatible(dist_bc))

def test_is_compatible_degenerate_block_cyclic(self):
size = 19937
gdd_block_cyclic = (
{
'dist_type': 'c',
'proc_grid_size': 1,
'block_size': 7,
'size': size,
},
)
gdd_block = (
{
'dist_type': 'b',
'proc_grid_size': 1,
'bounds': [0, size],
},
)
gdd_cyclic = (
{
'dist_type': 'c',
'proc_grid_size': 1,
'size': size,
},
)
dist_block_cyclic = Distribution(self.context, gdd_block_cyclic)
dist_block = Distribution(self.context, gdd_block)
dist_cyclic = Distribution(self.context, gdd_cyclic)

self.assertTrue(dist_block_cyclic.is_compatible(dist_block))
self.assertTrue(dist_block_cyclic.is_compatible(dist_cyclic))

self.assertTrue(dist_block.is_compatible(dist_block_cyclic))
self.assertTrue(dist_cyclic.is_compatible(dist_block_cyclic))

def test_not_compatible(self):
dist_b1 = Distribution.from_shape(self.context,
(10,), ('b',),
(1,), targets=[0])

dist_b2 = Distribution.from_shape(self.context,
(9,), ('b',),
(1,), targets=[0])

self.assertFalse(dist_b1.is_compatible(dist_b2))
self.assertFalse(dist_b2.is_compatible(dist_b1))

dist_b3 = Distribution.from_shape(self.context,
(10,), ('b',),
(2,), targets=[0,1])

self.assertFalse(dist_b1.is_compatible(dist_b3))
self.assertFalse(dist_b3.is_compatible(dist_b1))

dist_b4 = Distribution.from_shape(self.context,
(10,), ('c',),
(2,), targets=[0,1])

self.assertFalse(dist_b4.is_compatible(dist_b3))
self.assertFalse(dist_b3.is_compatible(dist_b4))

gdd_unstructured = (
{
'dist_type': 'u',
'indices': [range(10)],
},
)
dist_u = Distribution(self.context, gdd_unstructured)

self.assertFalse(dist_u.is_compatible(dist_b1))
self.assertFalse(dist_b1.is_compatible(dist_u))

def test_reduce(self):
nr, nc, nd = 10**5, 10**6, 10**4

dist = maps.Distribution.from_shape(self.context,
(nr, nc, nd),
('b', 'c', 'n'),
grid_shape=(2, 2, 1))
dist = Distribution.from_shape(self.context,
(nr, nc, nd),
('b', 'c', 'n'),
grid_shape=(2, 2, 1))

new_dist0 = dist.reduce(axes=[0])
self.assertEqual(new_dist0.dist, ('c', 'n'))
Expand All @@ -113,7 +212,7 @@ def test_reduce(self):

def test_reduce_0D(self):
N = 10**5
dist = maps.Distribution.from_shape(self.context, (N,))
dist = Distribution.from_shape(self.context, (N,))
new_dist = dist.reduce(axes=[0])
self.assertEqual(new_dist.dist, ())
self.assertSequenceEqual(new_dist.shape, ())
Expand All @@ -124,7 +223,7 @@ def test_reduce_0D(self):
class TestSlice(ContextTestCase):

def test_from_partial_slice_1d(self):
d0 = maps.Distribution.from_shape(context=self.context, shape=(15,))
d0 = Distribution.from_shape(context=self.context, shape=(15,))

s = (slice(0, 3),)
d1 = d0.slice(s)
Expand All @@ -135,7 +234,7 @@ def test_from_partial_slice_1d(self):
self.assertSequenceEqual(d1.shape, (3,))

def test_from_full_slice_1d(self):
d0 = maps.Distribution.from_shape(context=self.context, shape=(15,))
d0 = Distribution.from_shape(context=self.context, shape=(15,))

s = (slice(None),)
d1 = d0.slice(s)
Expand All @@ -146,7 +245,7 @@ def test_from_full_slice_1d(self):
self.assertSequenceEqual(d1.maps[0].bounds, d0.maps[0].bounds)

def test_from_full_slice_with_step_1d_0(self):
d0 = maps.Distribution.from_shape(context=self.context, shape=(15,))
d0 = Distribution.from_shape(context=self.context, shape=(15,))

s = (slice(None, None, 2),)
d1 = d0.slice(s)
Expand All @@ -157,7 +256,7 @@ def test_from_full_slice_with_step_1d_0(self):
self.assertEqual(d1.maps[0].bounds[0][0], d0.maps[0].bounds[0][0])

def test_from_full_slice_with_step_1d_1(self):
d0 = maps.Distribution.from_shape(context=self.context, shape=(30,))
d0 = Distribution.from_shape(context=self.context, shape=(30,))
step = 4

s = (slice(4, None, step),)
Expand All @@ -169,7 +268,7 @@ def test_from_full_slice_with_step_1d_1(self):
self.assertEqual(d1.maps[0].bounds[0][0], d0.maps[0].bounds[0][0])

def test_from_full_slice_2d(self):
d0 = maps.Distribution.from_shape(context=self.context, shape=(15, 20))
d0 = Distribution.from_shape(context=self.context, shape=(15, 20))

s = (slice(None), slice(None))
d1 = d0.slice(s)
Expand All @@ -182,7 +281,7 @@ def test_from_full_slice_2d(self):
self.assertSequenceEqual(d1.targets, d0.targets)

def test_from_partial_slice_2d(self):
d0 = maps.Distribution.from_shape(context=self.context, shape=(15, 20))
d0 = Distribution.from_shape(context=self.context, shape=(15, 20))

s = (slice(3, 7), 4)
d1 = d0.slice(s)
Expand All @@ -193,7 +292,7 @@ def test_from_partial_slice_2d(self):
self.assertSequenceEqual(m.bounds, expected)

def test_full_slice_with_int_2d(self):
d0 = maps.Distribution.from_shape(context=self.context, shape=(15, 20))
d0 = Distribution.from_shape(context=self.context, shape=(15, 20))

s = (slice(None), 4)
d1 = d0.slice(s)
Expand All @@ -209,7 +308,7 @@ class TestDunderMethods(ContextTestCase):
def setUpClass(cls):
super(TestDunderMethods, cls).setUpClass()
cls.shape = (3, 4, 5, 6)
cls.cm = maps.Distribution.from_shape(cls.context, cls.shape)
cls.cm = Distribution.from_shape(cls.context, cls.shape)

def test___len__(self):
self.assertEqual(len(self.cm), 4)
Expand All @@ -226,9 +325,9 @@ def test___getitem__(self):

class TestDistributionCreation(ContextTestCase):
def test_all_n_dist(self):
distribution = maps.Distribution.from_shape(self.context,
shape=(3, 3),
dist=('n', 'n'))
distribution = Distribution.from_shape(self.context,
shape=(3, 3),
dist=('n', 'n'))
self.context.ones(distribution)


Expand Down