diff --git a/python/rapidsmpf/rapidsmpf/coll/allgather.pxd b/python/rapidsmpf/rapidsmpf/coll/allgather.pxd index 09ba8bf2c..59ffa4c05 100644 --- a/python/rapidsmpf/rapidsmpf/coll/allgather.pxd +++ b/python/rapidsmpf/rapidsmpf/coll/allgather.pxd @@ -47,3 +47,4 @@ cdef class AllGather: cdef unique_ptr[cpp_AllGather] _handle cdef BufferResource _br cdef Communicator _comm + cdef bint in_context diff --git a/python/rapidsmpf/rapidsmpf/coll/allgather.pyi b/python/rapidsmpf/rapidsmpf/coll/allgather.pyi index e89d6858a..2cbf89450 100644 --- a/python/rapidsmpf/rapidsmpf/coll/allgather.pyi +++ b/python/rapidsmpf/rapidsmpf/coll/allgather.pyi @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +from typing import Any + from rapidsmpf.communicator.communicator import Communicator from rapidsmpf.memory.buffer_resource import BufferResource from rapidsmpf.memory.packed_data import PackedData @@ -17,6 +19,13 @@ class AllGather: ) -> None: ... @property def comm(self) -> Communicator: ... + def __enter__(self) -> AllGather: ... + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: Any | None, + ) -> bool: ... def insert(self, sequence_number: int, packed_data: PackedData) -> None: ... def insert_finished(self) -> None: ... def wait_and_extract( diff --git a/python/rapidsmpf/rapidsmpf/coll/allgather.pyx b/python/rapidsmpf/rapidsmpf/coll/allgather.pyx index 01d345897..8a206d314 100644 --- a/python/rapidsmpf/rapidsmpf/coll/allgather.pyx +++ b/python/rapidsmpf/rapidsmpf/coll/allgather.pyx @@ -59,6 +59,7 @@ cdef class AllGather: cdef cpp_BufferResource* br_ = br.ptr() if statistics is None: statistics = Statistics(enable=False) # Disables statistics. + self.in_context = False with nogil: self._handle = make_unique[cpp_AllGather]( comm._handle, @@ -113,9 +114,20 @@ cdef class AllGather: This method signals that no more data will be inserted by this rank. All ranks must call this method for the allgather operation to complete. """ + if self.in_context: + raise ValueError("Cannot call insert_finished() from within a context") with nogil: deref(self._handle).insert_finished() + def __enter__(self): + self.in_context = True + return self + + def __exit__(self, exc_type, exc, tb): + self.in_context = False + self.insert_finished() + return False # do not suppress exceptions + def wait_and_extract(self, bool ordered = True, int timeout_ms = -1): """ Wait for completion and extract all gathered data. diff --git a/python/rapidsmpf/rapidsmpf/streaming/coll/allgather.pxd b/python/rapidsmpf/rapidsmpf/streaming/coll/allgather.pxd index 5ae0ed27c..3d964d681 100644 --- a/python/rapidsmpf/rapidsmpf/streaming/coll/allgather.pxd +++ b/python/rapidsmpf/rapidsmpf/streaming/coll/allgather.pxd @@ -36,3 +36,4 @@ cdef extern from "" nogil: cdef class AllGather: cdef unique_ptr[cpp_AllGather] _handle cdef Communicator _comm + cdef bint in_context diff --git a/python/rapidsmpf/rapidsmpf/streaming/coll/allgather.pyi b/python/rapidsmpf/rapidsmpf/streaming/coll/allgather.pyi index 2fae28c60..cb870baa5 100644 --- a/python/rapidsmpf/rapidsmpf/streaming/coll/allgather.pyi +++ b/python/rapidsmpf/rapidsmpf/streaming/coll/allgather.pyi @@ -1,5 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: Apache-2.0 + +from typing import Any + from rapidsmpf.communicator.communicator import Communicator from rapidsmpf.memory.packed_data import PackedData from rapidsmpf.streaming.chunks.packed_data import PackedDataChunk @@ -11,6 +14,13 @@ class AllGather: def __init__(self, ctx: Context, comm: Communicator, op_id: int) -> None: ... @property def comm(self) -> Communicator: ... + def __enter__(self) -> AllGather: ... + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: Any | None, + ) -> bool: ... def insert(self, sequence_number: int, packed_data: PackedData) -> None: ... def insert_finished(self) -> None: ... async def extract_all(self, ctx: Context, *, ordered: bool) -> list[PackedData]: ... diff --git a/python/rapidsmpf/rapidsmpf/streaming/coll/allgather.pyx b/python/rapidsmpf/rapidsmpf/streaming/coll/allgather.pyx index 25fd0aa71..12ab0ea27 100644 --- a/python/rapidsmpf/rapidsmpf/streaming/coll/allgather.pyx +++ b/python/rapidsmpf/rapidsmpf/streaming/coll/allgather.pyx @@ -80,6 +80,7 @@ cdef class AllGather: """ def __init__(self, Context ctx not None, Communicator comm not None, int32_t op_id): self._comm = comm + self.in_context = False with nogil: self._handle = make_unique[cpp_AllGather]( ctx._handle, comm._handle, op_id @@ -121,9 +122,20 @@ cdef class AllGather: """ Insert a finished marker into the AllGather. """ + if self.in_context: + raise ValueError("Cannot call insert_finished() from within a context") with nogil: deref(self._handle).insert_finished() + def __enter__(self): + self.in_context = True + return self + + def __exit__(self, exc_type, exc, tb): + self.in_context = False + self.insert_finished() + return False # do not suppress exceptions + async def extract_all(self, Context ctx, *, bool ordered): """ Suspend and extract all data from the AllGather. diff --git a/python/rapidsmpf/rapidsmpf/tests/streaming/test_allgather.py b/python/rapidsmpf/rapidsmpf/tests/streaming/test_allgather.py index afeca966d..884e21b41 100644 --- a/python/rapidsmpf/rapidsmpf/tests/streaming/test_allgather.py +++ b/python/rapidsmpf/rapidsmpf/tests/streaming/test_allgather.py @@ -3,6 +3,7 @@ from __future__ import annotations +from contextlib import nullcontext from typing import TYPE_CHECKING import numpy as np @@ -122,12 +123,16 @@ async def allgather_and_concat( ch_in: Channel[PackedDataChunk], ch_out: Channel[TableChunk], op_id: int, + use_context_manager: bool, # noqa: FBT001 ) -> None: gather = AllGather(context, comm, op_id) - while (msg := await ch_in.recv(context)) is not None: - chunk = PackedDataChunk.from_message(msg, br=context.br()).to_packed_data() - gather.insert(msg.sequence_number, chunk) - gather.insert_finished() + cm = gather if use_context_manager else nullcontext(gather) + with cm as ag: + while (msg := await ch_in.recv(context)) is not None: + chunk = PackedDataChunk.from_message(msg, br=context.br()).to_packed_data() + ag.insert(msg.sequence_number, chunk) + if not use_context_manager: + gather.insert_finished() gathered = await gather.extract_all(context, ordered=True) stream = context.get_stream_from_pool() table = unpack_and_concat(gathered, stream, context.br()) @@ -138,7 +143,14 @@ async def allgather_and_concat( await ch_out.drain(context) -def test_allgather_object_interface(context: Context, comm: Communicator) -> None: +@pytest.mark.parametrize( + "use_context_manager", [True, False], ids=["context", "non-context"] +) +def test_allgather_object_interface( + context: Context, + comm: Communicator, + use_context_manager: bool, # noqa: FBT001 +) -> None: if comm.nranks != 1: pytest.skip("Only support single-rank runs") @@ -149,7 +161,9 @@ def test_allgather_object_interface(context: Context, comm: Communicator) -> Non num_chunks = 10 op_id = 0 actors.append(generate_inputs(context, ch_in, num_rows, num_chunks)) - actors.append(allgather_and_concat(context, comm, ch_in, ch_out, op_id)) + actors.append( + allgather_and_concat(context, comm, ch_in, ch_out, op_id, use_context_manager) + ) actor, deferred = pull_from_channel(context, ch_out) actors.append(actor) diff --git a/python/rapidsmpf/rapidsmpf/tests/test_allgather.py b/python/rapidsmpf/rapidsmpf/tests/test_allgather.py index 6cf867708..19861883d 100644 --- a/python/rapidsmpf/rapidsmpf/tests/test_allgather.py +++ b/python/rapidsmpf/rapidsmpf/tests/test_allgather.py @@ -4,6 +4,7 @@ from __future__ import annotations +from contextlib import nullcontext from typing import TYPE_CHECKING import numpy as np @@ -116,6 +117,9 @@ def gen_offset(i: int, r: int) -> int: @pytest.mark.parametrize("n_elements", [0, 1, 10, 100]) @pytest.mark.parametrize("n_inserts", [0, 1, 10]) @pytest.mark.parametrize("ordered", [False, True]) +@pytest.mark.parametrize( + "use_context_manager", [True, False], ids=["context", "non-context"] +) def test_basic_allgather( comm: Communicator, device_mr: rmm.mr.CudaMemoryResource, @@ -123,6 +127,7 @@ def test_basic_allgather( n_elements: int, n_inserts: int, ordered: bool, # noqa: FBT001 + use_context_manager: bool, # noqa: FBT001 ) -> None: """ Test basic AllGather functionality. @@ -137,23 +142,23 @@ def test_basic_allgather( # Create AllGather instance allgather = AllGather( comm=comm, - op_id=0, # Use operation ID 0 + op_id=0, br=br, statistics=statistics, ) - this_rank = comm.rank n_ranks = comm.nranks + this_rank = comm.rank - # Insert data from this rank - for i in range(n_inserts): - packed_data = generate_packed_data( - n_elements, gen_offset(i, this_rank), stream, br - ) - allgather.insert(i, packed_data) - - # Mark this rank as finished - allgather.insert_finished() + cm = allgather if use_context_manager else nullcontext(allgather) + with cm as ag: + for i in range(n_inserts): + packed_data = generate_packed_data( + n_elements, gen_offset(i, this_rank), stream, br + ) + ag.insert(i, packed_data) + if not use_context_manager: + allgather.insert_finished() # Wait for completion and extract results results = allgather.wait_and_extract(ordered=ordered) @@ -188,3 +193,20 @@ def test_basic_allgather( # This will at least verify the structure and size result_table = unpack_and_concat([result], stream, br) assert result_table.num_rows() == n_elements + + +def test_insert_finished_raises_in_context( + comm: Communicator, + device_mr: rmm.mr.CudaMemoryResource, +) -> None: + """Test that insert_finished raises when called inside a context manager.""" + br = BufferResource(device_mr) + ag = AllGather(comm=comm, op_id=0, br=br) + with ( + ag, + pytest.raises( + ValueError, match=r"Cannot call insert_finished.*within a context" + ), + ): + ag.insert_finished() + ag.wait_and_extract(ordered=True)