From 9f2763b7f70d0d48f4297a433158b037ed040177 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Fri, 15 May 2026 17:57:54 +0000 Subject: [PATCH] Define AllGather.__enter/exit__ for insert_finished --- python/rapidsmpf/rapidsmpf/coll/allgather.pyi | 9 +++++++++ python/rapidsmpf/rapidsmpf/coll/allgather.pyx | 7 +++++++ .../rapidsmpf/streaming/coll/allgather.pyi | 10 ++++++++++ .../rapidsmpf/streaming/coll/allgather.pyx | 7 +++++++ .../rapidsmpf/tests/streaming/test_allgather.py | 8 ++++---- .../rapidsmpf/rapidsmpf/tests/test_allgather.py | 16 +++++++--------- 6 files changed, 44 insertions(+), 13 deletions(-) diff --git a/python/rapidsmpf/rapidsmpf/coll/allgather.pyi b/python/rapidsmpf/rapidsmpf/coll/allgather.pyi index e89d6858a..2cbf89450 100644 --- a/python/rapidsmpf/rapidsmpf/coll/allgather.pyi +++ b/python/rapidsmpf/rapidsmpf/coll/allgather.pyi @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +from typing import Any + from rapidsmpf.communicator.communicator import Communicator from rapidsmpf.memory.buffer_resource import BufferResource from rapidsmpf.memory.packed_data import PackedData @@ -17,6 +19,13 @@ class AllGather: ) -> None: ... @property def comm(self) -> Communicator: ... + def __enter__(self) -> AllGather: ... + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: Any | None, + ) -> bool: ... def insert(self, sequence_number: int, packed_data: PackedData) -> None: ... def insert_finished(self) -> None: ... def wait_and_extract( diff --git a/python/rapidsmpf/rapidsmpf/coll/allgather.pyx b/python/rapidsmpf/rapidsmpf/coll/allgather.pyx index 01d345897..a55913d23 100644 --- a/python/rapidsmpf/rapidsmpf/coll/allgather.pyx +++ b/python/rapidsmpf/rapidsmpf/coll/allgather.pyx @@ -116,6 +116,13 @@ cdef class AllGather: with nogil: deref(self._handle).insert_finished() + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.insert_finished() + return False # do not suppress exceptions + def wait_and_extract(self, bool ordered = True, int timeout_ms = -1): """ Wait for completion and extract all gathered data. diff --git a/python/rapidsmpf/rapidsmpf/streaming/coll/allgather.pyi b/python/rapidsmpf/rapidsmpf/streaming/coll/allgather.pyi index 2fae28c60..cb870baa5 100644 --- a/python/rapidsmpf/rapidsmpf/streaming/coll/allgather.pyi +++ b/python/rapidsmpf/rapidsmpf/streaming/coll/allgather.pyi @@ -1,5 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: Apache-2.0 + +from typing import Any + from rapidsmpf.communicator.communicator import Communicator from rapidsmpf.memory.packed_data import PackedData from rapidsmpf.streaming.chunks.packed_data import PackedDataChunk @@ -11,6 +14,13 @@ class AllGather: def __init__(self, ctx: Context, comm: Communicator, op_id: int) -> None: ... @property def comm(self) -> Communicator: ... + def __enter__(self) -> AllGather: ... + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: Any | None, + ) -> bool: ... def insert(self, sequence_number: int, packed_data: PackedData) -> None: ... def insert_finished(self) -> None: ... async def extract_all(self, ctx: Context, *, ordered: bool) -> list[PackedData]: ... diff --git a/python/rapidsmpf/rapidsmpf/streaming/coll/allgather.pyx b/python/rapidsmpf/rapidsmpf/streaming/coll/allgather.pyx index 25fd0aa71..e2988d3c5 100644 --- a/python/rapidsmpf/rapidsmpf/streaming/coll/allgather.pyx +++ b/python/rapidsmpf/rapidsmpf/streaming/coll/allgather.pyx @@ -124,6 +124,13 @@ cdef class AllGather: with nogil: deref(self._handle).insert_finished() + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.insert_finished() + return False # do not suppress exceptions + async def extract_all(self, Context ctx, *, bool ordered): """ Suspend and extract all data from the AllGather. diff --git a/python/rapidsmpf/rapidsmpf/tests/streaming/test_allgather.py b/python/rapidsmpf/rapidsmpf/tests/streaming/test_allgather.py index 81c021f81..085510412 100644 --- a/python/rapidsmpf/rapidsmpf/tests/streaming/test_allgather.py +++ b/python/rapidsmpf/rapidsmpf/tests/streaming/test_allgather.py @@ -125,10 +125,10 @@ async def allgather_and_concat( op_id: int, ) -> None: gather = AllGather(context, comm, op_id) - while (msg := await ch_in.recv(context)) is not None: - chunk = PackedDataChunk.from_message(msg, br=context.br()).to_packed_data() - gather.insert(msg.sequence_number, chunk) - gather.insert_finished() + with gather as ag: + while (msg := await ch_in.recv(context)) is not None: + chunk = PackedDataChunk.from_message(msg, br=context.br()).to_packed_data() + ag.insert(msg.sequence_number, chunk) gathered = await gather.extract_all(context, ordered=True) stream = context.get_stream_from_pool() table = unpack_and_concat(gathered, stream, context.br()) diff --git a/python/rapidsmpf/rapidsmpf/tests/test_allgather.py b/python/rapidsmpf/rapidsmpf/tests/test_allgather.py index 7493b8509..ab6f06558 100644 --- a/python/rapidsmpf/rapidsmpf/tests/test_allgather.py +++ b/python/rapidsmpf/rapidsmpf/tests/test_allgather.py @@ -145,15 +145,13 @@ def test_basic_allgather( this_rank = comm.rank n_ranks = comm.nranks - # Insert data from this rank - for i in range(n_inserts): - packed_data = generate_packed_data( - n_elements, gen_offset(i, this_rank), stream, br - ) - allgather.insert(i, packed_data) - - # Mark this rank as finished - allgather.insert_finished() + # Insert data from this rank and mark as finished + with allgather as ag: + for i in range(n_inserts): + packed_data = generate_packed_data( + n_elements, gen_offset(i, this_rank), stream, br + ) + ag.insert(i, packed_data) # Wait for completion and extract results results = allgather.wait_and_extract(ordered=ordered)