diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index 7003053ce06..67bf6175872 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -26,6 +26,11 @@ from .core import Comm, Connector, Listener, CommClosedError, FatalCommClosedError from .utils import to_frames, from_frames, get_tcp_server_address, ensure_concrete_host +import asyncio +from itertools import starmap +from operator import add + +from tlz import accumulate, cons, sliding_window logger = logging.getLogger(__name__) @@ -190,18 +195,24 @@ async def read(self, deserializers=None): lengths = await stream.read_bytes(8 * n_frames) lengths = struct.unpack("Q" * n_frames, lengths) - frames = [] - for length in lengths: - if length: - if self._iostream_has_read_into: - frame = bytearray(length) - n = await stream.read_into(frame) - assert n == length, (n, length) - else: + if self._iostream_has_read_into: + frame_arr = bytearray(sum(lengths)) + slices = starmap( + slice, sliding_window(2, accumulate(add, cons(0, sizes))) + ) + frames = [frames_arr[sl] for sl in slices] + recvd_lengths = await asyncio.gather([ + stream.read_into(f) for f in frames if len(f) + ]) + assert all(recvd_lengths == lengths), (recvd_lengths, lengths) + else: + frames = [] + for length in lengths: + if length: frame = await stream.read_bytes(length) - else: - frame = b"" - frames.append(frame) + else: + frame = b"" + frames.append(frame) except StreamClosedError as e: self.stream = None if not shutting_down(): diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 175d628a0f6..8586122888b 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -22,6 +22,12 @@ import dask import numpy as np +import asyncio +from itertools import starmap +from operator import add + +from tlz import accumulate, cons, sliding_window + logger = logging.getLogger(__name__) @@ -42,12 +48,22 @@ def init_once(): ucp = _ucp ucp.init(options=dask.config.get("ucx"), env_takes_precedence=True) + # Find the function, `as_cuda_array()`, to get array-likes from CUDA + try: + import numba.cuda + + as_cuda_array = lambda a: numba.cuda.as_cuda_array(a) + except ImportError: + + def as_cuda_array(a): + raise RuntimeError("In order to send/recv CUDA arrays, Numba is required") + # Find the function, `cuda_array()`, to use when allocating new CUDA arrays try: import rmm if hasattr(rmm, "DeviceBuffer"): - cuda_array = lambda n: rmm.DeviceBuffer(size=n) + cuda_array = lambda n: as_cuda_array(rmm.DeviceBuffer(size=n)) else: # pre-0.11.0 cuda_array = lambda n: rmm.device_array(n, dtype=np.uint8) except ImportError: @@ -59,7 +75,7 @@ def init_once(): def cuda_array(n): raise RuntimeError( - "In order to send/recv CUDA arrays, Numba or RMM is required" + "In order to send/recv CUDA arrays, Numba and RMM are required" ) @@ -178,20 +194,24 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): raise CommClosedError("While reading, the connection was closed") else: # Recv frames - frames = [] - for is_cuda, size in zip(is_cudas.tolist(), sizes.tolist()): - if size > 0: - if is_cuda: - frame = cuda_array(size) - else: - frame = np.empty(size, dtype=np.uint8) - await self.ep.recv(frame) - frames.append(frame) - else: - if is_cuda: - frames.append(cuda_array(size)) - else: - frames.append(b"") + sizes_dev = sizes[is_cudas] + sizes_host = sizes[~is_cudas] + frames_dev_arr = cuda_array(sum(sizes_dev)) + frames_host_arr = np.empty(sum(sizes_host), dtype=np.uint8) + slices_dev = starmap( + slice, sliding_window(2, accumulate(add, cons(0, sizes_dev))) + ) + slices_host = starmap( + slice, sliding_window(2, accumulate(add, cons(0, sizes_host))) + ) + frames_dev = [frames_dev_arr[sl] for sl in slices_dev] + frames_host = [frames_host_arr[sl] for sl in slices_host] + frames = len(sizes) * [None] + for i, f in zip(is_cudas.nonzero()[0], frames_dev): + frames[i] = f + for i, f in zip((~is_cudas).nonzero()[0], frames_host): + frames[i] = f + await asyncio.gather([self.ep.recv(f) for f in frames if len(f)]) msg = await from_frames( frames, deserialize=self.deserialize, deserializers=deserializers )