Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 117 additions & 9 deletions connectrpc/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,32 @@ pub fn full_body(b: Bytes) -> ClientBody {
/// under 8 KiB per header set.
const RESPONSE_BUFFER_TRAILER_SLACK: usize = 64 * 1024;

/// Return the end offset of a complete gRPC-Web trailer frame, if present.
fn grpc_web_trailer_frame_end(data: &[u8]) -> Option<usize> {
let mut offset = 0;

while data.len().saturating_sub(offset) >= crate::envelope::HEADER_SIZE {
let length = u32::from_be_bytes([
data[offset + 1],
data[offset + 2],
data[offset + 3],
data[offset + 4],
]) as usize;
let frame_end = offset
.checked_add(crate::envelope::HEADER_SIZE)?
.checked_add(length)?;
if frame_end > data.len() {
return None;
}
if data[offset] & 0x80 != 0 {
return Some(frame_end);
}
offset = frame_end;
}

None
}

/// Trait for types that can be used as ConnectRPC client transports.
///
/// This is automatically implemented for any `tower::Service` that handles
Expand Down Expand Up @@ -1660,12 +1686,20 @@ where
if !data.is_empty() {
has_body_data = true;
}
if buf.len().saturating_add(data.len()) > max_buf_size {
let remaining = max_buf_size.saturating_sub(buf.len());
let append_len = data.len().min(remaining);
buf.extend_from_slice(&data[..append_len]);
if matches!(config.protocol, Protocol::GrpcWeb)
&& let Some(trailer_end) = grpc_web_trailer_frame_end(&buf)
{
buf.truncate(trailer_end);
break;
}
if append_len < data.len() {
return Err(ConnectError::resource_exhausted(format!(
"response body size exceeds limit {max_buf_size}"
)));
}
buf.extend_from_slice(&data);
}
} else if frame.is_trailers()
&& let Ok(trailers) = frame.into_trailers()
Expand Down Expand Up @@ -1718,6 +1752,13 @@ where
}
};

if message_count > 0 {
let mut err =
ConnectError::unimplemented("received multiple messages for unary response");
err.set_response_headers(resp_headers);
return Err(err);
}

let data = if envelope.is_compressed() {
let enc = response_encoding.as_deref().ok_or_else(|| {
ConnectError::internal("received compressed message without grpc-encoding header")
Expand Down Expand Up @@ -1755,13 +1796,6 @@ where
return Err(err);
}

// Validate message count for unary/client-stream (expect exactly 1)
if message_count > 1 {
let mut err = ConnectError::unimplemented("received multiple messages for unary response");
err.set_response_headers(resp_headers);
return Err(err);
}

// For missing grpc-status, synthesize an error.
// If a deadline was set and has passed, map to DEADLINE_EXCEEDED per the gRPC
// spec: RST_STREAM CANCEL is upgraded to DeadlineExceeded when the deadline
Expand Down Expand Up @@ -3977,6 +4011,80 @@ mod tests {
assert_eq!(headers.get("grpc-status").unwrap().to_str().unwrap(), "0");
}

#[tokio::test]
async fn grpc_unary_rejects_second_message_before_decompression() {
use buffa_types::google::protobuf::__buffa::view::StringValueView;

let mut body = BytesMut::new();
body.extend_from_slice(&Envelope::data(Bytes::new()).encode());
body.extend_from_slice(&Envelope::compressed(Bytes::from_static(b"not-gzip")).encode());

let response = Response::builder()
.header(http::header::CONTENT_TYPE, "application/grpc+proto")
.body(Full::new(body.freeze()))
.unwrap();
let config =
ClientConfig::new("http://localhost".parse().unwrap()).with_protocol(Protocol::Grpc);

let err = parse_grpc_unary_response::<_, StringValueView<'static>>(
response,
&config,
&CallOptions::default(),
None,
)
.await
.unwrap_err();
assert_eq!(err.code, ErrorCode::Unimplemented);
assert_eq!(
err.message.as_deref(),
Some("received multiple messages for unary response")
);
}

#[tokio::test]
async fn grpc_web_unary_stops_reading_after_trailer_frame() {
use buffa_types::google::protobuf::__buffa::view::StringValueView;

let mut body = BytesMut::new();
body.extend_from_slice(&Envelope::data(Bytes::from_static(b"\x0a\x02hi")).encode());
let trailer_payload = b"grpc-status: 0\r\n";
body.extend_from_slice(&[0x80]);
body.extend_from_slice(&(trailer_payload.len() as u32).to_be_bytes());
body.extend_from_slice(trailer_payload);

let (tx, rx) = tokio::sync::mpsc::channel(2);
tx.send(Ok(body.freeze())).await.unwrap();
tx.send(Ok(Bytes::from_static(b"server is still writing")))
.await
.unwrap();

// Keep the sender alive: a complete trailers frame must finish the
// response without waiting for EOF or consuming the queued bytes.
let response = Response::builder()
.header(http::header::CONTENT_TYPE, "application/grpc-web+proto")
.body(ChannelBody { rx })
.unwrap();
let config =
ClientConfig::new("http://localhost".parse().unwrap()).with_protocol(Protocol::GrpcWeb);

let response = tokio::time::timeout(
Duration::from_secs(1),
parse_grpc_unary_response::<_, StringValueView<'static>>(
response,
&config,
&CallOptions::default(),
None,
),
)
.await
.expect("parser should stop after the gRPC-Web trailer frame")
.unwrap();
assert_eq!(
response.trailers().get("grpc-status").unwrap(),
http::HeaderValue::from_static("0")
);
}

// ========================================================================
// Content type helper tests
// ========================================================================
Expand Down