diff --git a/CHANGELOG.md b/CHANGELOG.md index 003d84d..67ac1c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,13 @@ with the [Rust 0.x convention](https://doc.rust-lang.org/cargo/reference/semver. breaking changes increment the minor version (0.2 → 0.3), additive changes increment the patch version. +## [Unreleased] + +### Changed + +- Malformed gzip and zstd compressed payloads now return `invalid_argument` + instead of `internal` ([#139]). + ## [0.6.1] - 2026-05-27 A patch release focused on the robustness of the streaming request and @@ -46,6 +53,7 @@ now 1.6. [#131]: https://github.com/anthropics/connect-rust/pull/131 [#132]: https://github.com/anthropics/connect-rust/pull/132 [#133]: https://github.com/anthropics/connect-rust/pull/133 +[#139]: https://github.com/anthropics/connect-rust/issues/139 ## [0.6.0] - 2026-05-20 diff --git a/connectrpc/src/compression.rs b/connectrpc/src/compression.rs index e9c9dd6..729aa42 100644 --- a/connectrpc/src/compression.rs +++ b/connectrpc/src/compression.rs @@ -65,6 +65,11 @@ use tokio::io::AsyncRead; use crate::error::ConnectError; +#[cfg(any(feature = "gzip", feature = "zstd"))] +fn malformed_compressed_payload(message: impl Into) -> ConnectError { + ConnectError::invalid_argument(message) +} + // ============================================================================ // Streaming Types // ============================================================================ @@ -780,7 +785,9 @@ impl GzipProvider { &mut output, flate2::FlushDecompress::None, ) - .map_err(|e| ConnectError::internal(format!("gzip decompression failed: {e}")))?; + .map_err(|e| { + malformed_compressed_payload(format!("gzip decompression failed: {e}")) + })?; match status { flate2::Status::StreamEnd => break, flate2::Status::Ok => {} @@ -790,7 +797,7 @@ impl GzipProvider { // without an end-of-stream marker. Without this check the // loop would never terminate on such input. flate2::Status::BufError => { - return Err(ConnectError::internal( + return Err(malformed_compressed_payload( "gzip decompression stalled: truncated or invalid deflate stream", )); } @@ -809,7 +816,9 @@ impl GzipProvider { let deflate_consumed = (decompressor.total_in() - start_in) as usize; let trailer_start = deflate_consumed; if stream_data.len() < trailer_start + 8 { - return Err(ConnectError::internal("gzip data too short for trailer")); + return Err(malformed_compressed_payload( + "gzip data too short for trailer", + )); } let trailer = &stream_data[trailer_start..trailer_start + 8]; @@ -819,10 +828,10 @@ impl GzipProvider { let mut crc = flate2::Crc::new(); crc.update(&output); if crc.sum() != expected_crc { - return Err(ConnectError::internal("gzip CRC32 mismatch")); + return Err(malformed_compressed_payload("gzip CRC32 mismatch")); } if expected_size != (output.len() as u32) { - return Err(ConnectError::internal("gzip size mismatch")); + return Err(malformed_compressed_payload("gzip size mismatch")); } Ok(Bytes::from(output)) @@ -834,13 +843,15 @@ impl GzipProvider { #[cfg(feature = "gzip")] fn gzip_header_len(data: &[u8]) -> Result { if data.len() < 10 { - return Err(ConnectError::internal("gzip data too short for header")); + return Err(malformed_compressed_payload( + "gzip data too short for header", + )); } if data[0] != 0x1f || data[1] != 0x8b { - return Err(ConnectError::internal("invalid gzip magic")); + return Err(malformed_compressed_payload("invalid gzip magic")); } if data[2] != 0x08 { - return Err(ConnectError::internal( + return Err(malformed_compressed_payload( "unsupported gzip compression method", )); } @@ -850,7 +861,7 @@ fn gzip_header_len(data: &[u8]) -> Result { // FEXTRA if flags & 0x04 != 0 { if pos + 2 > data.len() { - return Err(ConnectError::internal("truncated gzip header")); + return Err(malformed_compressed_payload("truncated gzip header")); } let xlen = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize; pos += 2 + xlen; @@ -862,7 +873,7 @@ fn gzip_header_len(data: &[u8]) -> Result { pos += 1; } if pos >= data.len() { - return Err(ConnectError::internal("truncated gzip header")); + return Err(malformed_compressed_payload("truncated gzip header")); } pos += 1; // skip null terminator } @@ -873,7 +884,7 @@ fn gzip_header_len(data: &[u8]) -> Result { pos += 1; } if pos >= data.len() { - return Err(ConnectError::internal("truncated gzip header")); + return Err(malformed_compressed_payload("truncated gzip header")); } pos += 1; // skip null terminator } @@ -884,7 +895,7 @@ fn gzip_header_len(data: &[u8]) -> Result { } if pos > data.len() { - return Err(ConnectError::internal("truncated gzip header")); + return Err(malformed_compressed_payload("truncated gzip header")); } Ok(pos) } @@ -1034,7 +1045,7 @@ impl ZstdProvider { use std::io::Read; let mut decoder = zstd::Decoder::new(data) - .map_err(|e| ConnectError::internal(format!("zstd decompression failed: {e}")))?; + .map_err(|e| malformed_compressed_payload(format!("zstd decompression failed: {e}")))?; let mut decompressed = Vec::with_capacity(initial_decompress_capacity(data.len(), 4, max_size)); @@ -1047,7 +1058,7 @@ impl ZstdProvider { .take((limit as u64).saturating_add(1)) .read_to_end(&mut decompressed) .map_err(|e| { - ConnectError::internal(format!("zstd decompression failed: {e}")) + malformed_compressed_payload(format!("zstd decompression failed: {e}")) })?; if decompressed.len() > limit { return Err(ConnectError::resource_exhausted(format!( @@ -1057,7 +1068,7 @@ impl ZstdProvider { } None => { decoder.read_to_end(&mut decompressed).map_err(|e| { - ConnectError::internal(format!("zstd decompression failed: {e}")) + malformed_compressed_payload(format!("zstd decompression failed: {e}")) })?; } } @@ -1087,7 +1098,7 @@ impl CompressionProvider for ZstdProvider { data: &'a [u8], ) -> Result, ConnectError> { let decoder = zstd::Decoder::new(data) - .map_err(|e| ConnectError::internal(format!("zstd decompression failed: {e}")))?; + .map_err(|e| malformed_compressed_payload(format!("zstd decompression failed: {e}")))?; Ok(Box::new(decoder)) } @@ -1147,6 +1158,15 @@ fn initial_decompress_capacity( mod tests { use super::*; + #[cfg(any(feature = "gzip", feature = "zstd"))] + fn assert_invalid_argument(err: &ConnectError) { + assert_eq!( + err.code, + crate::error::ErrorCode::InvalidArgument, + "{err:?}" + ); + } + #[test] fn test_empty_registry() { let registry = CompressionRegistry::new(); @@ -1549,6 +1569,7 @@ mod tests { GzipProvider::default().decompress_with_limit(&MINIMAL_GZIP_HEADER, 1024) }) .expect_err("header-only gzip member must be rejected"); + assert_invalid_argument(&err); assert!( err.to_string() .contains("truncated or invalid deflate stream"), @@ -1571,6 +1592,7 @@ mod tests { provider.decompress_with_limit(&compressed[..14], 1024) }) .expect_err("truncated deflate stream must be rejected"); + assert_invalid_argument(&err); // Which check rejects the prefix depends on where the deflate encoder // happened to place block boundaries: an incomplete block is caught by // the stalled-stream handling, while a prefix that ends on a complete @@ -1598,6 +1620,7 @@ mod tests { ) }) .expect_err("truncated gzip payload must be rejected via the registry"); + assert_invalid_argument(&err); assert!( err.to_string() .contains("truncated or invalid deflate stream"), @@ -1620,12 +1643,43 @@ mod tests { let err = provider .decompress_with_limit(missing_trailer, 1024) .expect_err("gzip member without its trailer must be rejected"); + assert_invalid_argument(&err); assert!( err.to_string().contains("too short for trailer"), "unexpected error message: {err}" ); } + #[cfg(feature = "gzip")] + #[test] + fn test_gzip_malformed_payloads_are_invalid_argument() { + let provider = GzipProvider::default(); + + let err = provider + .decompress_with_limit(b"not gzip", 1024) + .expect_err("bad gzip header must be rejected"); + assert_invalid_argument(&err); + + let data = b"hello world, this is a test of gzip compression"; + let compressed = provider.compress(data).unwrap(); + let trailer_start = compressed.len() - 8; + + let mut bad_crc = compressed.to_vec(); + bad_crc[trailer_start] ^= 0xff; + let err = provider + .decompress_with_limit(&bad_crc, 1024) + .expect_err("gzip CRC mismatch must be rejected"); + assert_invalid_argument(&err); + + let mut bad_size = compressed.to_vec(); + let last = bad_size.len() - 1; + bad_size[last] ^= 0xff; + let err = provider + .decompress_with_limit(&bad_size, 1024) + .expect_err("gzip size mismatch must be rejected"); + assert_invalid_argument(&err); + } + #[cfg(feature = "gzip")] #[test] fn test_gzip_registry() { @@ -1657,6 +1711,15 @@ mod tests { assert_eq!(&decompressed[..], data); } + #[cfg(feature = "zstd")] + #[test] + fn test_zstd_malformed_payload_is_invalid_argument() { + let err = ZstdProvider::default() + .decompress_with_limit(b"not zstd", 1024) + .expect_err("malformed zstd payload must be rejected"); + assert_invalid_argument(&err); + } + #[cfg(feature = "zstd")] #[test] fn test_zstd_high_compression_ratio() {