diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index fc187dcc614..a29441ec4d5 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -156,6 +156,9 @@ async def write( frames = await to_frames( msg, serializers=serializers, on_error=on_error ) + send_frames = [ + each_frame for each_frame in frames if len(each_frame) > 0 + ] # Send meta data cuda_frames = np.array( @@ -167,6 +170,7 @@ async def write( await self.ep.send( np.array([nbytes(f) for f in frames], dtype=np.uint64) ) + # Send frames # It is necessary to first synchronize the default stream before start sending @@ -177,10 +181,9 @@ async def write( if cuda_frames.any(): synchronize_stream(0) - for frame in frames: - if nbytes(frame) > 0: - await self.ep.send(frame) - return sum(map(nbytes, frames)) + for each_frame in send_frames: + await self.ep.send(each_frame) + return sum(map(nbytes, send_frames)) except (ucp.exceptions.UCXBaseException): self.abort() raise CommClosedError("While writing, the connection was closed") @@ -206,30 +209,23 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")): raise CommClosedError("While reading, the connection was closed") else: # Recv frames - frames = [] - for is_cuda, size in zip(is_cudas.tolist(), sizes.tolist()): - if size > 0: - if is_cuda: - frame = cuda_array(size) - else: - frame = np.empty(size, dtype=np.uint8) - frames.append(frame) - else: - if is_cuda: - frames.append(cuda_array(size)) - else: - frames.append(b"") + frames = [ + cuda_array(each_size) + if is_cuda + else np.empty(each_size, dtype=np.uint8) + for is_cuda, each_size in zip(is_cudas.tolist(), sizes.tolist()) + ] + recv_frames = [ + each_frame for each_frame in frames if len(each_frame) > 0 + ] # It is necessary to first populate `frames` with CUDA arrays and synchronize # the default stream before starting receiving to ensure buffers have been allocated if is_cudas.any(): synchronize_stream(0) - for i, (is_cuda, size) in enumerate( - zip(is_cudas.tolist(), sizes.tolist()) - ): - if size > 0: - await self.ep.recv(frames[i]) + for each_frame in recv_frames: + await self.ep.recv(each_frame) msg = await from_frames( frames, deserialize=self.deserialize, deserializers=deserializers )