diff --git a/distarray/__version__.py b/distarray/__version__.py index 651582af..25e2d375 100644 --- a/distarray/__version__.py +++ b/distarray/__version__.py @@ -8,5 +8,5 @@ Version information for the DistArray package. """ -__short_version__ = "0.5" -__version__ = "0.5.0" +__short_version__ = "0.6" +__version__ = "0.6.0-dev" diff --git a/distarray/globalapi/context.py b/distarray/globalapi/context.py index 2809e142..eb47f577 100644 --- a/distarray/globalapi/context.py +++ b/distarray/globalapi/context.py @@ -77,6 +77,12 @@ def apply(self, func, args=None, kwargs=None, targets=None): def push_function(self, key, func): pass + def __enter__(self): + return self + + def __exit__(self, type_, value, traceback): + self.close() + def _setup_context_key(self): """ Create a dict on the engines which will hold everything from @@ -205,7 +211,7 @@ def local_allclose(la, lb, rtol, atol): from numpy import allclose return allclose(la.ndarray, lb.ndarray, rtol, atol) - local_results = self.apply(local_allclose, + local_results = self.apply(local_allclose, (a.key, b.key, rtol, atol), targets=a.targets) return all(local_results) @@ -579,7 +585,7 @@ def is_NoneType(pxy): return pxy.type_str == str(type(None)) def is_LocalArray(pxy): - return (isinstance(pxy, Proxy) and + return (isinstance(pxy, Proxy) and pxy.type_str == "") if all(is_LocalArray(r) for r in results): diff --git a/distarray/globalapi/tests/test_context.py b/distarray/globalapi/tests/test_context.py index 09848eeb..bc96681d 100644 --- a/distarray/globalapi/tests/test_context.py +++ b/distarray/globalapi/tests/test_context.py @@ -9,7 +9,6 @@ Many of these tests require a 4-engine cluster to be running locally. The engines should be launched with MPI, using the MPIEngineSetLauncher. - """ import unittest @@ -19,13 +18,25 @@ from numpy.testing import assert_allclose, assert_array_equal -from distarray.testing import DefaultContextTestCase, IPythonContextTestCase, check_targets +from distarray.testing import (DefaultContextTestCase, IPythonContextTestCase, + check_targets) from distarray.globalapi.context import Context from distarray.globalapi.maps import Distribution from distarray.mpionly_utils import is_solo_mpi_process, get_nengines from distarray.localapi import LocalArray +class TestContextManager(DefaultContextTestCase): + + ntargets = 'any' + + def test_manager(self): + with Context() as mycon: + testarr = mycon.zeros((10,10)) + # `close` is currently a no-op for MPI contexts, so I don't test + # anything regarding the __exit__ behavior + + class TestRegister(DefaultContextTestCase): ntargets = 'any' @@ -53,7 +64,7 @@ def test_local_sin(self): def local_sin(da): return numpy.sin(da) self.context.register(local_sin) - + db = self.context.local_sin(self.da) assert_allclose(0, db.tondarray(), atol=1e-14) @@ -146,7 +157,7 @@ def local_none(da): self.assertTrue(dp is None) def test_parameterless(self): - + def parameterless(): """This is a parameterless function.""" return None