Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ with the [Rust 0.x convention](https://doc.rust-lang.org/cargo/reference/semver.
breaking changes increment the minor version (0.2 → 0.3), additive changes
increment the patch version.

## [Unreleased]

### Changed

- Malformed gzip and zstd compressed payloads now return `invalid_argument`
instead of `internal` ([#139]).

## [0.6.1] - 2026-05-27

A patch release focused on the robustness of the streaming request and
Expand Down Expand Up @@ -46,6 +53,7 @@ now 1.6.
[#131]: https://github.com/anthropics/connect-rust/pull/131
[#132]: https://github.com/anthropics/connect-rust/pull/132
[#133]: https://github.com/anthropics/connect-rust/pull/133
[#139]: https://github.com/anthropics/connect-rust/issues/139

## [0.6.0] - 2026-05-20

Expand Down
95 changes: 79 additions & 16 deletions connectrpc/src/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ use tokio::io::AsyncRead;

use crate::error::ConnectError;

#[cfg(any(feature = "gzip", feature = "zstd"))]
fn malformed_compressed_payload(message: impl Into<String>) -> ConnectError {
ConnectError::invalid_argument(message)
}

// ============================================================================
// Streaming Types
// ============================================================================
Expand Down Expand Up @@ -780,7 +785,9 @@ impl GzipProvider {
&mut output,
flate2::FlushDecompress::None,
)
.map_err(|e| ConnectError::internal(format!("gzip decompression failed: {e}")))?;
.map_err(|e| {
malformed_compressed_payload(format!("gzip decompression failed: {e}"))
})?;
match status {
flate2::Status::StreamEnd => break,
flate2::Status::Ok => {}
Expand All @@ -790,7 +797,7 @@ impl GzipProvider {
// without an end-of-stream marker. Without this check the
// loop would never terminate on such input.
flate2::Status::BufError => {
return Err(ConnectError::internal(
return Err(malformed_compressed_payload(
"gzip decompression stalled: truncated or invalid deflate stream",
));
}
Expand All @@ -809,7 +816,9 @@ impl GzipProvider {
let deflate_consumed = (decompressor.total_in() - start_in) as usize;
let trailer_start = deflate_consumed;
if stream_data.len() < trailer_start + 8 {
return Err(ConnectError::internal("gzip data too short for trailer"));
return Err(malformed_compressed_payload(
"gzip data too short for trailer",
));
}
let trailer = &stream_data[trailer_start..trailer_start + 8];

Expand All @@ -819,10 +828,10 @@ impl GzipProvider {
let mut crc = flate2::Crc::new();
crc.update(&output);
if crc.sum() != expected_crc {
return Err(ConnectError::internal("gzip CRC32 mismatch"));
return Err(malformed_compressed_payload("gzip CRC32 mismatch"));
}
if expected_size != (output.len() as u32) {
return Err(ConnectError::internal("gzip size mismatch"));
return Err(malformed_compressed_payload("gzip size mismatch"));
}

Ok(Bytes::from(output))
Expand All @@ -834,13 +843,15 @@ impl GzipProvider {
#[cfg(feature = "gzip")]
fn gzip_header_len(data: &[u8]) -> Result<usize, ConnectError> {
if data.len() < 10 {
return Err(ConnectError::internal("gzip data too short for header"));
return Err(malformed_compressed_payload(
"gzip data too short for header",
));
}
if data[0] != 0x1f || data[1] != 0x8b {
return Err(ConnectError::internal("invalid gzip magic"));
return Err(malformed_compressed_payload("invalid gzip magic"));
}
if data[2] != 0x08 {
return Err(ConnectError::internal(
return Err(malformed_compressed_payload(
"unsupported gzip compression method",
));
}
Expand All @@ -850,7 +861,7 @@ fn gzip_header_len(data: &[u8]) -> Result<usize, ConnectError> {
// FEXTRA
if flags & 0x04 != 0 {
if pos + 2 > data.len() {
return Err(ConnectError::internal("truncated gzip header"));
return Err(malformed_compressed_payload("truncated gzip header"));
}
let xlen = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
pos += 2 + xlen;
Expand All @@ -862,7 +873,7 @@ fn gzip_header_len(data: &[u8]) -> Result<usize, ConnectError> {
pos += 1;
}
if pos >= data.len() {
return Err(ConnectError::internal("truncated gzip header"));
return Err(malformed_compressed_payload("truncated gzip header"));
}
pos += 1; // skip null terminator
}
Expand All @@ -873,7 +884,7 @@ fn gzip_header_len(data: &[u8]) -> Result<usize, ConnectError> {
pos += 1;
}
if pos >= data.len() {
return Err(ConnectError::internal("truncated gzip header"));
return Err(malformed_compressed_payload("truncated gzip header"));
}
pos += 1; // skip null terminator
}
Expand All @@ -884,7 +895,7 @@ fn gzip_header_len(data: &[u8]) -> Result<usize, ConnectError> {
}

if pos > data.len() {
return Err(ConnectError::internal("truncated gzip header"));
return Err(malformed_compressed_payload("truncated gzip header"));
}
Ok(pos)
}
Expand Down Expand Up @@ -1034,7 +1045,7 @@ impl ZstdProvider {
use std::io::Read;

let mut decoder = zstd::Decoder::new(data)
.map_err(|e| ConnectError::internal(format!("zstd decompression failed: {e}")))?;
.map_err(|e| malformed_compressed_payload(format!("zstd decompression failed: {e}")))?;

let mut decompressed =
Vec::with_capacity(initial_decompress_capacity(data.len(), 4, max_size));
Expand All @@ -1047,7 +1058,7 @@ impl ZstdProvider {
.take((limit as u64).saturating_add(1))
.read_to_end(&mut decompressed)
.map_err(|e| {
ConnectError::internal(format!("zstd decompression failed: {e}"))
malformed_compressed_payload(format!("zstd decompression failed: {e}"))
})?;
if decompressed.len() > limit {
return Err(ConnectError::resource_exhausted(format!(
Expand All @@ -1057,7 +1068,7 @@ impl ZstdProvider {
}
None => {
decoder.read_to_end(&mut decompressed).map_err(|e| {
ConnectError::internal(format!("zstd decompression failed: {e}"))
malformed_compressed_payload(format!("zstd decompression failed: {e}"))
})?;
}
}
Expand Down Expand Up @@ -1087,7 +1098,7 @@ impl CompressionProvider for ZstdProvider {
data: &'a [u8],
) -> Result<Box<dyn std::io::Read + 'a>, ConnectError> {
let decoder = zstd::Decoder::new(data)
.map_err(|e| ConnectError::internal(format!("zstd decompression failed: {e}")))?;
.map_err(|e| malformed_compressed_payload(format!("zstd decompression failed: {e}")))?;
Ok(Box::new(decoder))
}

Expand Down Expand Up @@ -1147,6 +1158,15 @@ fn initial_decompress_capacity(
mod tests {
use super::*;

#[cfg(any(feature = "gzip", feature = "zstd"))]
fn assert_invalid_argument(err: &ConnectError) {
assert_eq!(
err.code,
crate::error::ErrorCode::InvalidArgument,
"{err:?}"
);
}

#[test]
fn test_empty_registry() {
let registry = CompressionRegistry::new();
Expand Down Expand Up @@ -1549,6 +1569,7 @@ mod tests {
GzipProvider::default().decompress_with_limit(&MINIMAL_GZIP_HEADER, 1024)
})
.expect_err("header-only gzip member must be rejected");
assert_invalid_argument(&err);
assert!(
err.to_string()
.contains("truncated or invalid deflate stream"),
Expand All @@ -1571,6 +1592,7 @@ mod tests {
provider.decompress_with_limit(&compressed[..14], 1024)
})
.expect_err("truncated deflate stream must be rejected");
assert_invalid_argument(&err);
// Which check rejects the prefix depends on where the deflate encoder
// happened to place block boundaries: an incomplete block is caught by
// the stalled-stream handling, while a prefix that ends on a complete
Expand Down Expand Up @@ -1598,6 +1620,7 @@ mod tests {
)
})
.expect_err("truncated gzip payload must be rejected via the registry");
assert_invalid_argument(&err);
assert!(
err.to_string()
.contains("truncated or invalid deflate stream"),
Expand All @@ -1620,12 +1643,43 @@ mod tests {
let err = provider
.decompress_with_limit(missing_trailer, 1024)
.expect_err("gzip member without its trailer must be rejected");
assert_invalid_argument(&err);
assert!(
err.to_string().contains("too short for trailer"),
"unexpected error message: {err}"
);
}

#[cfg(feature = "gzip")]
#[test]
fn test_gzip_malformed_payloads_are_invalid_argument() {
let provider = GzipProvider::default();

let err = provider
.decompress_with_limit(b"not gzip", 1024)
.expect_err("bad gzip header must be rejected");
assert_invalid_argument(&err);

let data = b"hello world, this is a test of gzip compression";
let compressed = provider.compress(data).unwrap();
let trailer_start = compressed.len() - 8;

let mut bad_crc = compressed.to_vec();
bad_crc[trailer_start] ^= 0xff;
let err = provider
.decompress_with_limit(&bad_crc, 1024)
.expect_err("gzip CRC mismatch must be rejected");
assert_invalid_argument(&err);

let mut bad_size = compressed.to_vec();
let last = bad_size.len() - 1;
bad_size[last] ^= 0xff;
let err = provider
.decompress_with_limit(&bad_size, 1024)
.expect_err("gzip size mismatch must be rejected");
assert_invalid_argument(&err);
}

#[cfg(feature = "gzip")]
#[test]
fn test_gzip_registry() {
Expand Down Expand Up @@ -1657,6 +1711,15 @@ mod tests {
assert_eq!(&decompressed[..], data);
}

#[cfg(feature = "zstd")]
#[test]
fn test_zstd_malformed_payload_is_invalid_argument() {
let err = ZstdProvider::default()
.decompress_with_limit(b"not zstd", 1024)
.expect_err("malformed zstd payload must be rejected");
assert_invalid_argument(&err);
}

#[cfg(feature = "zstd")]
#[test]
fn test_zstd_high_compression_ratio() {
Expand Down