Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 64 additions & 7 deletions python/cudf_polars/cudf_polars/dsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@

from __future__ import annotations

import asyncio
import contextlib
import contextvars
import functools
import itertools
import json
import random
Expand All @@ -22,7 +25,15 @@
from dataclasses import dataclass, field
from functools import cache
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, assert_never, overload
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
ParamSpec,
TypeVar,
assert_never,
overload,
)

import polars as pl

Expand Down Expand Up @@ -55,6 +66,7 @@

if TYPE_CHECKING:
from collections.abc import Callable, Generator, Hashable, Iterable, Sequence
from concurrent.futures import ThreadPoolExecutor
from typing import Literal, Self

from polars import polars # type: ignore[attr-defined]
Expand All @@ -66,6 +78,9 @@
from cudf_polars.utils.config import ParquetOptions
from cudf_polars.utils.timer import Timer

P = ParamSpec("P")
T = TypeVar("T")

__all__ = [
"IR",
"Cache",
Expand Down Expand Up @@ -105,13 +120,51 @@ class IRExecutionContext:

Parameters
----------
py_executor
Thread pool for thread offload in async execution, only used by
streaming engine.
get_cuda_stream
A zero-argument callable that returns a CUDA stream.
query_id
Identifier for the query being executed.
"""

py_executor: ThreadPoolExecutor | None = field(default=None)
Comment thread
mroeschke marked this conversation as resolved.
get_cuda_stream: Callable[[], Stream] = field(default=get_cuda_stream)
query_id: uuid.UUID = field(default_factory=uuid.uuid4)

async def to_thread(
self, func: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs
) -> T:
"""
Run a function asynchronously in a thread.

Parameters
----------
func
The function to run.
args
Arguments.
kwargs
Keyword arguments.

Returns
-------
Awaitable to obtain the result of calling ``func``.

Notes
-----
This offloads the function to run in the thread pool attached to
this execution context.
"""
assert self.py_executor is not None, (
"Execution context must have a thread pool for offload"
)
Comment thread
wence- marked this conversation as resolved.
loop = asyncio.get_running_loop()
ctx = contextvars.copy_context()
func_call = functools.partial(ctx.run, func, *args, **kwargs)
return await loop.run_in_executor(self.py_executor, func_call)
Comment thread
wence- marked this conversation as resolved.

@contextlib.contextmanager
def stream_ordered_after(self, *dfs: DataFrame) -> Generator[Stream, None, None]:
"""
Expand Down Expand Up @@ -2438,9 +2491,11 @@ def do_evaluate(
right.columns,
left=False,
empty=True,
rename=lambda name: name
if name not in left.column_names_set
else f"{name}{suffix}",
rename=lambda name: (
name
if name not in left.column_names_set
else f"{name}{suffix}"
),
stream=stream,
)
result = DataFrame([*left_cols, *right_cols], stream=stream)
Expand All @@ -2454,9 +2509,11 @@ def do_evaluate(
right_cols = Join._build_columns(
columns[left.num_columns :],
right.columns,
rename=lambda name: name
if name not in left.column_names_set
else f"{name}{suffix}",
rename=lambda name: (
name
if name not in left.column_names_set
else f"{name}{suffix}"
),
left=False,
stream=stream,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING

from rapidsmpf.integrations.cudf.partition import unpack_and_concat
Expand All @@ -21,6 +20,8 @@
import pylibcudf as plc
from rmm.pylibrmm.stream import Stream

from cudf_polars.dsl.ir import IRExecutionContext


class AllGatherManager:
"""
Expand Down Expand Up @@ -100,7 +101,7 @@ def inserting(self) -> AllGatherManager.Inserter:
return AllGatherManager.Inserter(self)

async def extract_concatenated(
self, stream: Stream, *, ordered: bool = True
self, stream: Stream, *, ordered: bool = True, ir_context: IRExecutionContext
) -> plc.Table:
"""
Extract the concatenated result.
Expand All @@ -111,12 +112,14 @@ async def extract_concatenated(
The stream to use for chunk extraction.
ordered: bool
Whether to extract the data in ordered or unordered fashion.
ir_context
Execution context to offload concatenation.

Returns
-------
The concatenated AllGather result.
"""
return await asyncio.to_thread(
return await ir_context.to_thread(
unpack_and_concat,
partitions=await self.allgather.extract_all(self.context, ordered=ordered),
stream=stream,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ async def _simple_top_or_bottom_k(
chunk = await evaluate_chunk(
context,
TableChunk.from_pylibcudf_table(
await allgather.extract_concatenated(stream, ordered=True),
await allgather.extract_concatenated(
stream, ordered=True, ir_context=ir_context
),
stream,
exclusive_view=True,
br=context.br(),
Expand Down Expand Up @@ -194,7 +196,9 @@ async def _compute_sort_boundaries(
allgather = AllGatherManager(context, comm, allgather_id)
with allgather.inserting() as inserter:
inserter.insert(comm.rank, chunk)
concat_table = await allgather.extract_concatenated(stream, ordered=True)
concat_table = await allgather.extract_concatenated(
stream, ordered=True, ir_context=ir_context
)
return _get_final_sort_boundaries(
DataFrame.from_table(
concat_table,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def execute_ir_on_rank(
Collected channel metadata.
"""
ir_context = IRExecutionContext(
get_cuda_stream=ctx.get_stream_from_pool, query_id=query_id
py_executor, get_cuda_stream=ctx.get_stream_from_pool, query_id=query_id
)
metadata_collector: list[ChannelMetadata] = []

Expand All @@ -455,7 +455,7 @@ def execute_ir_on_rank(
metadata_collector=metadata_collector,
)

run_actor_network(actors=nodes, py_executor=py_executor)
run_actor_network(ctx, actors=nodes)

messages = output.release()
chunks = [
Expand Down
8 changes: 4 additions & 4 deletions python/cudf_polars/cudf_polars/experimental/rapidsmpf/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ async def read_chunk(
context, size=estimated_chunk_bytes, net_memory_delta=estimated_chunk_bytes
)
):
df = await asyncio.to_thread(
df = await ir_context.to_thread(
scan.do_evaluate,
*scan._non_child_args,
context=ir_context,
Expand Down Expand Up @@ -828,7 +828,7 @@ async def sink_node(
).make_available_and_spill(context.br(), allow_overbooking=True)
df = chunk_to_frame(chunk, child_ir)
part_path = f"{path_root}.{str(i).zfill(count_width)}.{suffix}"
await asyncio.to_thread(
await ir_context.to_thread(
Sink.do_evaluate,
ir.sink.schema,
ir.sink.kind,
Expand All @@ -848,7 +848,7 @@ async def sink_node(
).make_available_and_spill(context.br(), allow_overbooking=True)
# Multiple chunks - use chunked writer
df = chunk_to_frame(chunk, child_ir)
writer_state = await asyncio.to_thread(
writer_state = await ir_context.to_thread(
_sink_to_file,
ir.sink.kind,
ir.sink.path,
Expand All @@ -861,7 +861,7 @@ async def sink_node(
if writer_state and ir.sink.kind == "Parquet":
# We know that with ir.sink.kind == "Parquet", writer_state being truthy
# means that it's a ChunkedParquetWriter.
await asyncio.to_thread(writer_state.close, []) # type: ignore[attr-defined]
await ir_context.to_thread(writer_state.close, []) # type: ignore[attr-defined]

# Signal completion on the metadata and data channels with empty results
stream = ir_context.get_cuda_stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from __future__ import annotations

import asyncio
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal

Expand Down Expand Up @@ -193,7 +192,7 @@ async def _collect_small_side_for_broadcast(
for s_id in range(len(chunks)):
inserter.insert(s_id, chunks.pop(0))
stream = ir_context.get_cuda_stream()
gathered = await allgather.extract_concatenated(stream)
gathered = await allgather.extract_concatenated(stream, ir_context=ir_context)
# When every rank inserted zero chunks, the AllGather has no schema
# to infer and returns a 0 column table. Substitute a properly typed
# empty table for the small side so downstream joins still match the
Expand Down Expand Up @@ -266,7 +265,7 @@ async def _broadcast_join_large_chunk(
await reserve_memory(context, size=input_bytes, net_memory_delta=0)
):
for sdf in dfs_to_join:
result = await asyncio.to_thread(
result = await ir_context.to_thread(
ir.do_evaluate,
*ir._non_child_args,
*([large_df, sdf] if broadcast_side == "right" else [sdf, large_df]),
Expand Down Expand Up @@ -469,7 +468,7 @@ async def _join_chunks(
with opaque_memory_usage(
await reserve_memory(context, size=input_bytes, net_memory_delta=0)
):
df = await asyncio.to_thread(
df = await ir_context.to_thread(
ir.do_evaluate,
*ir._non_child_args,
chunk_to_frame(left_chunk, left),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ async def default_node_multi(
for chunk, child in zip(ready_chunks, ir.children, strict=True)
]
with opaque_memory_usage(extra):
df = await asyncio.to_thread(
df = await ir_context.to_thread(
ir.do_evaluate,
*ir._non_child_args,
*dfs,
Expand Down
11 changes: 5 additions & 6 deletions python/cudf_polars/cudf_polars/experimental/rapidsmpf/over.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@

from __future__ import annotations

import asyncio
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, cast

Expand Down Expand Up @@ -207,7 +206,7 @@ async def _evaluate_broadcast_chunk(
net_memory_delta=0,
)
with opaque_memory_usage(extra):
return await asyncio.to_thread(
return await ir_context.to_thread(
_evaluate_ir_broadcast_sync,
chunk,
ir,
Expand Down Expand Up @@ -366,7 +365,7 @@ async def _allgather_and_broadcast(
net_memory_delta=0,
)
with opaque_memory_usage(extra):
partial = await asyncio.to_thread(
partial = await ir_context.to_thread(
_evaluate_chunk_sync,
chunk,
piecewise_ir,
Expand Down Expand Up @@ -491,7 +490,7 @@ async def _distribute_by_group(
# 1..nranks-1 sit idle on emit. Slice the duplicated input
# across ranks (e.g. stripe by row index) and stamp each
# slice with its target origin rank to distribute emit work.
stamped = await asyncio.to_thread(
stamped = await ir_context.to_thread(
_append_origin_stamps,
chunk,
chunk_index,
Expand Down Expand Up @@ -523,10 +522,10 @@ async def _evaluate_and_route_to_origin(
partition = TableChunk.from_pylibcudf_table(
extracted, stream, exclusive_view=True, br=context.br()
)
evaluated = await asyncio.to_thread(
evaluated = await ir_context.to_thread(
_evaluate_window_with_stamps, partition, ir, ir_context, stamps
)
routed, splits = await asyncio.to_thread(
routed, splits = await ir_context.to_thread(
_partition_by_origin_rank, evaluated, num_ranks, context.br()
)
if routed is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ async def concatenate_node(
del msg

# Extract concatenated result
result_table = await allgather.extract_concatenated(stream)
result_table = await allgather.extract_concatenated(
stream, ir_context=ir_context
)
if tracer is not None:
tracer.add_chunk(table=result_table)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ async def evaluate_chunk(
)
with opaque_memory_usage(extra):
for single_ir in irs:
chunk = await asyncio.to_thread(
chunk = await ir_context.to_thread(
_evaluate_chunk_sync, chunk, single_ir, ir_context, context.br()
)
return chunk
Expand Down Expand Up @@ -544,7 +544,7 @@ async def allgather_and_reduce(
inserter.insert(0, local_chunk)
stream = ir_context.get_cuda_stream()
concat_chunk = TableChunk.from_pylibcudf_table(
await allgather.extract_concatenated(stream),
await allgather.extract_concatenated(stream, ir_context=ir_context),
stream,
exclusive_view=True,
br=context.br(),
Expand Down Expand Up @@ -583,7 +583,7 @@ async def concat_batch(
net_memory_delta=0,
)
with opaque_memory_usage(extra):
df = await asyncio.to_thread(
df = await ir_context.to_thread(
_concat,
*[
DataFrame.from_table(
Expand Down
1 change: 1 addition & 0 deletions python/cudf_polars/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ ban-relative-imports = "all"

[tool.ruff.lint.flake8-tidy-imports.banned-api]
"asyncio.gather".msg = "Use gather_with_task_group instead."
"asyncio.to_thread".msg = "Use ir_context.to_thread instead."

[tool.ruff.lint.flake8-type-checking]
strict = true
Expand Down
Loading
Loading