Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 18 additions & 22 deletions distributed/comm/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ async def write(
frames = await to_frames(
msg, serializers=serializers, on_error=on_error
)
send_frames = [
each_frame for each_frame in frames if len(each_frame) > 0
]

# Send meta data
cuda_frames = np.array(
Expand All @@ -167,6 +170,7 @@ async def write(
await self.ep.send(
np.array([nbytes(f) for f in frames], dtype=np.uint64)
)

# Send frames

# It is necessary to first synchronize the default stream before start sending
Expand All @@ -177,10 +181,9 @@ async def write(
if cuda_frames.any():
synchronize_stream(0)

for frame in frames:
if nbytes(frame) > 0:
await self.ep.send(frame)
return sum(map(nbytes, frames))
for each_frame in send_frames:
await self.ep.send(each_frame)
return sum(map(nbytes, send_frames))
except (ucp.exceptions.UCXBaseException):
self.abort()
raise CommClosedError("While writing, the connection was closed")
Expand All @@ -206,30 +209,23 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")):
raise CommClosedError("While reading, the connection was closed")
else:
# Recv frames
frames = []
for is_cuda, size in zip(is_cudas.tolist(), sizes.tolist()):
if size > 0:
if is_cuda:
frame = cuda_array(size)
else:
frame = np.empty(size, dtype=np.uint8)
frames.append(frame)
else:
if is_cuda:
frames.append(cuda_array(size))
else:
frames.append(b"")
frames = [
cuda_array(each_size)
if is_cuda
else np.empty(each_size, dtype=np.uint8)
for is_cuda, each_size in zip(is_cudas.tolist(), sizes.tolist())
]
recv_frames = [
each_frame for each_frame in frames if len(each_frame) > 0
]

# It is necessary to first populate `frames` with CUDA arrays and synchronize
# the default stream before starting receiving to ensure buffers have been allocated
if is_cudas.any():
synchronize_stream(0)
for i, (is_cuda, size) in enumerate(
zip(is_cudas.tolist(), sizes.tolist())
):
if size > 0:
await self.ep.recv(frames[i])

for each_frame in recv_frames:
await self.ep.recv(each_frame)
msg = await from_frames(
frames, deserialize=self.deserialize, deserializers=deserializers
)
Expand Down