From 7b63461f11c754c45df33785162929a66a165682 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Tue, 12 May 2026 12:09:19 +0100 Subject: [PATCH 1/3] Implement our own to_thread offload asyncio.to_thread always uses the default asyncio thread pool that contains a hardware-dependent number of threads. Although one can set the default executor on an event loop, when the loop exits, the executor is shut down. Since we want the executor thread pool to persist between collect calls we can't do that. Instead, hang an executor on the IRExecutionContext and use the new to_thread method to offload. --- python/cudf_polars/cudf_polars/dsl/ir.py | 67 +++++++++++++++++-- .../rapidsmpf/collectives/allgather.py | 9 ++- .../rapidsmpf/collectives/sort.py | 8 ++- .../experimental/rapidsmpf/frontend/core.py | 4 +- .../cudf_polars/experimental/rapidsmpf/io.py | 8 +-- .../experimental/rapidsmpf/join.py | 7 +- .../experimental/rapidsmpf/nodes.py | 2 +- .../experimental/rapidsmpf/over.py | 11 ++- .../experimental/rapidsmpf/repartition.py | 4 +- .../experimental/rapidsmpf/utils.py | 6 +- .../tests/experimental/test_allgather.py | 10 ++- 11 files changed, 101 insertions(+), 35 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index f056d1ab610..fede818b195 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -13,7 +13,10 @@ from __future__ import annotations +import asyncio import contextlib +import contextvars +import functools import itertools import json import random @@ -22,7 +25,15 @@ from dataclasses import dataclass, field from functools import cache from pathlib import Path -from typing import TYPE_CHECKING, Any, ClassVar, assert_never, overload +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + ParamSpec, + TypeVar, + assert_never, + overload, +) import polars as pl @@ -55,6 +66,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Generator, Hashable, Iterable, Sequence + from concurrent.futures import ThreadPoolExecutor from typing import Literal, Self from polars import polars # type: ignore[attr-defined] @@ -66,6 +78,9 @@ from cudf_polars.utils.config import ParquetOptions from cudf_polars.utils.timer import Timer + P = ParamSpec("P") + T = TypeVar("T") + __all__ = [ "IR", "Cache", @@ -109,9 +124,43 @@ class IRExecutionContext: A zero-argument callable that returns a CUDA stream. """ + py_executor: ThreadPoolExecutor | None = field(default=None) get_cuda_stream: Callable[[], Stream] = field(default=get_cuda_stream) query_id: uuid.UUID = field(default_factory=uuid.uuid4) + async def to_thread( + self, func: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs + ) -> T: + """ + Run a function asynchronously in a thread. + + Parameters + ---------- + func + The function to run. + args + Arguments. + kwargs + Keyword arguments. + + Returns + ------- + Coroutine + To be awaited to obtain the result of calling ``func``. + + Notes + ----- + This offloads the function to run in the thread pool attached to + this execution context. + """ + assert self.py_executor is not None, ( + "Execution context must have a thread pool for offload" + ) + loop = asyncio.get_running_loop() + ctx = contextvars.copy_context() + func_call = functools.partial(ctx.run, func, *args, **kwargs) + return await loop.run_in_executor(self.py_executor, func_call) + @contextlib.contextmanager def stream_ordered_after(self, *dfs: DataFrame) -> Generator[Stream, None, None]: """ @@ -2438,9 +2487,11 @@ def do_evaluate( right.columns, left=False, empty=True, - rename=lambda name: name - if name not in left.column_names_set - else f"{name}{suffix}", + rename=lambda name: ( + name + if name not in left.column_names_set + else f"{name}{suffix}" + ), stream=stream, ) result = DataFrame([*left_cols, *right_cols], stream=stream) @@ -2454,9 +2505,11 @@ def do_evaluate( right_cols = Join._build_columns( columns[left.num_columns :], right.columns, - rename=lambda name: name - if name not in left.column_names_set - else f"{name}{suffix}", + rename=lambda name: ( + name + if name not in left.column_names_set + else f"{name}{suffix}" + ), left=False, stream=stream, ) diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/allgather.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/allgather.py index 163a8ba50d1..911e3f27092 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/allgather.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/allgather.py @@ -4,7 +4,6 @@ from __future__ import annotations -import asyncio from typing import TYPE_CHECKING from rapidsmpf.integrations.cudf.partition import unpack_and_concat @@ -21,6 +20,8 @@ import pylibcudf as plc from rmm.pylibrmm.stream import Stream + from cudf_polars.dsl.ir import IRExecutionContext + class AllGatherManager: """ @@ -100,7 +101,7 @@ def inserting(self) -> AllGatherManager.Inserter: return AllGatherManager.Inserter(self) async def extract_concatenated( - self, stream: Stream, *, ordered: bool = True + self, stream: Stream, *, ordered: bool = True, ir_context: IRExecutionContext ) -> plc.Table: """ Extract the concatenated result. @@ -111,12 +112,14 @@ async def extract_concatenated( The stream to use for chunk extraction. ordered: bool Whether to extract the data in ordered or unordered fashion. + ir_context + Execution context to offload concatenation. Returns ------- The concatenated AllGather result. """ - return await asyncio.to_thread( + return await ir_context.to_thread( unpack_and_concat, partitions=await self.allgather.extract_all(self.context, ordered=ordered), stream=stream, diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/sort.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/sort.py index 3bcc2517c4b..6d63562a08a 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/sort.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/sort.py @@ -118,7 +118,9 @@ async def _simple_top_or_bottom_k( chunk = await evaluate_chunk( context, TableChunk.from_pylibcudf_table( - await allgather.extract_concatenated(stream, ordered=True), + await allgather.extract_concatenated( + stream, ordered=True, ir_context=ir_context + ), stream, exclusive_view=True, br=context.br(), @@ -194,7 +196,9 @@ async def _compute_sort_boundaries( allgather = AllGatherManager(context, comm, allgather_id) with allgather.inserting() as inserter: inserter.insert(comm.rank, chunk) - concat_table = await allgather.extract_concatenated(stream, ordered=True) + concat_table = await allgather.extract_concatenated( + stream, ordered=True, ir_context=ir_context + ) return _get_final_sort_boundaries( DataFrame.from_table( concat_table, diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/frontend/core.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/frontend/core.py index ef7fbf24475..c1d5cae5088 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/frontend/core.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/frontend/core.py @@ -439,7 +439,7 @@ def execute_ir_on_rank( Collected channel metadata. """ ir_context = IRExecutionContext( - get_cuda_stream=ctx.get_stream_from_pool, query_id=query_id + py_executor, get_cuda_stream=ctx.get_stream_from_pool, query_id=query_id ) metadata_collector: list[ChannelMetadata] = [] @@ -455,7 +455,7 @@ def execute_ir_on_rank( metadata_collector=metadata_collector, ) - run_actor_network(actors=nodes, py_executor=py_executor) + run_actor_network(ctx, actors=nodes) messages = output.release() chunks = [ diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/io.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/io.py index d5005441910..f99482fb938 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/io.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/io.py @@ -383,7 +383,7 @@ async def read_chunk( context, size=estimated_chunk_bytes, net_memory_delta=estimated_chunk_bytes ) ): - df = await asyncio.to_thread( + df = await ir_context.to_thread( scan.do_evaluate, *scan._non_child_args, context=ir_context, @@ -828,7 +828,7 @@ async def sink_node( ).make_available_and_spill(context.br(), allow_overbooking=True) df = chunk_to_frame(chunk, child_ir) part_path = f"{path_root}.{str(i).zfill(count_width)}.{suffix}" - await asyncio.to_thread( + await ir_context.to_thread( Sink.do_evaluate, ir.sink.schema, ir.sink.kind, @@ -848,7 +848,7 @@ async def sink_node( ).make_available_and_spill(context.br(), allow_overbooking=True) # Multiple chunks - use chunked writer df = chunk_to_frame(chunk, child_ir) - writer_state = await asyncio.to_thread( + writer_state = await ir_context.to_thread( _sink_to_file, ir.sink.kind, ir.sink.path, @@ -861,7 +861,7 @@ async def sink_node( if writer_state and ir.sink.kind == "Parquet": # We know that with ir.sink.kind == "Parquet", writer_state being truthy # means that it's a ChunkedParquetWriter. - await asyncio.to_thread(writer_state.close, []) # type: ignore[attr-defined] + await ir_context.to_thread(writer_state.close, []) # type: ignore[attr-defined] # Signal completion on the metadata and data channels with empty results stream = ir_context.get_cuda_stream() diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/join.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/join.py index 6a9815678d0..6fd5280b9e4 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/join.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/join.py @@ -4,7 +4,6 @@ from __future__ import annotations -import asyncio from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Literal @@ -193,7 +192,7 @@ async def _collect_small_side_for_broadcast( for s_id in range(len(chunks)): inserter.insert(s_id, chunks.pop(0)) stream = ir_context.get_cuda_stream() - gathered = await allgather.extract_concatenated(stream) + gathered = await allgather.extract_concatenated(stream, ir_context=ir_context) # When every rank inserted zero chunks, the AllGather has no schema # to infer and returns a 0 column table. Substitute a properly typed # empty table for the small side so downstream joins still match the @@ -266,7 +265,7 @@ async def _broadcast_join_large_chunk( await reserve_memory(context, size=input_bytes, net_memory_delta=0) ): for sdf in dfs_to_join: - result = await asyncio.to_thread( + result = await ir_context.to_thread( ir.do_evaluate, *ir._non_child_args, *([large_df, sdf] if broadcast_side == "right" else [sdf, large_df]), @@ -469,7 +468,7 @@ async def _join_chunks( with opaque_memory_usage( await reserve_memory(context, size=input_bytes, net_memory_delta=0) ): - df = await asyncio.to_thread( + df = await ir_context.to_thread( ir.do_evaluate, *ir._non_child_args, chunk_to_frame(left_chunk, left), diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/nodes.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/nodes.py index fd83c5e092f..56e6c1ed5ac 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/nodes.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/nodes.py @@ -212,7 +212,7 @@ async def default_node_multi( for chunk, child in zip(ready_chunks, ir.children, strict=True) ] with opaque_memory_usage(extra): - df = await asyncio.to_thread( + df = await ir_context.to_thread( ir.do_evaluate, *ir._non_child_args, *dfs, diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/over.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/over.py index 1a0e09b8e8e..c2c72ae9dab 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/over.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/over.py @@ -37,7 +37,6 @@ from __future__ import annotations -import asyncio from dataclasses import dataclass from typing import TYPE_CHECKING, Any, ClassVar, cast @@ -207,7 +206,7 @@ async def _evaluate_broadcast_chunk( net_memory_delta=0, ) with opaque_memory_usage(extra): - return await asyncio.to_thread( + return await ir_context.to_thread( _evaluate_ir_broadcast_sync, chunk, ir, @@ -366,7 +365,7 @@ async def _allgather_and_broadcast( net_memory_delta=0, ) with opaque_memory_usage(extra): - partial = await asyncio.to_thread( + partial = await ir_context.to_thread( _evaluate_chunk_sync, chunk, piecewise_ir, @@ -491,7 +490,7 @@ async def _distribute_by_group( # 1..nranks-1 sit idle on emit. Slice the duplicated input # across ranks (e.g. stripe by row index) and stamp each # slice with its target origin rank to distribute emit work. - stamped = await asyncio.to_thread( + stamped = await ir_context.to_thread( _append_origin_stamps, chunk, chunk_index, @@ -523,10 +522,10 @@ async def _evaluate_and_route_to_origin( partition = TableChunk.from_pylibcudf_table( extracted, stream, exclusive_view=True, br=context.br() ) - evaluated = await asyncio.to_thread( + evaluated = await ir_context.to_thread( _evaluate_window_with_stamps, partition, ir, ir_context, stamps ) - routed, splits = await asyncio.to_thread( + routed, splits = await ir_context.to_thread( _partition_by_origin_rank, evaluated, num_ranks, context.br() ) if routed is not None: diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/repartition.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/repartition.py index c4bf2163678..dba7dd5a0a3 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/repartition.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/repartition.py @@ -155,7 +155,9 @@ async def concatenate_node( del msg # Extract concatenated result - result_table = await allgather.extract_concatenated(stream) + result_table = await allgather.extract_concatenated( + stream, ir_context=ir_context + ) if tracer is not None: tracer.add_chunk(table=result_table) diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/utils.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/utils.py index d21f6c4f977..39546950223 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/utils.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/utils.py @@ -503,7 +503,7 @@ async def evaluate_chunk( ) with opaque_memory_usage(extra): for single_ir in irs: - chunk = await asyncio.to_thread( + chunk = await ir_context.to_thread( _evaluate_chunk_sync, chunk, single_ir, ir_context, context.br() ) return chunk @@ -544,7 +544,7 @@ async def allgather_and_reduce( inserter.insert(0, local_chunk) stream = ir_context.get_cuda_stream() concat_chunk = TableChunk.from_pylibcudf_table( - await allgather.extract_concatenated(stream), + await allgather.extract_concatenated(stream, ir_context=ir_context), stream, exclusive_view=True, br=context.br(), @@ -583,7 +583,7 @@ async def concat_batch( net_memory_delta=0, ) with opaque_memory_usage(extra): - df = await asyncio.to_thread( + df = await ir_context.to_thread( _concat, *[ DataFrame.from_table( diff --git a/python/cudf_polars/tests/experimental/test_allgather.py b/python/cudf_polars/tests/experimental/test_allgather.py index 3a68b3f3988..4b3a751214a 100644 --- a/python/cudf_polars/tests/experimental/test_allgather.py +++ b/python/cudf_polars/tests/experimental/test_allgather.py @@ -6,6 +6,7 @@ from __future__ import annotations import asyncio +from concurrent.futures import ThreadPoolExecutor from rapidsmpf.streaming.cudf.table_chunk import TableChunk @@ -13,6 +14,7 @@ import pylibcudf as plc +from cudf_polars.dsl.ir import IRExecutionContext from cudf_polars.experimental.rapidsmpf.collectives.allgather import AllGatherManager from cudf_polars.experimental.rapidsmpf.utils import allgather_reduce @@ -46,8 +48,12 @@ async def _test_allgather(engine) -> None: ), ) - # Extract concatenated result - result = await allgather.extract_concatenated(stream, ordered=True) + with ThreadPoolExecutor(max_workers=1) as executor: + ir_context = IRExecutionContext(executor) + # Extract concatenated result + result = await allgather.extract_concatenated( + stream, ordered=True, ir_context=ir_context + ) # Verify the concatenated table has the expected shape assert result.num_rows() == 600 # 100 + 200 + 300 From eefb384b99614dc8af2ee111d2a31659a953df01 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Thu, 14 May 2026 18:13:12 +0100 Subject: [PATCH 2/3] Ruff rule --- python/cudf_polars/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/python/cudf_polars/pyproject.toml b/python/cudf_polars/pyproject.toml index 91867fb16eb..a36afe7935c 100644 --- a/python/cudf_polars/pyproject.toml +++ b/python/cudf_polars/pyproject.toml @@ -206,6 +206,7 @@ ban-relative-imports = "all" [tool.ruff.lint.flake8-tidy-imports.banned-api] "asyncio.gather".msg = "Use gather_with_task_group instead." +"asyncio.to_thread".msg = "Use ir_context.to_thread instead." [tool.ruff.lint.flake8-type-checking] strict = true From 8aa62db19e94a430fa6b63e537c17a5f23ad352e Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Fri, 15 May 2026 10:53:33 +0100 Subject: [PATCH 3/3] Fix docstrings --- python/cudf_polars/cudf_polars/dsl/ir.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index fede818b195..f109d38810e 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -120,8 +120,13 @@ class IRExecutionContext: Parameters ---------- + py_executor + Thread pool for thread offload in async execution, only used by + streaming engine. get_cuda_stream A zero-argument callable that returns a CUDA stream. + query_id + Identifier for the query being executed. """ py_executor: ThreadPoolExecutor | None = field(default=None) @@ -145,8 +150,7 @@ async def to_thread( Returns ------- - Coroutine - To be awaited to obtain the result of calling ``func``. + Awaitable to obtain the result of calling ``func``. Notes -----