diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index f056d1ab610..f109d38810e 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -13,7 +13,10 @@ from __future__ import annotations +import asyncio import contextlib +import contextvars +import functools import itertools import json import random @@ -22,7 +25,15 @@ from dataclasses import dataclass, field from functools import cache from pathlib import Path -from typing import TYPE_CHECKING, Any, ClassVar, assert_never, overload +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + ParamSpec, + TypeVar, + assert_never, + overload, +) import polars as pl @@ -55,6 +66,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Generator, Hashable, Iterable, Sequence + from concurrent.futures import ThreadPoolExecutor from typing import Literal, Self from polars import polars # type: ignore[attr-defined] @@ -66,6 +78,9 @@ from cudf_polars.utils.config import ParquetOptions from cudf_polars.utils.timer import Timer + P = ParamSpec("P") + T = TypeVar("T") + __all__ = [ "IR", "Cache", @@ -105,13 +120,51 @@ class IRExecutionContext: Parameters ---------- + py_executor + Thread pool for thread offload in async execution, only used by + streaming engine. get_cuda_stream A zero-argument callable that returns a CUDA stream. + query_id + Identifier for the query being executed. """ + py_executor: ThreadPoolExecutor | None = field(default=None) get_cuda_stream: Callable[[], Stream] = field(default=get_cuda_stream) query_id: uuid.UUID = field(default_factory=uuid.uuid4) + async def to_thread( + self, func: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs + ) -> T: + """ + Run a function asynchronously in a thread. + + Parameters + ---------- + func + The function to run. + args + Arguments. + kwargs + Keyword arguments. + + Returns + ------- + Awaitable to obtain the result of calling ``func``. + + Notes + ----- + This offloads the function to run in the thread pool attached to + this execution context. + """ + assert self.py_executor is not None, ( + "Execution context must have a thread pool for offload" + ) + loop = asyncio.get_running_loop() + ctx = contextvars.copy_context() + func_call = functools.partial(ctx.run, func, *args, **kwargs) + return await loop.run_in_executor(self.py_executor, func_call) + @contextlib.contextmanager def stream_ordered_after(self, *dfs: DataFrame) -> Generator[Stream, None, None]: """ @@ -2438,9 +2491,11 @@ def do_evaluate( right.columns, left=False, empty=True, - rename=lambda name: name - if name not in left.column_names_set - else f"{name}{suffix}", + rename=lambda name: ( + name + if name not in left.column_names_set + else f"{name}{suffix}" + ), stream=stream, ) result = DataFrame([*left_cols, *right_cols], stream=stream) @@ -2454,9 +2509,11 @@ def do_evaluate( right_cols = Join._build_columns( columns[left.num_columns :], right.columns, - rename=lambda name: name - if name not in left.column_names_set - else f"{name}{suffix}", + rename=lambda name: ( + name + if name not in left.column_names_set + else f"{name}{suffix}" + ), left=False, stream=stream, ) diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/allgather.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/allgather.py index 163a8ba50d1..911e3f27092 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/allgather.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/allgather.py @@ -4,7 +4,6 @@ from __future__ import annotations -import asyncio from typing import TYPE_CHECKING from rapidsmpf.integrations.cudf.partition import unpack_and_concat @@ -21,6 +20,8 @@ import pylibcudf as plc from rmm.pylibrmm.stream import Stream + from cudf_polars.dsl.ir import IRExecutionContext + class AllGatherManager: """ @@ -100,7 +101,7 @@ def inserting(self) -> AllGatherManager.Inserter: return AllGatherManager.Inserter(self) async def extract_concatenated( - self, stream: Stream, *, ordered: bool = True + self, stream: Stream, *, ordered: bool = True, ir_context: IRExecutionContext ) -> plc.Table: """ Extract the concatenated result. @@ -111,12 +112,14 @@ async def extract_concatenated( The stream to use for chunk extraction. ordered: bool Whether to extract the data in ordered or unordered fashion. + ir_context + Execution context to offload concatenation. Returns ------- The concatenated AllGather result. """ - return await asyncio.to_thread( + return await ir_context.to_thread( unpack_and_concat, partitions=await self.allgather.extract_all(self.context, ordered=ordered), stream=stream, diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/sort.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/sort.py index 3bcc2517c4b..6d63562a08a 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/sort.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/sort.py @@ -118,7 +118,9 @@ async def _simple_top_or_bottom_k( chunk = await evaluate_chunk( context, TableChunk.from_pylibcudf_table( - await allgather.extract_concatenated(stream, ordered=True), + await allgather.extract_concatenated( + stream, ordered=True, ir_context=ir_context + ), stream, exclusive_view=True, br=context.br(), @@ -194,7 +196,9 @@ async def _compute_sort_boundaries( allgather = AllGatherManager(context, comm, allgather_id) with allgather.inserting() as inserter: inserter.insert(comm.rank, chunk) - concat_table = await allgather.extract_concatenated(stream, ordered=True) + concat_table = await allgather.extract_concatenated( + stream, ordered=True, ir_context=ir_context + ) return _get_final_sort_boundaries( DataFrame.from_table( concat_table, diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/frontend/core.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/frontend/core.py index ef7fbf24475..c1d5cae5088 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/frontend/core.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/frontend/core.py @@ -439,7 +439,7 @@ def execute_ir_on_rank( Collected channel metadata. """ ir_context = IRExecutionContext( - get_cuda_stream=ctx.get_stream_from_pool, query_id=query_id + py_executor, get_cuda_stream=ctx.get_stream_from_pool, query_id=query_id ) metadata_collector: list[ChannelMetadata] = [] @@ -455,7 +455,7 @@ def execute_ir_on_rank( metadata_collector=metadata_collector, ) - run_actor_network(actors=nodes, py_executor=py_executor) + run_actor_network(ctx, actors=nodes) messages = output.release() chunks = [ diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/io.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/io.py index d5005441910..f99482fb938 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/io.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/io.py @@ -383,7 +383,7 @@ async def read_chunk( context, size=estimated_chunk_bytes, net_memory_delta=estimated_chunk_bytes ) ): - df = await asyncio.to_thread( + df = await ir_context.to_thread( scan.do_evaluate, *scan._non_child_args, context=ir_context, @@ -828,7 +828,7 @@ async def sink_node( ).make_available_and_spill(context.br(), allow_overbooking=True) df = chunk_to_frame(chunk, child_ir) part_path = f"{path_root}.{str(i).zfill(count_width)}.{suffix}" - await asyncio.to_thread( + await ir_context.to_thread( Sink.do_evaluate, ir.sink.schema, ir.sink.kind, @@ -848,7 +848,7 @@ async def sink_node( ).make_available_and_spill(context.br(), allow_overbooking=True) # Multiple chunks - use chunked writer df = chunk_to_frame(chunk, child_ir) - writer_state = await asyncio.to_thread( + writer_state = await ir_context.to_thread( _sink_to_file, ir.sink.kind, ir.sink.path, @@ -861,7 +861,7 @@ async def sink_node( if writer_state and ir.sink.kind == "Parquet": # We know that with ir.sink.kind == "Parquet", writer_state being truthy # means that it's a ChunkedParquetWriter. - await asyncio.to_thread(writer_state.close, []) # type: ignore[attr-defined] + await ir_context.to_thread(writer_state.close, []) # type: ignore[attr-defined] # Signal completion on the metadata and data channels with empty results stream = ir_context.get_cuda_stream() diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/join.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/join.py index 6a9815678d0..6fd5280b9e4 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/join.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/join.py @@ -4,7 +4,6 @@ from __future__ import annotations -import asyncio from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Literal @@ -193,7 +192,7 @@ async def _collect_small_side_for_broadcast( for s_id in range(len(chunks)): inserter.insert(s_id, chunks.pop(0)) stream = ir_context.get_cuda_stream() - gathered = await allgather.extract_concatenated(stream) + gathered = await allgather.extract_concatenated(stream, ir_context=ir_context) # When every rank inserted zero chunks, the AllGather has no schema # to infer and returns a 0 column table. Substitute a properly typed # empty table for the small side so downstream joins still match the @@ -266,7 +265,7 @@ async def _broadcast_join_large_chunk( await reserve_memory(context, size=input_bytes, net_memory_delta=0) ): for sdf in dfs_to_join: - result = await asyncio.to_thread( + result = await ir_context.to_thread( ir.do_evaluate, *ir._non_child_args, *([large_df, sdf] if broadcast_side == "right" else [sdf, large_df]), @@ -469,7 +468,7 @@ async def _join_chunks( with opaque_memory_usage( await reserve_memory(context, size=input_bytes, net_memory_delta=0) ): - df = await asyncio.to_thread( + df = await ir_context.to_thread( ir.do_evaluate, *ir._non_child_args, chunk_to_frame(left_chunk, left), diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/nodes.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/nodes.py index fd83c5e092f..56e6c1ed5ac 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/nodes.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/nodes.py @@ -212,7 +212,7 @@ async def default_node_multi( for chunk, child in zip(ready_chunks, ir.children, strict=True) ] with opaque_memory_usage(extra): - df = await asyncio.to_thread( + df = await ir_context.to_thread( ir.do_evaluate, *ir._non_child_args, *dfs, diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/over.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/over.py index 1a0e09b8e8e..c2c72ae9dab 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/over.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/over.py @@ -37,7 +37,6 @@ from __future__ import annotations -import asyncio from dataclasses import dataclass from typing import TYPE_CHECKING, Any, ClassVar, cast @@ -207,7 +206,7 @@ async def _evaluate_broadcast_chunk( net_memory_delta=0, ) with opaque_memory_usage(extra): - return await asyncio.to_thread( + return await ir_context.to_thread( _evaluate_ir_broadcast_sync, chunk, ir, @@ -366,7 +365,7 @@ async def _allgather_and_broadcast( net_memory_delta=0, ) with opaque_memory_usage(extra): - partial = await asyncio.to_thread( + partial = await ir_context.to_thread( _evaluate_chunk_sync, chunk, piecewise_ir, @@ -491,7 +490,7 @@ async def _distribute_by_group( # 1..nranks-1 sit idle on emit. Slice the duplicated input # across ranks (e.g. stripe by row index) and stamp each # slice with its target origin rank to distribute emit work. - stamped = await asyncio.to_thread( + stamped = await ir_context.to_thread( _append_origin_stamps, chunk, chunk_index, @@ -523,10 +522,10 @@ async def _evaluate_and_route_to_origin( partition = TableChunk.from_pylibcudf_table( extracted, stream, exclusive_view=True, br=context.br() ) - evaluated = await asyncio.to_thread( + evaluated = await ir_context.to_thread( _evaluate_window_with_stamps, partition, ir, ir_context, stamps ) - routed, splits = await asyncio.to_thread( + routed, splits = await ir_context.to_thread( _partition_by_origin_rank, evaluated, num_ranks, context.br() ) if routed is not None: diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/repartition.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/repartition.py index c4bf2163678..dba7dd5a0a3 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/repartition.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/repartition.py @@ -155,7 +155,9 @@ async def concatenate_node( del msg # Extract concatenated result - result_table = await allgather.extract_concatenated(stream) + result_table = await allgather.extract_concatenated( + stream, ir_context=ir_context + ) if tracer is not None: tracer.add_chunk(table=result_table) diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/utils.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/utils.py index d21f6c4f977..39546950223 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/utils.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/utils.py @@ -503,7 +503,7 @@ async def evaluate_chunk( ) with opaque_memory_usage(extra): for single_ir in irs: - chunk = await asyncio.to_thread( + chunk = await ir_context.to_thread( _evaluate_chunk_sync, chunk, single_ir, ir_context, context.br() ) return chunk @@ -544,7 +544,7 @@ async def allgather_and_reduce( inserter.insert(0, local_chunk) stream = ir_context.get_cuda_stream() concat_chunk = TableChunk.from_pylibcudf_table( - await allgather.extract_concatenated(stream), + await allgather.extract_concatenated(stream, ir_context=ir_context), stream, exclusive_view=True, br=context.br(), @@ -583,7 +583,7 @@ async def concat_batch( net_memory_delta=0, ) with opaque_memory_usage(extra): - df = await asyncio.to_thread( + df = await ir_context.to_thread( _concat, *[ DataFrame.from_table( diff --git a/python/cudf_polars/pyproject.toml b/python/cudf_polars/pyproject.toml index 91867fb16eb..a36afe7935c 100644 --- a/python/cudf_polars/pyproject.toml +++ b/python/cudf_polars/pyproject.toml @@ -206,6 +206,7 @@ ban-relative-imports = "all" [tool.ruff.lint.flake8-tidy-imports.banned-api] "asyncio.gather".msg = "Use gather_with_task_group instead." +"asyncio.to_thread".msg = "Use ir_context.to_thread instead." [tool.ruff.lint.flake8-type-checking] strict = true diff --git a/python/cudf_polars/tests/experimental/test_allgather.py b/python/cudf_polars/tests/experimental/test_allgather.py index 3a68b3f3988..4b3a751214a 100644 --- a/python/cudf_polars/tests/experimental/test_allgather.py +++ b/python/cudf_polars/tests/experimental/test_allgather.py @@ -6,6 +6,7 @@ from __future__ import annotations import asyncio +from concurrent.futures import ThreadPoolExecutor from rapidsmpf.streaming.cudf.table_chunk import TableChunk @@ -13,6 +14,7 @@ import pylibcudf as plc +from cudf_polars.dsl.ir import IRExecutionContext from cudf_polars.experimental.rapidsmpf.collectives.allgather import AllGatherManager from cudf_polars.experimental.rapidsmpf.utils import allgather_reduce @@ -46,8 +48,12 @@ async def _test_allgather(engine) -> None: ), ) - # Extract concatenated result - result = await allgather.extract_concatenated(stream, ordered=True) + with ThreadPoolExecutor(max_workers=1) as executor: + ir_context = IRExecutionContext(executor) + # Extract concatenated result + result = await allgather.extract_concatenated( + stream, ordered=True, ir_context=ir_context + ) # Verify the concatenated table has the expected shape assert result.num_rows() == 600 # 100 + 200 + 300