From 5e935f8af1a895c8ae73e5bb226cd14b1d2f79bf Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Thu, 6 Feb 2020 11:31:23 -0800 Subject: [PATCH 1/6] Add `as_cuda_array` --- distributed/comm/ucx.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 175d628a0f6..a6f8e5e29ee 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -42,6 +42,16 @@ def init_once(): ucp = _ucp ucp.init(options=dask.config.get("ucx"), env_takes_precedence=True) + # Find the function, `as_cuda_array()`, to get array-likes from CUDA + try: + import numba.cuda + + as_cuda_array = lambda a: numba.cuda.as_cuda_array(a) + except ImportError: + + def as_cuda_array(a): + raise RuntimeError("In order to send/recv CUDA arrays, Numba is required") + # Find the function, `cuda_array()`, to use when allocating new CUDA arrays try: import rmm From bd3b0c9d18ec14840196ed752cc9df9446059015 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Thu, 6 Feb 2020 11:31:29 -0800 Subject: [PATCH 2/6] Coerce `DeviceBuffer` to Numba `DeviceNDArray` As we need to index into the object and `DeviceBuffer` currently lacks a way to index it, go ahead and coerce `DeviceBuffer`s to Numba `DeviceNDArray`s. This way we can be sure we will be able to index it when needed. --- distributed/comm/ucx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index a6f8e5e29ee..0f62105ab7f 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -57,7 +57,7 @@ def as_cuda_array(a): import rmm if hasattr(rmm, "DeviceBuffer"): - cuda_array = lambda n: rmm.DeviceBuffer(size=n) + cuda_array = lambda n: as_cuda_array(rmm.DeviceBuffer(size=n)) else: # pre-0.11.0 cuda_array = lambda n: rmm.device_array(n, dtype=np.uint8) except ImportError: From e3cbf12a74972d9d69e2dfc3daa7b94c30c4f7db Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Thu, 6 Feb 2020 11:31:32 -0800 Subject: [PATCH 3/6] Note that we always require Numba --- distributed/comm/ucx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 0f62105ab7f..e1a51715879 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -69,7 +69,7 @@ def as_cuda_array(a): def cuda_array(n): raise RuntimeError( - "In order to send/recv CUDA arrays, Numba or RMM is required" + "In order to send/recv CUDA arrays, Numba and RMM are required" ) From 9040498794bc603ada3d75fd4fa58cc9b86842da Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Thu, 6 Feb 2020 12:21:53 -0800 Subject: [PATCH 4/6] Aggregate allocations --- distributed/comm/tcp.py | 36 +++++++++++++++++++++++++----------- distributed/comm/ucx.py | 33 +++++++++++++++++++-------------- 2 files changed, 44 insertions(+), 25 deletions(-) diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index 7003053ce06..09f33e19632 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -26,6 +26,14 @@ from .core import Comm, Connector, Listener, CommClosedError, FatalCommClosedError from .utils import to_frames, from_frames, get_tcp_server_address, ensure_concrete_host +import asyncio +from itertools import starmap +from operator import add + +try: + from cytoolz import accumulate, cons, sliding_window +except ImportError: + from toolz import accumulate, cons, sliding_window logger = logging.getLogger(__name__) @@ -190,18 +198,24 @@ async def read(self, deserializers=None): lengths = await stream.read_bytes(8 * n_frames) lengths = struct.unpack("Q" * n_frames, lengths) - frames = [] - for length in lengths: - if length: - if self._iostream_has_read_into: - frame = bytearray(length) - n = await stream.read_into(frame) - assert n == length, (n, length) - else: + if self._iostream_has_read_into: + frame_arr = bytearray(sum(lengths)) + slices = starmap( + slice, sliding_window(2, accumulate(add, cons(0, sizes))) + ) + frames = [frames_arr[sl] for sl in slices] + recvd_lengths = await asyncio.gather([ + stream.read_into(f) for f in frames if len(f) + ]) + assert all(recvd_lengths == lengths), (recvd_lengths, lengths) + else: + frames = [] + for length in lengths: + if length: frame = await stream.read_bytes(length) - else: - frame = b"" - frames.append(frame) + else: + frame = b"" + frames.append(frame) except StreamClosedError as e: self.stream = None if not shutting_down(): diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index e1a51715879..aefbf3bd452 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -22,6 +22,15 @@ import dask import numpy as np +import asyncio +from itertools import starmap +from operator import add + +try: + from cytoolz import accumulate, cons, sliding_window +except ImportError: + from toolz import accumulate, cons, sliding_window + logger = logging.getLogger(__name__) @@ -188,20 +197,16 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): raise CommClosedError("While reading, the connection was closed") else: # Recv frames - frames = [] - for is_cuda, size in zip(is_cudas.tolist(), sizes.tolist()): - if size > 0: - if is_cuda: - frame = cuda_array(size) - else: - frame = np.empty(size, dtype=np.uint8) - await self.ep.recv(frame) - frames.append(frame) - else: - if is_cuda: - frames.append(cuda_array(size)) - else: - frames.append(b"") + frames_dev_arr = cuda_array(sum(sizes[is_cudas])) + frames_host_arr = np.empty(sum(sizes[~is_cudas]), dtype=np.uint8) + slices = starmap( + slice, sliding_window(2, accumulate(add, cons(0, sizes))) + ) + frames = [ + frames_dev_arr[sl] if is_cuda else frames_host_arr[sl] + for is_cuda, size in zip(is_cudas.tolist(), slices) + ] + await asyncio.gather([self.ep.recv(f) for f in frames if len(f)]) msg = await from_frames( frames, deserialize=self.deserialize, deserializers=deserializers ) From ce5d4685e3dddf1ebe08e1abdca4b97213272dd7 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Thu, 6 Feb 2020 22:02:29 -0800 Subject: [PATCH 5/6] Handle host + dev frames separately then merge --- distributed/comm/ucx.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index aefbf3bd452..bc32ade3415 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -197,15 +197,23 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): raise CommClosedError("While reading, the connection was closed") else: # Recv frames - frames_dev_arr = cuda_array(sum(sizes[is_cudas])) - frames_host_arr = np.empty(sum(sizes[~is_cudas]), dtype=np.uint8) - slices = starmap( - slice, sliding_window(2, accumulate(add, cons(0, sizes))) + sizes_dev = sizes[is_cudas] + sizes_host = sizes[~is_cudas] + frames_dev_arr = cuda_array(sum(sizes_dev)) + frames_host_arr = np.empty(sum(sizes_host), dtype=np.uint8) + slices_dev = starmap( + slice, sliding_window(2, accumulate(add, cons(0, sizes_dev))) ) - frames = [ - frames_dev_arr[sl] if is_cuda else frames_host_arr[sl] - for is_cuda, size in zip(is_cudas.tolist(), slices) - ] + slices_host = starmap( + slice, sliding_window(2, accumulate(add, cons(0, sizes_host))) + ) + frames_dev = [frames_dev_arr[sl] for sl in slices_dev] + frames_host = [frames_host_arr[sl] for sl in slices_host] + frames = len(sizes) * [None] + for i, f in zip(is_cudas.nonzero()[0], frames_dev): + frames[i] = f + for i, f in zip((~is_cudas).nonzero()[0], frames_host): + frames[i] = f await asyncio.gather([self.ep.recv(f) for f in frames if len(f)]) msg = await from_frames( frames, deserialize=self.deserialize, deserializers=deserializers From 59ca551b9063c224c17e33a769f4b6600d8d8a1f Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Mon, 16 Mar 2020 16:14:00 -0700 Subject: [PATCH 6/6] Import from `tlz` for optional `cytoolz` support --- distributed/comm/tcp.py | 5 +---- distributed/comm/ucx.py | 5 +---- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index 09f33e19632..67bf6175872 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -30,10 +30,7 @@ from itertools import starmap from operator import add -try: - from cytoolz import accumulate, cons, sliding_window -except ImportError: - from toolz import accumulate, cons, sliding_window +from tlz import accumulate, cons, sliding_window logger = logging.getLogger(__name__) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index bc32ade3415..8586122888b 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -26,10 +26,7 @@ from itertools import starmap from operator import add -try: - from cytoolz import accumulate, cons, sliding_window -except ImportError: - from toolz import accumulate, cons, sliding_window +from tlz import accumulate, cons, sliding_window logger = logging.getLogger(__name__)