diff --git a/distarray/dist/maps.py b/distarray/dist/maps.py index bca5f2f6..9bcc5bb0 100644 --- a/distarray/dist/maps.py +++ b/distarray/dist/maps.py @@ -585,6 +585,12 @@ def __init__(self, context, global_dim_data, targets=None): nelts = reduce(operator.mul, self.grid_shape, 1) self.rank_from_coords = np.arange(nelts).reshape(self.grid_shape) + def __getitem__(self, idx): + return self.maps[idx] + + def __len__(self): + return len(self.maps) + @property def has_precise_index(self): """ diff --git a/distarray/dist/tests/test_maps.py b/distarray/dist/tests/test_maps.py index bc45b2e8..9eec4dda 100644 --- a/distarray/dist/tests/test_maps.py +++ b/distarray/dist/tests/test_maps.py @@ -11,6 +11,7 @@ from distarray.testing import ContextTestCase from distarray.dist import maps as client_map +from distarray.dist.maps import MapBase class TestClientMap(ContextTestCase): @@ -111,6 +112,27 @@ def test_reduce_0D(self): self.assertEqual(set(new_dist.targets), set(dist.targets[:1])) +class TestDunderMethods(ContextTestCase): + + @classmethod + def setUpClass(cls): + super(TestDunderMethods, cls).setUpClass() + cls.shape = (3, 4, 5, 6) + cls.cm = client_map.Distribution.from_shape(cls.context, cls.shape) + + def test___len__(self): + self.assertEqual(len(self.cm), 4) + + def test___getitem__(self): + for m in self.cm: + self.assertTrue(isinstance(m, MapBase)) + + self.assertEqual(self.cm[0].dist, 'b') + self.assertEqual(self.cm[1].dist, 'n') + self.assertEqual(self.cm[2].dist, 'n') + self.assertEqual(self.cm[-1].dist, 'n') + + class TestDistributionCreation(ContextTestCase): def test_all_n_dist(self): distribution = client_map.Distribution.from_shape(self.context,