diff --git a/Cargo.lock b/Cargo.lock index 69840ec4..775b1452 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2405,7 +2405,9 @@ dependencies = [ "praxis-proxy-filter", "praxis-proxy-proto", "prost-types", + "prost-wkt-types", "serde", + "sync_wrapper", "thiserror 2.0.18", "tokio", "tokio-stream", diff --git a/Cargo.toml b/Cargo.toml index 73356875..ca1c1449 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -81,6 +81,7 @@ rustls-pemfile = "2.2.0" serde_yaml = { package = "yaml_serde", version = "0.10.4" } sqlx = { version = "0.8.6", default-features = false, features = ["runtime-tokio-rustls", "sqlite", "postgres"] } syn = { version = "2.0.118", features = ["full", "extra-traits", "visit"] } +sync_wrapper = "1.0.2" tempfile = "3.27.0" tokio-rustls = "0.26.4" tokio-tungstenite = "0.29.0" diff --git a/filter/ext-proc/Cargo.toml b/filter/ext-proc/Cargo.toml index e3ec6438..51eba126 100644 --- a/filter/ext-proc/Cargo.toml +++ b/filter/ext-proc/Cargo.toml @@ -22,11 +22,15 @@ bytes = { workspace = true } futures = { workspace = true } http = { workspace = true } praxis-filter = { workspace = true } +prost-types = { workspace = true } praxis-proto = { workspace = true } +prost-wkt-types = { workspace = true } serde = { workspace = true } serde_yaml = { workspace = true } +sync_wrapper = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true } +tokio-stream = { workspace = true } tonic = { workspace = true } tracing = { workspace = true } diff --git a/filter/ext-proc/src/duplex.rs b/filter/ext-proc/src/duplex.rs new file mode 100644 index 00000000..1b99b79c --- /dev/null +++ b/filter/ext-proc/src/duplex.rs @@ -0,0 +1,1349 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 Praxis Contributors + +//! Persistent bidirectional `ext_proc` exchange state machine. +//! +//! Opens one [`ExternalProcessor.Process`] gRPC stream per HTTP +//! request and sends/receives multiple messages across request +//! and response phases. +//! +//! Sending and receiving are independent. Response envelopes are +//! classified into typed [`ExchangeEvent`] variants with +//! processor-output validation. Request and response directions +//! are tracked independently for both outbound and +//! processor-output phases. +//! +//! # State Domains +//! +//! The exchange tracks six orthogonal state domains: +//! +//! 1. **`terminal`** — shared terminal flag. +//! 2. **`request_send`** — outbound send phase for the request direction. +//! 3. **`response_send`** — outbound send phase for the response direction. +//! 4. **`request_output`** — processor-output phase for the request direction. +//! 5. **`response_output`** — processor-output phase for the response direction. +//! 6. **`active_processing`** — optional per-message processing state with deadline and override tracking. +//! +//! # Non-Full-Duplex vs Full-Duplex +//! +//! In non-full-duplex modes, every sent message (including every +//! body chunk) creates an [`ActiveProcessingState`] with a +//! deadline. At most one may be outstanding — a second send while +//! one is active fails with [`ExchangeError::OrderingViolation`]. +//! +//! In full-duplex mode (`FULL_DUPLEX_STREAMED`), no messages in +//! the direction — headers, body chunks, or trailers — create +//! active processing state. The entire direction operates without +//! per-message timeouts. +//! +//! No background tasks are spawned. The bounded request channel +//! feeds tonic's h2 connection driver, which polls it lazily. +//! +//! [`ExternalProcessor.Process`]: praxis_proto::envoy::service::ext_proc::v3::external_processor_client::ExternalProcessorClient::process + +use std::time::Duration; + +use praxis_proto::envoy::service::ext_proc::v3::{ + BodyResponse, HeadersResponse, ImmediateResponse, ProcessingRequest, ProcessingResponse, ProtocolConfiguration, + TrailersResponse, external_processor_client::ExternalProcessorClient, processing_request, processing_response, +}; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tonic::transport::Channel; + +use crate::BodySendMode; + +// ----------------------------------------------------------------------------- +// Constants +// ----------------------------------------------------------------------------- + +/// Bounded channel capacity for the request stream. +/// +/// Capacity 1 provides tighter backpressure. No measured +/// performance benefit from capacity 2 was demonstrated. +pub(crate) const REQUEST_CHANNEL_CAPACITY: usize = 1; + +/// Minimum valid override duration. +const MIN_OVERRIDE: Duration = Duration::from_millis(1); + +// ----------------------------------------------------------------------------- +// ExchangeConfig +// ----------------------------------------------------------------------------- + +/// Configuration for opening a duplex exchange. +#[derive(Debug, Clone)] +pub(crate) struct ExchangeConfig { + /// Per-message timeout for non-full-duplex processing states. + pub message_timeout: Duration, + + /// Upper bound for processor-requested timeout overrides. + pub max_message_timeout: Option, + + /// Body send mode for the request direction. + pub request_body_mode: BodySendMode, + + /// Body send mode for the response direction. + pub response_body_mode: BodySendMode, +} + +// ----------------------------------------------------------------------------- +// ExchangeError +// ----------------------------------------------------------------------------- + +/// Errors during a duplex exchange. +#[derive(Debug, thiserror::Error)] +pub(crate) enum ExchangeError { + /// gRPC transport or protocol error. + #[error("ext_proc gRPC error: {0}")] + Grpc(#[from] tonic::Status), + + /// A processing-state deadline expired. + #[error("ext_proc message timeout")] + Timeout, + + /// The server closed the stream without a required response. + #[error("ext_proc server closed stream without response")] + EmptyStream, + + /// The request channel was closed or sending was finished. + #[error("ext_proc request channel closed")] + SendFailed, + + /// The exchange entered a terminal state. + #[error("ext_proc exchange closed")] + Closed, + + /// A processing deadline could not be represented. + #[error("ext_proc deadline overflow")] + DeadlineOverflow, + + /// A message violated within-direction ordering. + #[error("ext_proc ordering violation: {0}")] + OrderingViolation(String), +} + +// ----------------------------------------------------------------------------- +// ExchangeEvent +// ----------------------------------------------------------------------------- + +/// Typed exchange event classified from a processor response. +/// +/// Each variant preserves the proto response payload and any +/// `dynamic_metadata` from the envelope. +#[derive(Debug)] +pub(crate) enum ExchangeEvent { + /// Request headers response. + RequestHeaders { + /// Processor response payload. + response: HeadersResponse, + /// Structured dynamic metadata from the envelope. + metadata: Option, + }, + /// Request body response. + RequestBody { + /// Processor response payload. + response: BodyResponse, + /// Structured dynamic metadata from the envelope. + metadata: Option, + }, + /// Request trailers response. + #[expect( + dead_code, + reason = "classified by receive; consumed in follow-up response lifecycle PR" + )] + RequestTrailers { + /// Processor response payload. + response: TrailersResponse, + /// Structured dynamic metadata from the envelope. + metadata: Option, + }, + /// Response headers response. + #[expect( + dead_code, + reason = "classified by receive; consumed in follow-up response lifecycle PR" + )] + ResponseHeaders { + /// Processor response payload. + response: HeadersResponse, + /// Structured dynamic metadata from the envelope. + metadata: Option, + }, + /// Response body response. + #[expect( + dead_code, + reason = "classified by receive; consumed in follow-up response lifecycle PR" + )] + ResponseBody { + /// Processor response payload. + response: BodyResponse, + /// Structured dynamic metadata from the envelope. + metadata: Option, + }, + /// Response trailers response. + #[expect( + dead_code, + reason = "classified by receive; consumed in follow-up response lifecycle PR" + )] + ResponseTrailers { + /// Processor response payload. + response: TrailersResponse, + /// Structured dynamic metadata from the envelope. + metadata: Option, + }, + /// Immediate response — terminal event. + Immediate { + /// Processor immediate response payload. + response: ImmediateResponse, + /// Structured dynamic metadata from the envelope. + metadata: Option, + }, +} + +// ----------------------------------------------------------------------------- +// Phase Types +// ----------------------------------------------------------------------------- + +/// Outbound send phase for a direction. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum SendPhase { + /// No messages sent. + NotStarted, + /// Headers committed. + Headers, + /// Body chunks flowing. + BodyOpen, + /// Terminal body chunk committed (`end_of_stream`). + BodyEos, + /// Trailers committed. + Trailers, +} + +/// Per-direction outbound send state combining phase with +/// body-commitment tracking. +#[derive(Debug, Clone, Copy)] +struct DirectionSendState { + /// Current send phase for the direction. + phase: SendPhase, + /// Whether at least one body message has been committed. + body_ever_committed: bool, +} + +impl DirectionSendState { + /// Create a new send state at the initial phase. + fn new() -> Self { + Self { + phase: SendPhase::NotStarted, + body_ever_committed: false, + } + } +} + +/// Processor-output phase for a direction. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum OutputPhase { + /// No output received. + None, + /// Header response received. + Headers, + /// Body responses flowing. + BodyOpen, + /// Body output completed (EOS received). + BodyDone, + /// Trailer response received. + Trailers, +} + +// ----------------------------------------------------------------------------- +// Active Processing State +// ----------------------------------------------------------------------------- + +/// Which response type is expected from the processor. +/// +/// Each variant corresponds to one of the six directional +/// message types that require a response before the next +/// send is permitted (in non-full-duplex modes). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ExpectedResponse { + /// Awaiting a request headers response. + RequestHeaders, + /// Awaiting a request body response. + RequestBody, + /// Awaiting a request trailers response. + RequestTrailers, + /// Awaiting a response headers response. + ResponseHeaders, + /// Awaiting a response body response. + ResponseBody, + /// Awaiting a response trailers response. + ResponseTrailers, +} + +impl ExpectedResponse { + /// Map an expected response type to its direction. + fn direction(self) -> Direction { + match self { + Self::RequestHeaders | Self::RequestBody | Self::RequestTrailers => Direction::Request, + Self::ResponseHeaders | Self::ResponseBody | Self::ResponseTrailers => Direction::Response, + } + } +} + +/// Per-message processing state for non-full-duplex directions. +/// +/// Tracks the expected response type, deadline, and whether the +/// processor has consumed its one allowed timeout override. +/// Full-duplex directions never create this state. +#[derive(Debug)] +struct ActiveProcessingState { + /// Which response type will consume this state. + expected: ExpectedResponse, + /// Absolute deadline for the processor to respond. + deadline: tokio::time::Instant, + /// Whether the override has been consumed. + override_consumed: bool, +} + +// ----------------------------------------------------------------------------- +// Direction +// ----------------------------------------------------------------------------- + +/// Which direction a message belongs to. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Direction { + /// Request direction. + Request, + /// Response direction. + Response, +} + +// ----------------------------------------------------------------------------- +// SendTransition +// ----------------------------------------------------------------------------- + +/// Proposed state transition computed by [`ExtProcExchange::compute_send_transition`]. +/// +/// Pure value — applying it is the only mutation step. +struct SendTransition { + /// Which direction to advance. + direction: Direction, + /// New send phase for the direction. + new_phase: SendPhase, + /// Optional expected response to install after commit. + /// + /// `Some` for messages that require a response before the next + /// send (all non-full-duplex messages). `None` for all messages + /// in full-duplex directions. + active_state: Option, +} + +// ----------------------------------------------------------------------------- +// Bootstrap State +// ----------------------------------------------------------------------------- + +use std::{future::Future, pin::Pin}; + +/// Pinned boxed Process future, `Send + 'static` but not `Sync`. +type PinnedProcessFuture = + Pin>, tonic::Status>> + Send>>; + +/// Bootstrap state for the Process RPC. +/// +/// Wraps the pending non-`Sync` future in [`SyncWrapper`] so the +/// exchange satisfies `Send + Sync` for typed filter state. +/// Access/poll occurs only through `&mut` exchange methods. +/// +/// [`SyncWrapper`]: sync_wrapper::SyncWrapper +enum BootstrapState { + /// The Process RPC is pending. The wrapped future is polled + /// inline by [`send`] and resolved by [`receive`]. + /// + /// [`send`]: ExtProcExchange::send + /// [`receive`]: ExtProcExchange::receive + Pending(sync_wrapper::SyncWrapper), + + /// The Process RPC completed and the response stream is ready. + Ready(Box>), + + /// The Process RPC failed or the stream was consumed. + Closed, +} + +// ----------------------------------------------------------------------------- +// ExtProcExchange +// ----------------------------------------------------------------------------- + +/// Persistent bidirectional exchange with an external processor. +/// +/// Owns one [`Process`] gRPC stream. [`send`] validates +/// ordering, reserves channel capacity, commits the message, +/// then atomically updates phase and active processing state. +/// [`receive`] reads the next response, handles timeout +/// overrides, validates processor-output ordering, and returns +/// a typed [`ExchangeEvent`]. +/// +/// Timeout policy is derived internally from the active +/// processing state. Callers cannot select or override timeout +/// behavior. +/// +/// No background tasks are spawned. +/// +/// [`Process`]: ExternalProcessorClient::process +/// [`send`]: Self::send +/// [`receive`]: Self::receive +pub(crate) struct ExtProcExchange { + /// Non-full-duplex response expectation with deadline. + active_processing: Option, + + /// Whether the first message has been sent. + first_sent: bool, + + /// Upper bound for processor-requested timeout overrides. + max_message_timeout: Option, + + /// Per-message timeout (non-full-duplex modes). + message_timeout: Duration, + + /// Protocol configuration for the first request. + protocol_config: ProtocolConfiguration, + + /// Request direction body mode. + request_body_mode: BodySendMode, + + /// Processor output phase for the request direction. + request_output: OutputPhase, + + /// Request outbound send state. + request_send: DirectionSendState, + + /// Send half of the bounded request channel. + request_tx: Option>, + + /// Response direction body mode. + response_body_mode: BodySendMode, + + /// Processor output phase for the response direction. + response_output: OutputPhase, + + /// Response outbound send state. + response_send: DirectionSendState, + + /// Process RPC bootstrap and response stream state. + bootstrap: BootstrapState, + + /// Terminal state. + terminal: bool, +} + +impl ExtProcExchange { + /// Open a new exchange on the given channel. + /// + /// Synchronous — constructs the Process future without polling + /// it. The gRPC stream is established when [`send`] or + /// [`receive`] first drives the pending future. + /// + /// [`send`]: Self::send + /// [`receive`]: Self::receive + #[expect(clippy::unnecessary_wraps, reason = "follow-up PR adds preload that can fail")] + pub(crate) fn open(channel: Channel, config: &ExchangeConfig) -> Result { + let (tx, rx) = mpsc::channel(REQUEST_CHANNEL_CAPACITY); + let request_stream = ReceiverStream::new(rx); + let mut client = ExternalProcessorClient::new(channel); + let pending: PinnedProcessFuture = Box::pin(async move { client.process(request_stream).await }); + + Ok(Self { + active_processing: None, + first_sent: false, + max_message_timeout: config.max_message_timeout, + message_timeout: config.message_timeout, + protocol_config: ProtocolConfiguration { + request_body_mode: config.request_body_mode.to_proto_i32(), + response_body_mode: config.response_body_mode.to_proto_i32(), + send_body_without_waiting_for_header_response: false, + }, + request_body_mode: config.request_body_mode, + request_output: OutputPhase::None, + request_send: DirectionSendState::new(), + request_tx: Some(tx), + response_body_mode: config.response_body_mode, + response_output: OutputPhase::None, + response_send: DirectionSendState::new(), + bootstrap: BootstrapState::Pending(sync_wrapper::SyncWrapper::new(pending)), + terminal: false, + }) + } + + /// Send a processing request with transactional state update. + /// + /// 1. Validates the proposed transition (pure, no mutation). + /// 2. Reserves bounded channel capacity (cancellable). + /// 3. Commits the message via `permit.send()`. + /// 4. Atomically updates phase and active processing state (no await between commit steps). + pub(crate) async fn send(&mut self, request: processing_request::Request) -> Result<(), ExchangeError> { + if self.terminal { + return Err(ExchangeError::Closed); + } + + let transition = self.compute_send_transition(&request)?; + + let include_config = !self.first_sent; + let mut msg = ProcessingRequest { + request: Some(request), + ..Default::default() + }; + if include_config { + msg.protocol_config = Some(self.protocol_config); + } + + let timeout = transition.active_state.map(|_| self.message_timeout); + // Clone the sender to avoid aliasing `&mut self` with the permit borrow. + let tx = self.request_tx.clone().ok_or(ExchangeError::SendFailed)?; + let permit = self.reserve_while_bootstrapping(&tx).await?; + + let checked_deadline = timeout + .map(|dur| { + tokio::time::Instant::now() + .checked_add(dur) + .ok_or(ExchangeError::DeadlineOverflow) + }) + .transpose()?; + + permit.send(msg); + + // --- Atomic state commit (no await below) --- + if include_config { + self.first_sent = true; + } + self.apply_send_transition(&transition, checked_deadline); + Ok(()) + } + + /// Read the next response, validate output ordering, and + /// return a typed event. + /// + /// Timeout policy is derived internally: uses the deadline + /// from [`ActiveProcessingState`] if present, or awaits + /// without timeout otherwise. The override loop handles + /// `override_message_timeout` envelopes before classification. + /// + /// Takes no arguments — timeout behavior is fully internal. + pub(crate) async fn receive(&mut self) -> Result { + if self.terminal { + return Err(ExchangeError::Closed); + } + + let result = self.receive_inner().await; + match result { + Ok(event) => { + if matches!(event, ExchangeEvent::Immediate { .. }) { + self.terminal = true; + } + Ok(event) + }, + Err(e) => { + self.terminal = true; + Err(e) + }, + } + } + + /// Half-close the request stream. Direction-local, not + /// terminal. Response events remain readable. + pub(crate) fn finish_sending(&mut self) { + self.request_tx.take(); + } + + /// Consume remaining response stream messages to allow clean + /// h2 stream closure. Prevents `RST_STREAM` on exchange drop + /// when the server has trailing data. + /// + /// Only drains when the bootstrap is [`Ready`]. If the + /// bootstrap is still [`Pending`], this is a no-op — callers + /// must ensure at least one successful [`receive`] before + /// calling `drain_trailing` for clean closure. + /// + /// [`Ready`]: BootstrapState::Ready + /// [`Pending`]: BootstrapState::Pending + /// [`receive`]: Self::receive + pub(crate) async fn drain_trailing(&mut self) { + if let BootstrapState::Ready(ref mut stream) = self.bootstrap { + while stream.message().await.is_ok_and(|m| m.is_some()) {} + } + } + + /// Whether the exchange has entered a terminal state. + pub(crate) fn is_terminal(&self) -> bool { + self.terminal + } + + /// Whether the outbound request channel has been closed. + /// + /// Outbound closure is direction-local: the bootstrap/response + /// stream may still contain buffered responses. + #[expect(dead_code, reason = "used by integration tests and follow-up PRs")] + pub(crate) fn is_outbound_closed(&self) -> bool { + self.request_tx.is_none() + } + + /// Reserve bounded channel capacity while driving the pending + /// Process future via [`tokio::select!`]. + /// + /// When the bootstrap is [`Pending`], polls both the channel + /// reserve and the Process future. If the Process future + /// resolves first, stores the response stream as [`Ready`] + /// and continues the same reserve attempt. If the channel + /// reserve wins, the pending future is preserved. + /// + /// [`Pending`]: BootstrapState::Pending + /// [`Ready`]: BootstrapState::Ready + #[expect(clippy::too_many_lines, reason = "select! branches with state transitions")] + async fn reserve_while_bootstrapping<'a>( + &mut self, + tx: &'a mpsc::Sender, + ) -> Result, ExchangeError> { + loop { + match self.bootstrap { + BootstrapState::Pending(ref mut wrapper) => { + let future = wrapper.get_mut(); + tokio::select! { + biased; + permit = tx.reserve() => { + return permit.map_err(|_send_err| { + self.request_tx.take(); + ExchangeError::SendFailed + }); + }, + result = future.as_mut() => { + match result { + Ok(response) => { + self.bootstrap = BootstrapState::Ready(Box::new(response.into_inner())); + }, + Err(status) => { + self.bootstrap = BootstrapState::Closed; + self.request_tx.take(); + self.terminal = true; + return Err(ExchangeError::Grpc(status)); + }, + } + }, + } + }, + BootstrapState::Ready(_) | BootstrapState::Closed => { + return tx.reserve().await.map_err(|_send_err| { + self.request_tx.take(); + ExchangeError::SendFailed + }); + }, + } + } + } + + /// Snapshot of output phases for transactional-validation + /// testing. + #[cfg(test)] + pub(crate) fn output_phases(&self) -> (OutputPhase, OutputPhase) { + (self.request_output, self.response_output) + } +} + +// ----------------------------------------------------------------------------- +// Send Transition — Computation +// ----------------------------------------------------------------------------- + +#[expect( + clippy::multiple_inherent_impl, + reason = "sectioned state-machine implementation keeps domains reviewable" +)] +impl ExtProcExchange { + /// Compute the proposed send transition without mutating state. + /// + /// Validates ordering, body-mode gating, and active processing + /// exclusivity. Returns a pure [`SendTransition`] value. + /// + /// Full-duplex directions (`FULL_DUPLEX_STREAMED`) never create + /// active processing state for any message type — headers, body, + /// or trailers — because full-duplex processing has no + /// per-message timeout. + #[expect( + clippy::too_many_lines, + reason = "six direction×type variants with mode-aware active-state logic" + )] + fn compute_send_transition(&self, request: &processing_request::Request) -> Result { + let transition = match request { + processing_request::Request::RequestHeaders(_) => { + require_phase( + self.request_send.phase, + SendPhase::NotStarted, + "request headers already sent", + )?; + let creates_active = !self.request_body_mode.is_full_duplex(); + if creates_active { + self.require_no_active_processing("request headers")?; + } + SendTransition { + direction: Direction::Request, + new_phase: SendPhase::Headers, + active_state: creates_active.then_some(ExpectedResponse::RequestHeaders), + } + }, + processing_request::Request::RequestBody(b) => { + self.require_body_mode_enabled(Direction::Request)?; + require_body_phase(self.request_send.phase, "request body")?; + let full_duplex = self.request_body_mode.is_full_duplex(); + if !full_duplex { + self.require_no_active_processing("request body")?; + } + SendTransition { + direction: Direction::Request, + new_phase: if b.end_of_stream { + SendPhase::BodyEos + } else { + SendPhase::BodyOpen + }, + active_state: if full_duplex { + None + } else { + Some(ExpectedResponse::RequestBody) + }, + } + }, + processing_request::Request::RequestTrailers(_) => { + require_trailer_phase(self.request_send.phase, "request trailers")?; + let creates_active = !self.request_body_mode.is_full_duplex(); + if creates_active { + self.require_no_active_processing("request trailers")?; + } + SendTransition { + direction: Direction::Request, + new_phase: SendPhase::Trailers, + active_state: creates_active.then_some(ExpectedResponse::RequestTrailers), + } + }, + processing_request::Request::ResponseHeaders(_) => { + require_phase( + self.response_send.phase, + SendPhase::NotStarted, + "response headers already sent", + )?; + let creates_active = !self.response_body_mode.is_full_duplex(); + if creates_active { + self.require_no_active_processing("response headers")?; + } + SendTransition { + direction: Direction::Response, + new_phase: SendPhase::Headers, + active_state: creates_active.then_some(ExpectedResponse::ResponseHeaders), + } + }, + processing_request::Request::ResponseBody(b) => { + self.require_body_mode_enabled(Direction::Response)?; + require_body_phase(self.response_send.phase, "response body")?; + let full_duplex = self.response_body_mode.is_full_duplex(); + if !full_duplex { + self.require_no_active_processing("response body")?; + } + SendTransition { + direction: Direction::Response, + new_phase: if b.end_of_stream { + SendPhase::BodyEos + } else { + SendPhase::BodyOpen + }, + active_state: if full_duplex { + None + } else { + Some(ExpectedResponse::ResponseBody) + }, + } + }, + processing_request::Request::ResponseTrailers(_) => { + require_trailer_phase(self.response_send.phase, "response trailers")?; + let creates_active = !self.response_body_mode.is_full_duplex(); + if creates_active { + self.require_no_active_processing("response trailers")?; + } + SendTransition { + direction: Direction::Response, + new_phase: SendPhase::Trailers, + active_state: creates_active.then_some(ExpectedResponse::ResponseTrailers), + } + }, + }; + + Ok(transition) + } + + /// Reject sends when active processing state is outstanding. + fn require_no_active_processing(&self, label: &str) -> Result<(), ExchangeError> { + if self.active_processing.is_some() { + return Err(ExchangeError::OrderingViolation(format!( + "cannot send {label}: active processing state outstanding" + ))); + } + Ok(()) + } + + /// Reject body sends when the body mode is `None`. + fn require_body_mode_enabled(&self, direction: Direction) -> Result<(), ExchangeError> { + let mode = match direction { + Direction::Request => self.request_body_mode, + Direction::Response => self.response_body_mode, + }; + if matches!(mode, BodySendMode::None) { + let dir = match direction { + Direction::Request => "request", + Direction::Response => "response", + }; + return Err(ExchangeError::OrderingViolation(format!( + "{dir} body send rejected: body mode is none" + ))); + } + Ok(()) + } +} + +// ----------------------------------------------------------------------------- +// Send Transition — Application +// ----------------------------------------------------------------------------- + +#[expect( + clippy::multiple_inherent_impl, + reason = "sectioned state-machine implementation keeps domains reviewable" +)] +impl ExtProcExchange { + /// Apply the committed transition atomically (no await). + /// + /// The deadline is created after `reserve().await` and before + /// `permit.send()`, so producer backpressure is excluded from + /// the processing deadline. + fn apply_send_transition(&mut self, t: &SendTransition, checked_deadline: Option) { + let state = match t.direction { + Direction::Request => &mut self.request_send, + Direction::Response => &mut self.response_send, + }; + let is_body = matches!(t.new_phase, SendPhase::BodyOpen | SendPhase::BodyEos); + state.phase = t.new_phase; + if is_body { + state.body_ever_committed = true; + } + if let (Some(expected), Some(deadline)) = (t.active_state, checked_deadline) { + self.active_processing = Some(ActiveProcessingState { + expected, + deadline, + override_consumed: false, + }); + } + } +} + +// ----------------------------------------------------------------------------- +// Send Phase Validation Helpers +// ----------------------------------------------------------------------------- + +/// Require the current phase to be `expected`. +fn require_phase(current: SendPhase, expected: SendPhase, msg: &str) -> Result<(), ExchangeError> { + if current != expected { + return Err(ExchangeError::OrderingViolation(msg.to_owned())); + } + Ok(()) +} + +/// Require the current phase to accept a body message. +fn require_body_phase(current: SendPhase, label: &str) -> Result<(), ExchangeError> { + if !matches!(current, SendPhase::Headers | SendPhase::BodyOpen) { + return Err(ExchangeError::OrderingViolation(format!( + "{label} requires headers sent and no EOS/trailers" + ))); + } + Ok(()) +} + +/// Require the current phase to accept trailers. +fn require_trailer_phase(current: SendPhase, label: &str) -> Result<(), ExchangeError> { + if !matches!(current, SendPhase::Headers | SendPhase::BodyOpen) { + return Err(ExchangeError::OrderingViolation(format!( + "{label} requires headers sent and no EOS" + ))); + } + Ok(()) +} + +// ----------------------------------------------------------------------------- +// Receive Implementation +// ----------------------------------------------------------------------------- + +#[expect( + clippy::multiple_inherent_impl, + reason = "sectioned state-machine implementation keeps domains reviewable" +)] +impl ExtProcExchange { + /// Ensure the bootstrap has resolved to a ready response + /// stream. Awaits the pending Process future if necessary. + async fn ensure_response_stream(&mut self) -> Result<(), ExchangeError> { + if let BootstrapState::Pending(ref mut wrapper) = self.bootstrap { + let future = wrapper.get_mut(); + let response = future.await.map_err(|status| { + self.bootstrap = BootstrapState::Closed; + self.terminal = true; + ExchangeError::Grpc(status) + })?; + self.bootstrap = BootstrapState::Ready(Box::new(response.into_inner())); + } + Ok(()) + } + + /// Internal receive with deferred stream resolution, override + /// loop, and classification. + /// + /// 1. Resolves the deferred response stream if needed. + /// 2. Reads a response with optional deadline from active processing. + /// 3. Runs the override loop: if `override_message_timeout` is present on the envelope, it is an override envelope. + /// Valid overrides replace the deadline and continue reading. Invalid overrides are silently ignored (the entire + /// envelope is discarded, including any populated response oneof). Override envelopes never reach + /// [`classify_and_validate`]. + /// 4. Classifies the response against expected type and validates output ordering. + /// + /// [`classify_and_validate`]: Self::classify_and_validate + async fn receive_inner(&mut self) -> Result { + self.ensure_response_stream().await?; + loop { + let deadline = self.active_processing.as_ref().map(|ap| ap.deadline); + let stream = match self.bootstrap { + BootstrapState::Ready(ref mut s) => s, + BootstrapState::Closed => return Err(ExchangeError::Closed), + BootstrapState::Pending(_) => { + return Err(ExchangeError::OrderingViolation( + "bootstrap still pending after ensure".to_owned(), + )); + }, + }; + let resp = read_with_optional_deadline(stream, deadline).await?; + + // If override_message_timeout is present, this is an + // override envelope — it never reaches classification. + if resp.override_message_timeout.is_some() { + // Try to apply it; whether accepted or rejected, + // the envelope is consumed and we read the next one. + self.try_accept_override(&resp); + continue; + } + + return self.classify_and_validate(resp); + } + } + + /// Try to accept an override from a response envelope. + /// + /// Returns `true` if the override was accepted and the active + /// deadline was replaced. Returns `false` if the override is + /// invalid or not applicable. In both cases, the caller + /// consumes the entire envelope and continues reading — + /// invalid override envelopes are never classified as + /// ordinary responses. + /// + /// Conditions for acceptance: + /// - `override_message_timeout` is present + /// - Active processing state exists (a timer is running) + /// - Override has not already been consumed for this state + /// - `max_message_timeout` is configured + /// - Duration passes strict protobuf validation + /// - Duration is at least 1ms + #[expect( + clippy::too_many_lines, + reason = "override validation with multiple early-return conditions" + )] + fn try_accept_override(&mut self, resp: &ProcessingResponse) -> bool { + let Some(proto_dur) = resp.override_message_timeout.as_ref() else { + return false; + }; + + if self.active_processing.as_ref().is_none_or(|ap| ap.override_consumed) { + return false; + } + + let Some(max) = self.max_message_timeout else { + return false; + }; + + let Some(dur) = parse_override_duration(proto_dur) else { + return false; + }; + + let clamped = dur.min(max); + if clamped < dur { + tracing::warn!( + requested_ms = dur.as_millis(), + clamped_ms = clamped.as_millis(), + "ext_proc exchange: override clamped to max" + ); + } + + tracing::debug!( + override_ms = clamped.as_millis(), + "ext_proc exchange: timeout override accepted" + ); + + // Compute the new deadline with overflow protection. + let Some(deadline) = tokio::time::Instant::now().checked_add(clamped) else { + return false; + }; + + // Mutate active processing state. The `is_none_or` + // guard above ensures this branch is taken. + if let Some(ap) = self.active_processing.as_mut() { + ap.deadline = deadline; + ap.override_consumed = true; + } + true + } + + /// Classify a raw response and validate output ordering. + /// + /// Validation is transactional: output phase is advanced on a + /// local copy first, then committed only after all checks pass. + /// This prevents rejected responses from corrupting state. + /// + /// [`ImmediateResponse`] is terminal regardless of active state + /// but requires at least one outbound message to have been sent. + #[expect( + clippy::too_many_lines, + reason = "seven response variants with transactional output validation" + )] + fn classify_and_validate(&mut self, resp: ProcessingResponse) -> Result { + let metadata = resp.dynamic_metadata; + + let Some(response) = resp.response else { + return Err(ExchangeError::OrderingViolation( + "empty response with no override".to_owned(), + )); + }; + + match response { + processing_response::Response::ImmediateResponse(r) => { + if !self.first_sent { + return Err(ExchangeError::OrderingViolation( + "immediate response before first send".to_owned(), + )); + } + self.active_processing = None; + Ok(ExchangeEvent::Immediate { response: r, metadata }) + }, + processing_response::Response::RequestHeaders(r) => { + let expected = ExpectedResponse::RequestHeaders; + self.validate_response_solicited(expected)?; + let mut local_output = self.request_output; + validate_output_transition(&mut local_output, OutputPhase::Headers, "request headers")?; + self.request_output = local_output; + self.consume_active_if_matched(expected); + Ok(ExchangeEvent::RequestHeaders { response: r, metadata }) + }, + processing_response::Response::RequestBody(r) => { + let expected = ExpectedResponse::RequestBody; + self.validate_response_solicited(expected)?; + validate_body_mutation_mode(&r, self.request_body_mode, "request body")?; + let mut local_output = self.request_output; + validate_body_output(&mut local_output, &r, "request body")?; + self.request_output = local_output; + self.consume_active_if_matched(expected); + Ok(ExchangeEvent::RequestBody { response: r, metadata }) + }, + processing_response::Response::RequestTrailers(r) => { + let expected = ExpectedResponse::RequestTrailers; + self.validate_response_solicited(expected)?; + let mut local_output = self.request_output; + validate_output_transition(&mut local_output, OutputPhase::Trailers, "request trailers")?; + self.request_output = local_output; + self.consume_active_if_matched(expected); + Ok(ExchangeEvent::RequestTrailers { response: r, metadata }) + }, + processing_response::Response::ResponseHeaders(r) => { + let expected = ExpectedResponse::ResponseHeaders; + self.validate_response_solicited(expected)?; + let mut local_output = self.response_output; + validate_output_transition(&mut local_output, OutputPhase::Headers, "response headers")?; + self.response_output = local_output; + self.consume_active_if_matched(expected); + Ok(ExchangeEvent::ResponseHeaders { response: r, metadata }) + }, + processing_response::Response::ResponseBody(r) => { + let expected = ExpectedResponse::ResponseBody; + self.validate_response_solicited(expected)?; + validate_body_mutation_mode(&r, self.response_body_mode, "response body")?; + let mut local_output = self.response_output; + validate_body_output(&mut local_output, &r, "response body")?; + self.response_output = local_output; + self.consume_active_if_matched(expected); + Ok(ExchangeEvent::ResponseBody { response: r, metadata }) + }, + processing_response::Response::ResponseTrailers(r) => { + let expected = ExpectedResponse::ResponseTrailers; + self.validate_response_solicited(expected)?; + let mut local_output = self.response_output; + validate_output_transition(&mut local_output, OutputPhase::Trailers, "response trailers")?; + self.response_output = local_output; + self.consume_active_if_matched(expected); + Ok(ExchangeEvent::ResponseTrailers { response: r, metadata }) + }, + } + } + + /// Get the body send mode for the given direction. + fn body_mode(&self, direction: Direction) -> BodySendMode { + match direction { + Direction::Request => self.request_body_mode, + Direction::Response => self.response_body_mode, + } + } + + /// Get the outbound send state for the given direction. + fn send_state(&self, direction: Direction) -> DirectionSendState { + match direction { + Direction::Request => self.request_send, + Direction::Response => self.response_send, + } + } + + /// Whether an outbound message matching `received` has been + /// committed in the appropriate direction. + fn committed_for(&self, received: ExpectedResponse) -> bool { + let state = self.send_state(received.direction()); + match received { + ExpectedResponse::RequestHeaders | ExpectedResponse::ResponseHeaders => { + state.phase != SendPhase::NotStarted + }, + ExpectedResponse::RequestBody | ExpectedResponse::ResponseBody => state.body_ever_committed, + ExpectedResponse::RequestTrailers | ExpectedResponse::ResponseTrailers => { + state.phase == SendPhase::Trailers + }, + } + } + + /// Validate that the server's response was solicited by a + /// matching outbound commit. + /// + /// In full-duplex mode, checks that an outbound message of the + /// corresponding type was committed at some point. In + /// non-full-duplex mode, checks against the active processing + /// state. + fn validate_response_solicited(&self, received: ExpectedResponse) -> Result<(), ExchangeError> { + let direction = received.direction(); + let mode = self.body_mode(direction); + + if mode.is_full_duplex() { + if !self.committed_for(received) { + return Err(ExchangeError::OrderingViolation(format!( + "unsolicited {received:?} (no matching outbound committed)" + ))); + } + return Ok(()); + } + + match &self.active_processing { + Some(active) if active.expected == received => Ok(()), + Some(active) => Err(ExchangeError::OrderingViolation(format!( + "expected {:?}, received {received:?}", + active.expected + ))), + None => Err(ExchangeError::OrderingViolation(format!( + "unsolicited {received:?} (no active processing state)" + ))), + } + } + + /// Consume active processing state if the response matches + /// the expected type. Only applicable in non-full-duplex mode. + fn consume_active_if_matched(&mut self, received: ExpectedResponse) { + if self + .active_processing + .as_ref() + .is_some_and(|ap| ap.expected == received) + { + self.active_processing = None; + } + } +} + +// ----------------------------------------------------------------------------- +// Duration Parsing +// ----------------------------------------------------------------------------- + +/// Parse a protobuf [`Duration`] into a [`std::time::Duration`] +/// with strict validation. +/// +/// Returns `None` if the value is negative, out of protobuf range, +/// has invalid nanos, or is below the minimum override threshold. +/// +/// [`Duration`]: prost_types::Duration +fn parse_override_duration(value: &prost_types::Duration) -> Option { + // Protobuf Duration: seconds in [-315_576_000_000, 315_576_000_000], + // nanos in [-999_999_999, 999_999_999], same sign as seconds. + if value.seconds < 0 || value.seconds > 315_576_000_000 { + return None; + } + if value.nanos < 0 || value.nanos >= 1_000_000_000 { + return None; + } + #[expect(clippy::cast_sign_loss, reason = "negative values rejected above")] + let dur = Duration::new(value.seconds as u64, value.nanos as u32); + if dur < MIN_OVERRIDE { + return None; + } + Some(dur) +} + +// ----------------------------------------------------------------------------- +// Output Validation +// ----------------------------------------------------------------------------- + +/// Validate a non-body output phase transition. +fn validate_output_transition( + output: &mut OutputPhase, + expected: OutputPhase, + label: &str, +) -> Result<(), ExchangeError> { + let valid = match (output, expected) { + (phase @ &mut OutputPhase::None, OutputPhase::Headers) => { + *phase = OutputPhase::Headers; + true + }, + (phase @ &mut (OutputPhase::Headers | OutputPhase::BodyOpen), OutputPhase::Trailers) => { + *phase = OutputPhase::Trailers; + true + }, + _ => false, + }; + if !valid { + return Err(ExchangeError::OrderingViolation(format!( + "unexpected {label} in output phase" + ))); + } + Ok(()) +} + +/// Validate and advance body output phase, including EOS tracking. +/// +/// Checks `StreamedBodyResponse.end_of_stream` to transition +/// to [`OutputPhase::BodyDone`]. Post-EOS and duplicate-EOS +/// body outputs are rejected. +fn validate_body_output(output: &mut OutputPhase, body_resp: &BodyResponse, label: &str) -> Result<(), ExchangeError> { + if matches!(output, OutputPhase::BodyDone | OutputPhase::Trailers) { + return Err(ExchangeError::OrderingViolation(format!("post-EOS {label} output"))); + } + if !matches!(output, OutputPhase::Headers | OutputPhase::BodyOpen) { + return Err(ExchangeError::OrderingViolation(format!( + "{label} output before headers" + ))); + } + + let is_eos = body_resp + .response + .as_ref() + .and_then(|c| c.body_mutation.as_ref()) + .and_then(|bm| match &bm.mutation { + Some(praxis_proto::envoy::service::ext_proc::v3::body_mutation::Mutation::StreamedResponse(sr)) => { + Some(sr.end_of_stream) + }, + _ => None, + }) + .unwrap_or(false); + + *output = if is_eos { + OutputPhase::BodyDone + } else { + OutputPhase::BodyOpen + }; + Ok(()) +} + +/// Validate that the body response mutation type matches the +/// direction's body send mode. +/// +/// - [`BodySendMode::FullDuplexStreamed`] requires a [`StreamedBodyResponse`] mutation. +/// - All other modes reject [`StreamedBodyResponse`] mutations. +/// +/// [`StreamedBodyResponse`]: praxis_proto::envoy::service::ext_proc::v3::StreamedBodyResponse +fn validate_body_mutation_mode( + body_resp: &BodyResponse, + body_mode: BodySendMode, + label: &str, +) -> Result<(), ExchangeError> { + let has_streamed = body_resp + .response + .as_ref() + .and_then(|c| c.body_mutation.as_ref()) + .is_some_and(|bm| { + matches!( + bm.mutation, + Some(praxis_proto::envoy::service::ext_proc::v3::body_mutation::Mutation::StreamedResponse(_)) + ) + }); + + if body_mode.is_full_duplex() && !has_streamed { + return Err(ExchangeError::OrderingViolation(format!( + "{label}: full-duplex mode requires StreamedBodyResponse mutation" + ))); + } + if !body_mode.is_full_duplex() && has_streamed { + return Err(ExchangeError::OrderingViolation(format!( + "{label}: StreamedBodyResponse mutation requires full-duplex mode" + ))); + } + Ok(()) +} + +// ----------------------------------------------------------------------------- +// I/O Utilities +// ----------------------------------------------------------------------------- + +/// Read the next message with an optional absolute deadline. +async fn read_with_optional_deadline( + streaming: &mut tonic::Streaming, + deadline: Option, +) -> Result { + if let Some(dl) = deadline { + tokio::time::timeout_at(dl, next_message(streaming)) + .await + .map_err(|_elapsed| ExchangeError::Timeout)? + } else { + next_message(streaming).await + } +} + +/// Reserve capacity, compute checked deadline, and commit message. +/// +/// Does not mutate exchange lifecycle state. The caller commits +/// `first_sent`, outbound phase, and active processing state +/// after this returns. +/// `timeout` is `Some(duration)` when the committed message +/// requires a processing deadline, `None` otherwise. +#[cfg(test)] +pub(crate) async fn commit_message( + tx: &mpsc::Sender, + msg: ProcessingRequest, + timeout: Option, +) -> Result, ExchangeError> { + let permit = tx.reserve().await.map_err(|_send_err| ExchangeError::SendFailed)?; + + let deadline = if let Some(dur) = timeout { + Some( + tokio::time::Instant::now() + .checked_add(dur) + .ok_or(ExchangeError::DeadlineOverflow)?, + ) + } else { + None + }; + + permit.send(msg); + Ok(deadline) +} + +/// Read the next message from the response stream. +async fn next_message( + streaming: &mut tonic::Streaming, +) -> Result { + streaming + .message() + .await + .map_err(ExchangeError::Grpc)? + .ok_or(ExchangeError::EmptyStream) +} diff --git a/filter/ext-proc/src/lib.rs b/filter/ext-proc/src/lib.rs index 9d84062a..4016542b 100644 --- a/filter/ext-proc/src/lib.rs +++ b/filter/ext-proc/src/lib.rs @@ -36,6 +36,8 @@ #![deny(unreachable_pub)] mod callout; +#[expect(dead_code, reason = "wired into ExtProcFilter in follow-up PR")] +pub(crate) mod duplex; mod mutations; use std::time::Duration; @@ -340,6 +342,27 @@ impl std::fmt::Display for BodySendMode { } } +impl BodySendMode { + /// Whether this mode uses full-duplex streaming. + pub(crate) fn is_full_duplex(self) -> bool { + self == Self::FullDuplexStreamed + } + + /// Convert to the protobuf [`BodySendMode`] enum integer value. + /// + /// [`BodySendMode`]: praxis_proto::envoy::service::ext_proc::v3::BodySendMode + pub(crate) fn to_proto_i32(self) -> i32 { + use praxis_proto::envoy::service::ext_proc::v3::BodySendMode as ProtoMode; + match self { + Self::None => ProtoMode::None as i32, + Self::Streamed => ProtoMode::Streamed as i32, + Self::Buffered => ProtoMode::Buffered as i32, + Self::BufferedPartial => ProtoMode::BufferedPartial as i32, + Self::FullDuplexStreamed => ProtoMode::FullDuplexStreamed as i32, + } + } +} + // ----------------------------------------------------------------------------- // MutationRulesConfig / ForwardRulesConfig // ----------------------------------------------------------------------------- diff --git a/filter/ext-proc/src/tests.rs b/filter/ext-proc/src/tests.rs index 46522ea4..2c1da15f 100644 --- a/filter/ext-proc/src/tests.rs +++ b/filter/ext-proc/src/tests.rs @@ -1,12 +1,17 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2026 Praxis Contributors -#![allow( +#![expect( + clippy::items_after_statements, clippy::let_underscore_must_use, clippy::unwrap_used, clippy::expect_used, clippy::indexing_slicing, clippy::panic, + clippy::clone_on_ref_ptr, + clippy::doc_markdown, + clippy::significant_drop_tightening, + clippy::too_many_lines, reason = "tests" )] @@ -24,6 +29,7 @@ use praxis_proto::envoy::service::{ }; use super::*; +use crate::duplex::{ExchangeConfig, ExchangeError, ExchangeEvent, ExtProcExchange}; // ----------------------------------------------------------------------------- // Config Parsing @@ -2129,7 +2135,7 @@ use std::{net::SocketAddr, pin::Pin}; use async_trait::async_trait; use praxis_proto::envoy::service::ext_proc::v3::{ - BodyResponse, ProcessingRequest, ProcessingResponse, + BodyResponse, ProcessingRequest, ProcessingResponse, ProtocolConfiguration, TrailersResponse, external_processor_server::{ExternalProcessor, ExternalProcessorServer}, processing_request, processing_response, }; @@ -2237,10 +2243,22 @@ fn build_noop_response(req: &ProcessingRequest) -> ProcessingResponse { Some(processing_request::Request::RequestHeaders(_)) => { processing_response::Response::RequestHeaders(HeadersResponse { response: None }) }, + Some(processing_request::Request::RequestBody(_)) => { + processing_response::Response::RequestBody(BodyResponse { response: None }) + }, + Some(processing_request::Request::RequestTrailers(_)) => { + processing_response::Response::RequestTrailers(TrailersResponse { header_mutation: None }) + }, Some(processing_request::Request::ResponseHeaders(_)) => { processing_response::Response::ResponseHeaders(HeadersResponse { response: None }) }, - _ => processing_response::Response::RequestHeaders(HeadersResponse { response: None }), + Some(processing_request::Request::ResponseBody(_)) => { + processing_response::Response::ResponseBody(BodyResponse { response: None }) + }, + Some(processing_request::Request::ResponseTrailers(_)) => { + processing_response::Response::ResponseTrailers(TrailersResponse { header_mutation: None }) + }, + None => processing_response::Response::RequestHeaders(HeadersResponse { response: None }), }; ProcessingResponse { response: Some(response), @@ -2371,3 +2389,3940 @@ async fn wait_for_server(addr: SocketAddr) { } panic!("mock server at {addr} did not become ready"); } + +// ============================================================================= +// Duplex Exchange Tests +// ============================================================================= + +fn default_exchange_config() -> ExchangeConfig { + ExchangeConfig { + message_timeout: Duration::from_secs(5), + max_message_timeout: None, + request_body_mode: BodySendMode::None, + response_body_mode: BodySendMode::None, + } +} + +fn streamed_body_exchange_config() -> ExchangeConfig { + ExchangeConfig { + request_body_mode: BodySendMode::Streamed, + response_body_mode: BodySendMode::Streamed, + ..default_exchange_config() + } +} + +fn full_duplex_exchange_config() -> ExchangeConfig { + ExchangeConfig { + request_body_mode: BodySendMode::FullDuplexStreamed, + response_body_mode: BodySendMode::FullDuplexStreamed, + ..default_exchange_config() + } +} + +// ----------------------------------------------------------------------------- +// Duplex Mock Server +// ----------------------------------------------------------------------------- + +/// Configurable behavior for the duplex mock processor. +/// +/// Unlike [`MockBehavior`] which handles one message, +/// this reads the full conversation. +#[derive(Clone)] +enum DuplexBehavior { + /// Read request headers, respond with header mutation. + EchoHeaders { name: String, value: String }, + + /// Read headers + body chunks. Respond only after body EOS. + /// Returns header response then body response. + DelayedRouting { header_name: String, header_value: String }, + + /// Respond with ImmediateResponse on request headers. + ImmediateOnHeaders { status: i32, body: String }, + + /// Read headers, then respond with ImmediateResponse on first body chunk. + ImmediateOnBody { status: i32, body: String }, + + /// Read headers + body EOS, respond with multiple StreamedBodyResponse chunks. + StreamedBodyChunks { chunks: Vec> }, + + /// Handle full lifecycle: request headers, request body, response headers. + FullLifecycle { + req_header_name: String, + req_header_value: String, + resp_header_name: String, + resp_header_value: String, + }, + + /// Never respond (timeout testing). + Hang, + + /// Close stream immediately without responding. + CloseEarly, + + /// Send override_message_timeout then delayed header response. + OverrideTimeout { + override_ms: u64, + delay_ms: u64, + name: String, + value: String, + }, + + /// Echo headers, respond with unexpected body response type. + UnexpectedResponseType, + + /// Read headers + body, respond to both. Body response uses + /// simple BodyMutation (not streamed). + HeadersAndBody, +} + +struct DuplexMockProcessor { + behavior: DuplexBehavior, +} + +#[async_trait] +impl ExternalProcessor for DuplexMockProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let behavior = self.behavior.clone(); + + let (tx, rx) = tokio::sync::mpsc::channel(16); + tokio::spawn(async move { + match behavior { + DuplexBehavior::EchoHeaders { name, value } => { + let msg = stream.message().await.unwrap().unwrap(); + let resp = build_add_header_response(&msg, &name, &value); + drop(tx.send(Ok(resp)).await); + }, + DuplexBehavior::DelayedRouting { + header_name, + header_value, + } => { + let header_msg = stream.message().await.unwrap().unwrap(); + loop { + let body_msg = stream.message().await.unwrap().unwrap(); + if let Some(processing_request::Request::RequestBody(b)) = &body_msg.request + && b.end_of_stream + { + break; + } + } + let header_resp = build_add_header_response(&header_msg, &header_name, &header_value); + drop(tx.send(Ok(header_resp)).await); + use praxis_proto::envoy::service::ext_proc::v3::{ + BodyMutation, CommonResponse, StreamedBodyResponse, body_mutation, + }; + let body_resp = ProcessingResponse { + response: Some(processing_response::Response::RequestBody(BodyResponse { + response: Some(CommonResponse { + body_mutation: Some(BodyMutation { + mutation: Some(body_mutation::Mutation::StreamedResponse(StreamedBodyResponse { + body: Vec::new(), + end_of_stream: true, + })), + }), + ..Default::default() + }), + })), + ..Default::default() + }; + drop(tx.send(Ok(body_resp)).await); + }, + DuplexBehavior::ImmediateOnHeaders { status, body } => { + let _msg = stream.message().await.unwrap().unwrap(); + let resp = build_immediate_response(status, &body); + drop(tx.send(Ok(resp)).await); + }, + DuplexBehavior::ImmediateOnBody { status, body } => { + let _headers = stream.message().await.unwrap().unwrap(); + let header_resp = ProcessingResponse { + response: Some(processing_response::Response::RequestHeaders(HeadersResponse { + response: None, + })), + ..Default::default() + }; + drop(tx.send(Ok(header_resp)).await); + let _body_msg = stream.message().await.unwrap().unwrap(); + let resp = build_immediate_response(status, &body); + drop(tx.send(Ok(resp)).await); + }, + DuplexBehavior::StreamedBodyChunks { chunks } => { + let _headers = stream.message().await.unwrap().unwrap(); + loop { + let body_msg = stream.message().await.unwrap().unwrap(); + if let Some(processing_request::Request::RequestBody(b)) = &body_msg.request + && b.end_of_stream + { + break; + } + } + let header_resp = ProcessingResponse { + response: Some(processing_response::Response::RequestHeaders(HeadersResponse { + response: None, + })), + ..Default::default() + }; + drop(tx.send(Ok(header_resp)).await); + + for (i, chunk) in chunks.iter().enumerate() { + let is_last = i == chunks.len() - 1; + use praxis_proto::envoy::service::ext_proc::v3::{ + BodyMutation, CommonResponse, StreamedBodyResponse, body_mutation, + }; + let body_resp = ProcessingResponse { + response: Some(processing_response::Response::RequestBody(BodyResponse { + response: Some(CommonResponse { + body_mutation: Some(BodyMutation { + mutation: Some(body_mutation::Mutation::StreamedResponse( + StreamedBodyResponse { + body: chunk.clone(), + end_of_stream: is_last, + }, + )), + }), + ..Default::default() + }), + })), + ..Default::default() + }; + drop(tx.send(Ok(body_resp)).await); + } + }, + DuplexBehavior::FullLifecycle { + req_header_name, + req_header_value, + resp_header_name, + resp_header_value, + } => { + let header_msg = stream.message().await.unwrap().unwrap(); + let req_resp = build_add_header_response(&header_msg, &req_header_name, &req_header_value); + drop(tx.send(Ok(req_resp)).await); + + while let Ok(Some(msg)) = stream.message().await { + if let Some(processing_request::Request::ResponseHeaders(_)) = msg.request { + let resp_resp = ProcessingResponse { + response: Some(processing_response::Response::ResponseHeaders(HeadersResponse { + response: Some(CommonResponse { + header_mutation: Some(HeaderMutation { + set_headers: vec![make_hvo(&resp_header_name, &resp_header_value)], + remove_headers: vec![], + }), + ..Default::default() + }), + })), + ..Default::default() + }; + drop(tx.send(Ok(resp_resp)).await); + break; + } + } + }, + DuplexBehavior::Hang => { + futures::future::pending::<()>().await; + }, + DuplexBehavior::CloseEarly => {}, + DuplexBehavior::OverrideTimeout { + override_ms, + delay_ms, + name, + value, + } => { + let msg = stream.message().await.unwrap().unwrap(); + let override_resp = build_override_response(override_ms); + drop(tx.send(Ok(override_resp)).await); + tokio::time::sleep(Duration::from_millis(delay_ms)).await; + let real_resp = build_add_header_response(&msg, &name, &value); + drop(tx.send(Ok(real_resp)).await); + }, + DuplexBehavior::UnexpectedResponseType => { + let _msg = stream.message().await.unwrap().unwrap(); + let resp = build_unexpected_body_response(); + drop(tx.send(Ok(resp)).await); + }, + DuplexBehavior::HeadersAndBody => { + let header_msg = stream.message().await.unwrap().unwrap(); + let header_resp = build_noop_response(&header_msg); + drop(tx.send(Ok(header_resp)).await); + + let _body_msg = stream.message().await.unwrap().unwrap(); + let body_resp = ProcessingResponse { + response: Some(processing_response::Response::RequestBody(BodyResponse { + response: None, + })), + ..Default::default() + }; + drop(tx.send(Ok(body_resp)).await); + }, + } + }); + + let output = tokio_stream::wrappers::ReceiverStream::new(rx); + Ok(tonic::Response::new(Box::pin(output))) + } +} + +async fn start_duplex_processor(behavior: DuplexBehavior) -> (SocketAddr, MockServerGuard) { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + + let svc = ExternalProcessorServer::new(DuplexMockProcessor { behavior }); + + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + + wait_for_server(addr).await; + + let guard = MockServerGuard { + shutdown: Some(shutdown_tx), + }; + (addr, guard) +} + +fn make_request_headers() -> processing_request::Request { + processing_request::Request::RequestHeaders(praxis_proto::envoy::service::ext_proc::v3::HttpHeaders { + headers: Some(praxis_proto::envoy::service::ext_proc::v3::HeaderMap { + headers: vec![HeaderValue { + key: ":method".to_owned(), + value: "GET".to_owned(), + raw_value: Vec::new(), + }], + }), + end_of_stream: false, + }) +} + +fn make_request_body(body: &[u8], end_of_stream: bool) -> processing_request::Request { + processing_request::Request::RequestBody(praxis_proto::envoy::service::ext_proc::v3::HttpBody { + body: body.to_vec(), + end_of_stream, + }) +} + +fn make_response_headers() -> processing_request::Request { + processing_request::Request::ResponseHeaders(praxis_proto::envoy::service::ext_proc::v3::HttpHeaders { + headers: Some(praxis_proto::envoy::service::ext_proc::v3::HeaderMap { + headers: vec![HeaderValue { + key: ":status".to_owned(), + value: "200".to_owned(), + raw_value: Vec::new(), + }], + }), + end_of_stream: false, + }) +} + +fn make_request_trailers() -> processing_request::Request { + processing_request::Request::RequestTrailers(praxis_proto::envoy::service::ext_proc::v3::HttpTrailers { + trailers: Some(praxis_proto::envoy::service::ext_proc::v3::HeaderMap { headers: vec![] }), + }) +} + +// ----------------------------------------------------------------------------- +// Duplex Exchange Test Functions +// ----------------------------------------------------------------------------- + +/// Mock that records the protocol_config from the first message. +struct ProtocolConfigRecorder { + recorded: std::sync::Arc>>, +} + +#[async_trait] +impl ExternalProcessor for ProtocolConfigRecorder { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let recorded = self.recorded.clone(); + + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + let mut first = true; + while let Ok(Some(msg)) = stream.message().await { + if first { + *recorded.lock().await = msg.protocol_config; + first = false; + } + let resp = ProcessingResponse { + response: Some(processing_response::Response::RequestHeaders(HeadersResponse { + response: None, + })), + ..Default::default() + }; + drop(tx.send(Ok(resp)).await); + } + }); + + let output = tokio_stream::wrappers::ReceiverStream::new(rx); + Ok(tonic::Response::new(Box::pin(output))) + } +} + +#[tokio::test] +async fn duplex_first_message_includes_protocol_config() { + let recorded = std::sync::Arc::new(tokio::sync::Mutex::new(None)); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + + let svc = ExternalProcessorServer::new(ProtocolConfigRecorder { + recorded: recorded.clone(), + }); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let config = ExchangeConfig { + request_body_mode: BodySendMode::FullDuplexStreamed, + response_body_mode: BodySendMode::FullDuplexStreamed, + ..default_exchange_config() + }; + let mut exchange = ExtProcExchange::open(channel, &config).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let _resp = exchange.receive().await.unwrap(); + + let pc = recorded.lock().await; + let pc = pc.as_ref().expect("first message should include protocol_config"); + assert_eq!( + pc.request_body_mode, 4, + "request_body_mode should be FULL_DUPLEX_STREAMED" + ); + assert_eq!( + pc.response_body_mode, 4, + "response_body_mode should be FULL_DUPLEX_STREAMED" + ); + + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_second_message_omits_protocol_config() { + let all_configs: std::sync::Arc>>> = + std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new())); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + + struct AllConfigRecorder { + configs: std::sync::Arc>>>, + } + + #[async_trait] + impl ExternalProcessor for AllConfigRecorder { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let configs = self.configs.clone(); + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + while let Ok(Some(msg)) = stream.message().await { + configs.lock().await.push(msg.protocol_config); + let resp = build_noop_response(&msg); + drop(tx.send(Ok(resp)).await); + } + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let svc = ExternalProcessorServer::new(AllConfigRecorder { + configs: all_configs.clone(), + }); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &streamed_body_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let _r1 = exchange.receive().await.unwrap(); + exchange.send(make_request_body(b"data", true)).await.unwrap(); + let _r2 = exchange.receive().await.unwrap(); + + drop(exchange); + let _ = shutdown_tx.send(()); + + let configs = all_configs.lock().await; + assert_eq!(configs.len(), 2, "server should have received 2 messages"); + assert!(configs[0].is_some(), "first message should include protocol_config"); + assert!(configs[1].is_none(), "second message should omit protocol_config"); +} + +#[tokio::test] +async fn duplex_request_headers_round_trip() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::EchoHeaders { + name: "x-injected".to_owned(), + value: "from-duplex".to_owned(), + }) + .await; + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &default_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let resp = exchange.receive().await.unwrap(); + assert!( + matches!(resp, ExchangeEvent::RequestHeaders { .. }), + "should receive a response with header mutation" + ); +} + +#[tokio::test] +async fn duplex_request_body_round_trip() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::HeadersAndBody).await; + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &streamed_body_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let _hdr_resp = exchange.receive().await.unwrap(); + exchange.send(make_request_body(b"hello", true)).await.unwrap(); + let body_resp = exchange.receive().await.unwrap(); + assert!( + matches!(body_resp, ExchangeEvent::RequestBody { .. }), + "should receive a body response" + ); +} + +#[tokio::test] +async fn duplex_delayed_routing_no_deadlock() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::DelayedRouting { + header_name: "x-endpoint".to_owned(), + header_value: "10.0.0.1:8080".to_owned(), + }) + .await; + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &full_duplex_exchange_config()).unwrap(); + + exchange.send(make_request_headers()).await.unwrap(); + exchange.send(make_request_body(b"chunk1", false)).await.unwrap(); + exchange.send(make_request_body(b"chunk2", true)).await.unwrap(); + + let header_resp = exchange.receive().await.unwrap(); + assert!( + matches!(header_resp, ExchangeEvent::RequestHeaders { .. }), + "should receive deferred header response after body EOS" + ); + let body_resp = exchange.receive().await.unwrap(); + assert!( + matches!(body_resp, ExchangeEvent::RequestBody { .. }), + "should receive body response after header response" + ); +} + +#[tokio::test] +async fn duplex_multiple_sends_before_any_receive() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::DelayedRouting { + header_name: "x-ep".to_owned(), + header_value: "ep1".to_owned(), + }) + .await; + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &full_duplex_exchange_config()).unwrap(); + + exchange.send(make_request_headers()).await.unwrap(); + exchange.send(make_request_body(b"all-at-once", true)).await.unwrap(); + + let r1 = exchange.receive().await.unwrap(); + let r2 = exchange.receive().await.unwrap(); + assert!( + matches!(r1, ExchangeEvent::RequestHeaders { .. }), + "first response should exist" + ); + assert!( + matches!(r2, ExchangeEvent::RequestBody { .. }), + "second response should exist" + ); +} + +#[tokio::test] +async fn duplex_response_headers_on_same_stream() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::FullLifecycle { + req_header_name: "x-req".to_owned(), + req_header_value: "val".to_owned(), + resp_header_name: "x-resp".to_owned(), + resp_header_value: "val".to_owned(), + }) + .await; + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &default_exchange_config()).unwrap(); + + exchange.send(make_request_headers()).await.unwrap(); + let _req_resp = exchange.receive().await.unwrap(); + + exchange.send(make_response_headers()).await.unwrap(); + let resp_resp = exchange.receive().await.unwrap(); + assert!( + matches!(resp_resp, ExchangeEvent::ResponseHeaders { .. }), + "should receive response headers on the same stream" + ); +} + +#[tokio::test] +async fn duplex_streamed_body_chunks() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::StreamedBodyChunks { + chunks: vec![b"chunk1".to_vec(), b"chunk2".to_vec(), b"chunk3".to_vec()], + }) + .await; + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &full_duplex_exchange_config()).unwrap(); + + exchange.send(make_request_headers()).await.unwrap(); + exchange.send(make_request_body(b"body", true)).await.unwrap(); + + let _header_resp = exchange.receive().await.unwrap(); + + let mut chunks_received = 0; + for _ in 0..3 { + let resp = exchange.receive().await.unwrap(); + assert!( + matches!(resp, ExchangeEvent::RequestBody { .. }), + "should receive body response chunk" + ); + chunks_received += 1; + } + assert_eq!(chunks_received, 3, "should receive all 3 streamed body chunks"); +} + +#[tokio::test] +async fn duplex_immediate_response_on_headers() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::ImmediateOnHeaders { + status: 403, + body: "blocked".to_owned(), + }) + .await; + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &default_exchange_config()).unwrap(); + + exchange.send(make_request_headers()).await.unwrap(); + let resp = exchange.receive().await.unwrap(); + assert!( + matches!(resp, ExchangeEvent::Immediate { .. }), + "should receive ImmediateResponse during headers" + ); +} + +#[tokio::test] +async fn duplex_immediate_response_on_body() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::ImmediateOnBody { + status: 413, + body: "too large".to_owned(), + }) + .await; + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &streamed_body_exchange_config()).unwrap(); + + exchange.send(make_request_headers()).await.unwrap(); + let _hdr = exchange.receive().await.unwrap(); + exchange.send(make_request_body(b"big", true)).await.unwrap(); + let resp = exchange.receive().await.unwrap(); + assert!( + matches!(resp, ExchangeEvent::Immediate { .. }), + "should receive ImmediateResponse during body" + ); +} + +#[tokio::test] +async fn duplex_unexpected_response_type_rejected() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::UnexpectedResponseType).await; + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &default_exchange_config()).unwrap(); + + exchange.send(make_request_headers()).await.unwrap(); + let result = exchange.receive().await; + assert!( + matches!(result, Err(ExchangeError::OrderingViolation(_))), + "wrong-phase response should be rejected by output validation" + ); +} + +#[tokio::test] +async fn duplex_empty_stream_error() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::CloseEarly).await; + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &default_exchange_config()).unwrap(); + + exchange.send(make_request_headers()).await.unwrap(); + let result = exchange.receive().await; + assert!( + matches!(result, Err(ExchangeError::EmptyStream)), + "should return EmptyStream when server closes without responding" + ); +} + +#[tokio::test] +async fn duplex_timeout_before_response() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::Hang).await; + let channel = connect_channel(addr).await; + let config = ExchangeConfig { + message_timeout: Duration::from_millis(50), + ..default_exchange_config() + }; + let mut exchange = ExtProcExchange::open(channel, &config).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let result = exchange.receive().await; + assert!( + matches!(result, Err(ExchangeError::Timeout)), + "should timeout when server hangs" + ); +} + +#[tokio::test] +async fn duplex_timeout_override_accepted() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::OverrideTimeout { + override_ms: 2000, + delay_ms: 200, + name: "x-after".to_owned(), + value: "override".to_owned(), + }) + .await; + let channel = connect_channel(addr).await; + let config = ExchangeConfig { + message_timeout: Duration::from_millis(100), + max_message_timeout: Some(Duration::from_secs(5)), + ..default_exchange_config() + }; + let mut exchange = ExtProcExchange::open(channel, &config).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let resp = exchange.receive().await.unwrap(); + assert!( + matches!(resp, ExchangeEvent::RequestHeaders { .. }), + "override should extend deadline past delay" + ); +} + +#[tokio::test] +async fn duplex_timeout_override_clamped() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::OverrideTimeout { + override_ms: 5000, + delay_ms: 300, + name: "x-late".to_owned(), + value: "val".to_owned(), + }) + .await; + let channel = connect_channel(addr).await; + let config = ExchangeConfig { + message_timeout: Duration::from_millis(100), + max_message_timeout: Some(Duration::from_millis(200)), + ..default_exchange_config() + }; + let mut exchange = ExtProcExchange::open(channel, &config).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let result = exchange.receive().await; + assert!( + matches!(result, Err(ExchangeError::Timeout)), + "clamped override should still timeout" + ); +} + +#[tokio::test] +async fn duplex_timeout_override_ignored_without_max() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::OverrideTimeout { + override_ms: 500, + delay_ms: 0, + name: "x-after".to_owned(), + value: "val".to_owned(), + }) + .await; + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &default_exchange_config()).unwrap(); + + exchange.send(make_request_headers()).await.unwrap(); + // Override envelope is consumed and ignored (no max_timeout + // configured). The real response follows. + let event = exchange.receive().await.unwrap(); + assert!( + matches!(event, ExchangeEvent::RequestHeaders { .. }), + "override without max_timeout is consumed and ignored; real response returned" + ); +} + +#[test] +fn exchange_is_send_and_sync() { + fn assert_send() {} + fn assert_sync() {} + assert_send::(); + assert_sync::(); +} + +#[tokio::test] +async fn duplex_transport_error() { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + drop(listener); + + let channel = Endpoint::from_shared(format!("http://{addr}")).unwrap().connect_lazy(); + + let config = ExchangeConfig { + message_timeout: Duration::from_millis(500), + ..default_exchange_config() + }; + let mut exchange = ExtProcExchange::open(channel, &config).unwrap(); + let send_result = exchange.send(make_request_headers()).await; + if send_result.is_err() { + return; + } + let recv_result = exchange.receive().await; + assert!(recv_result.is_err(), "connecting to closed port should fail on receive"); +} + +#[tokio::test] +async fn duplex_finish_sending_causes_server_eof() { + let eof_observed = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + + struct EofObserver { + eof_observed: std::sync::Arc, + } + + #[async_trait] + impl ExternalProcessor for EofObserver { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let eof_flag = self.eof_observed.clone(); + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + while let Ok(Some(msg)) = stream.message().await { + let resp = build_noop_response(&msg); + drop(tx.send(Ok(resp)).await); + } + eof_flag.store(true, std::sync::atomic::Ordering::SeqCst); + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let svc = ExternalProcessorServer::new(EofObserver { + eof_observed: eof_observed.clone(), + }); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &default_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let _resp = exchange.receive().await.unwrap(); + exchange.finish_sending(); + + tokio::time::sleep(Duration::from_millis(50)).await; + assert!( + eof_observed.load(std::sync::atomic::Ordering::SeqCst), + "server should observe EOF on request stream after finish_sending" + ); + + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_receive_after_finish_sending() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::DelayedRouting { + header_name: "x-ep".to_owned(), + header_value: "ep1".to_owned(), + }) + .await; + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &full_duplex_exchange_config()).unwrap(); + + exchange.send(make_request_headers()).await.unwrap(); + exchange.send(make_request_body(b"data", true)).await.unwrap(); + exchange.finish_sending(); + + let r1 = exchange.receive().await.unwrap(); + assert!( + matches!(r1, ExchangeEvent::RequestHeaders { .. }), + "should still receive after finish_sending" + ); + let r2 = exchange.receive().await.unwrap(); + assert!( + matches!(r2, ExchangeEvent::RequestBody { .. }), + "should receive second response after finish_sending" + ); +} + +#[tokio::test] +async fn duplex_send_after_finish_sending_fails() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::EchoHeaders { + name: "x-test".to_owned(), + value: "ok".to_owned(), + }) + .await; + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &default_exchange_config()).unwrap(); + + exchange.finish_sending(); + let result = exchange.send(make_request_headers()).await; + assert!( + matches!(result, Err(ExchangeError::SendFailed)), + "sending after finish_sending should fail deterministically" + ); +} + +#[tokio::test] +async fn duplex_drop_exchange_cleans_up() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::Hang).await; + let channel = connect_channel(addr).await; + let exchange = ExtProcExchange::open(channel, &default_exchange_config()).unwrap(); + drop(exchange); +} + +#[tokio::test] +async fn duplex_concurrent_exchanges_no_crosstalk() { + struct EchoIdProcessor; + + #[async_trait] + impl ExternalProcessor for EchoIdProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + while let Ok(Some(msg)) = stream.message().await { + if let Some(processing_request::Request::RequestHeaders(h)) = &msg.request { + let id_header = h + .headers + .as_ref() + .and_then(|m| m.headers.iter().find(|hv| hv.key == "x-exchange-id")) + .map(|hv| hv.value.clone()) + .unwrap_or_default(); + let resp = build_add_header_response(&msg, "x-echo-id", &id_header); + drop(tx.send(Ok(resp)).await); + } + } + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(EchoIdProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let shared_channel = connect_channel(addr).await; + + let mut handles = Vec::new(); + for i in 0_u64..100 { + let channel = shared_channel.clone(); + handles.push(tokio::spawn(async move { + let mut exchange = ExtProcExchange::open(channel, &default_exchange_config()).unwrap(); + let unique_id = format!("exchange-{i}"); + let headers = + processing_request::Request::RequestHeaders(praxis_proto::envoy::service::ext_proc::v3::HttpHeaders { + headers: Some(praxis_proto::envoy::service::ext_proc::v3::HeaderMap { + headers: vec![ + HeaderValue { + key: ":method".to_owned(), + value: "GET".to_owned(), + raw_value: Vec::new(), + }, + HeaderValue { + key: "x-exchange-id".to_owned(), + value: unique_id.clone(), + raw_value: Vec::new(), + }, + ], + }), + end_of_stream: false, + }); + exchange.send(headers).await.unwrap(); + let resp = exchange.receive().await.unwrap(); + if let ExchangeEvent::RequestHeaders { response: hr, .. } = &resp + && let Some(common) = &hr.response + && let Some(mutation) = &common.header_mutation + { + let echoed = mutation + .set_headers + .iter() + .find(|hvo| hvo.header.as_ref().is_some_and(|h| h.key == "x-echo-id")) + .and_then(|hvo| hvo.header.as_ref()) + .map(|h| h.value.as_str()); + assert_eq!( + echoed, + Some(unique_id.as_str()), + "exchange {i} should echo back its own unique ID" + ); + return; + } + panic!("exchange {i} did not receive expected echo response"); + })); + } + for handle in handles { + handle.await.unwrap(); + } + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_exchange_is_send_and_sync() { + fn assert_send_sync() {} + assert_send_sync::(); +} + +#[tokio::test] +async fn duplex_existing_fd00_tests_unaffected() { + let (addr, _guard) = start_mock_processor(MockBehavior::AddHeader { + name: "x-existing".to_owned(), + value: "works".to_owned(), + }) + .await; + let channel = connect_channel(addr).await; + let req = make_request(Method::GET, "/test"); + let mut ctx = make_ctx(&req); + let action = callout::process_request_headers(channel, &addr.to_string(), Duration::from_secs(5), None, &mut ctx) + .await + .unwrap(); + assert!( + matches!(action, FilterAction::Continue), + "existing callout should still work alongside duplex module" + ); +} + +#[tokio::test] +async fn duplex_terminal_state_after_timeout() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::Hang).await; + let channel = connect_channel(addr).await; + let config = ExchangeConfig { + message_timeout: Duration::from_millis(50), + ..default_exchange_config() + }; + let mut exchange = ExtProcExchange::open(channel, &config).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let _timeout = exchange.receive().await; + assert!(exchange.is_terminal(), "exchange should be closed after timeout"); + + let send_result = exchange.send(make_request_headers()).await; + assert!( + matches!(send_result, Err(ExchangeError::Closed)), + "send after timeout should return Closed" + ); + let recv_result = exchange.receive().await; + assert!( + matches!(recv_result, Err(ExchangeError::Closed)), + "receive after timeout should return Closed" + ); +} + +#[tokio::test] +async fn duplex_response_body_round_trip() { + struct ResponseBodyProcessor; + + #[async_trait] + impl ExternalProcessor for ResponseBodyProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + while let Ok(Some(msg)) = stream.message().await { + let resp = match &msg.request { + Some(processing_request::Request::RequestHeaders(_)) => ProcessingResponse { + response: Some(processing_response::Response::RequestHeaders(HeadersResponse { + response: None, + })), + ..Default::default() + }, + Some(processing_request::Request::ResponseHeaders(_)) => ProcessingResponse { + response: Some(processing_response::Response::ResponseHeaders(HeadersResponse { + response: None, + })), + ..Default::default() + }, + Some(processing_request::Request::ResponseBody(_)) => ProcessingResponse { + response: Some(processing_response::Response::ResponseBody(BodyResponse { + response: None, + })), + ..Default::default() + }, + _ => ProcessingResponse::default(), + }; + drop(tx.send(Ok(resp)).await); + } + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(ResponseBodyProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &streamed_body_exchange_config()).unwrap(); + + exchange.send(make_request_headers()).await.unwrap(); + let _req_hdr = exchange.receive().await.unwrap(); + + exchange.send(make_response_headers()).await.unwrap(); + let _resp_hdr = exchange.receive().await.unwrap(); + + let resp_body = processing_request::Request::ResponseBody(praxis_proto::envoy::service::ext_proc::v3::HttpBody { + body: b"response body data".to_vec(), + end_of_stream: true, + }); + exchange.send(resp_body).await.unwrap(); + let resp = exchange.receive().await.unwrap(); + assert!( + matches!(resp, ExchangeEvent::ResponseBody { .. }), + "should receive ResponseBody response" + ); + + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_server_observes_client_cancellation() { + let cancelled = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + + struct CancellationObserver { + cancelled: std::sync::Arc, + } + + #[async_trait] + impl ExternalProcessor for CancellationObserver { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let flag = self.cancelled.clone(); + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + if let Ok(Some(_msg)) = stream.message().await { + let resp = ProcessingResponse { + response: Some(processing_response::Response::RequestHeaders(HeadersResponse { + response: None, + })), + ..Default::default() + }; + drop(tx.send(Ok(resp)).await); + } + while let Ok(Some(_msg)) = stream.message().await {} + flag.store(true, std::sync::atomic::Ordering::SeqCst); + drop(tx); + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let svc = ExternalProcessorServer::new(CancellationObserver { + cancelled: cancelled.clone(), + }); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &default_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let _resp = exchange.receive().await.unwrap(); + drop(exchange); + + tokio::time::sleep(Duration::from_millis(200)).await; + assert!( + cancelled.load(std::sync::atomic::Ordering::SeqCst), + "server should observe client cancellation when exchange is dropped" + ); + + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_repeated_clean_close() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::EchoHeaders { + name: "x-close".to_owned(), + value: "test".to_owned(), + }) + .await; + + let shared_channel = connect_channel(addr).await; + + for i in 0..20 { + let mut exchange = ExtProcExchange::open(shared_channel.clone(), &default_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let resp = exchange.receive().await.unwrap(); + assert!( + matches!(resp, ExchangeEvent::RequestHeaders { .. }), + "exchange {i} should receive a response" + ); + exchange.finish_sending(); + } +} + +// ----------------------------------------------------------------------------- +// Directional State and Ordering Tests +// ----------------------------------------------------------------------------- + +#[tokio::test] +async fn duplex_request_body_before_headers_rejected() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::EchoHeaders { + name: "x-t".to_owned(), + value: "v".to_owned(), + }) + .await; + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &streamed_body_exchange_config()).unwrap(); + let result = exchange.send(make_request_body(b"data", false)).await; + assert!( + matches!(result, Err(ExchangeError::OrderingViolation(_))), + "body before headers should be rejected" + ); +} + +#[tokio::test] +async fn duplex_duplicate_request_headers_rejected() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::EchoHeaders { + name: "x-t".to_owned(), + value: "v".to_owned(), + }) + .await; + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &default_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let result = exchange.send(make_request_headers()).await; + assert!( + matches!(result, Err(ExchangeError::OrderingViolation(_))), + "duplicate request headers should be rejected" + ); +} + +#[tokio::test] +async fn duplex_body_after_eos_rejected() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::HeadersAndBody).await; + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &full_duplex_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + exchange.send(make_request_body(b"data", true)).await.unwrap(); + let result = exchange.send(make_request_body(b"more", false)).await; + assert!( + matches!(result, Err(ExchangeError::OrderingViolation(_))), + "body after EOS should be rejected" + ); +} + +#[tokio::test] +async fn duplex_legal_response_while_request_open() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::FullLifecycle { + req_header_name: "x-r".to_owned(), + req_header_value: "v".to_owned(), + resp_header_name: "x-s".to_owned(), + resp_header_value: "v".to_owned(), + }) + .await; + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &default_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let _req_resp = exchange.receive().await.unwrap(); + exchange.send(make_response_headers()).await.unwrap(); + let resp_resp = exchange.receive().await.unwrap(); + assert!( + matches!(resp_resp, ExchangeEvent::ResponseHeaders { .. }), + "response headers should be legal while request direction has only sent headers" + ); +} + +#[tokio::test] +async fn duplex_request_trailers_send_and_classify() { + struct TrailerProcessor; + + #[async_trait] + impl ExternalProcessor for TrailerProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + while let Ok(Some(msg)) = stream.message().await { + let resp = match &msg.request { + Some(processing_request::Request::RequestHeaders(_)) => ProcessingResponse { + response: Some(processing_response::Response::RequestHeaders(HeadersResponse { + response: None, + })), + ..Default::default() + }, + Some(processing_request::Request::RequestTrailers(_)) => ProcessingResponse { + response: Some(processing_response::Response::RequestTrailers(TrailersResponse { + header_mutation: None, + })), + ..Default::default() + }, + _ => ProcessingResponse::default(), + }; + drop(tx.send(Ok(resp)).await); + } + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(TrailerProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &default_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let _hdr = exchange.receive().await.unwrap(); + + let trailers = + processing_request::Request::RequestTrailers(praxis_proto::envoy::service::ext_proc::v3::HttpTrailers { + trailers: Some(praxis_proto::envoy::service::ext_proc::v3::HeaderMap { headers: vec![] }), + }); + exchange.send(trailers).await.unwrap(); + let event = exchange.receive().await.unwrap(); + assert!( + matches!(event, ExchangeEvent::RequestTrailers { .. }), + "should classify request trailers response" + ); + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_dynamic_metadata_preserved() { + struct MetadataProcessor; + + #[async_trait] + impl ExternalProcessor for MetadataProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + if let Ok(Some(_msg)) = stream.message().await { + let mut fields = HashMap::new(); + fields.insert( + "test_key".to_owned(), + prost_wkt_types::Value { + kind: Some(prost_wkt_types::value::Kind::StringValue("test_value".to_owned())), + }, + ); + let resp = ProcessingResponse { + response: Some(processing_response::Response::RequestHeaders(HeadersResponse { + response: None, + })), + dynamic_metadata: Some(prost_wkt_types::Struct { fields }), + ..Default::default() + }; + drop(tx.send(Ok(resp)).await); + } + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(MetadataProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &default_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let event = exchange.receive().await.unwrap(); + match event { + ExchangeEvent::RequestHeaders { metadata, .. } => { + let md = metadata.expect("metadata should be present"); + assert!( + md.fields.contains_key("test_key"), + "dynamic_metadata should be preserved on typed event" + ); + }, + other => panic!("expected RequestHeaders, got {other:?}"), + } + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_full_duplex_no_per_message_timeout() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::DelayedRouting { + header_name: "x-ep".to_owned(), + header_value: "ep1".to_owned(), + }) + .await; + let channel = connect_channel(addr).await; + let config = ExchangeConfig { + message_timeout: Duration::from_millis(50), + request_body_mode: BodySendMode::FullDuplexStreamed, + ..default_exchange_config() + }; + let mut exchange = ExtProcExchange::open(channel, &config).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + exchange.send(make_request_body(b"data", true)).await.unwrap(); + let event = exchange.receive().await.unwrap(); + assert!( + matches!(event, ExchangeEvent::RequestHeaders { .. }), + "full-duplex receive without timeout should succeed even with low message_timeout" + ); +} + +#[tokio::test] +async fn duplex_immediate_response_sets_terminal() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::ImmediateOnHeaders { + status: 403, + body: "blocked".to_owned(), + }) + .await; + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &default_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let event = exchange.receive().await.unwrap(); + assert!( + matches!(event, ExchangeEvent::Immediate { .. }), + "should be immediate event" + ); + assert!( + exchange.is_terminal(), + "exchange should be terminal after ImmediateResponse" + ); + let send_result = exchange.send(make_request_body(b"data", true)).await; + assert!( + matches!(send_result, Err(ExchangeError::Closed)), + "send after immediate should return Closed" + ); +} + +#[tokio::test] +async fn duplex_override_envelope_ignores_response_data() { + struct OverrideWithResponseProcessor; + + #[async_trait] + impl ExternalProcessor for OverrideWithResponseProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + if let Ok(Some(msg)) = stream.message().await { + let override_with_response = ProcessingResponse { + response: Some(processing_response::Response::RequestHeaders(HeadersResponse { + response: None, + })), + override_message_timeout: Some(prost_types::Duration { seconds: 2, nanos: 0 }), + ..Default::default() + }; + drop(tx.send(Ok(override_with_response)).await); + tokio::time::sleep(Duration::from_millis(100)).await; + let real_resp = build_add_header_response(&msg, "x-real", "response"); + drop(tx.send(Ok(real_resp)).await); + } + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(OverrideWithResponseProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let config = ExchangeConfig { + max_message_timeout: Some(Duration::from_secs(5)), + ..default_exchange_config() + }; + let mut exchange = ExtProcExchange::open(channel, &config).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let event = exchange.receive().await.unwrap(); + match &event { + ExchangeEvent::RequestHeaders { response, .. } => { + assert!( + response.response.is_some(), + "should receive the REAL response, not the override envelope's response" + ); + }, + other => panic!("expected RequestHeaders from real response, got {other:?}"), + } + drop(exchange); + let _ = shutdown_tx.send(()); +} + +// ----------------------------------------------------------------------------- +// Duplex Exchange Evidence Tests +// ----------------------------------------------------------------------------- + +#[tokio::test] +async fn duplex_body_mode_none_rejects_body_send() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::EchoHeaders { + name: "x-t".to_owned(), + value: "v".to_owned(), + }) + .await; + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &default_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let _hdr = exchange.receive().await.unwrap(); + let result = exchange.send(make_request_body(b"rejected", false)).await; + assert!( + matches!(result, Err(ExchangeError::OrderingViolation(_))), + "body send with BodySendMode::None should be rejected with OrderingViolation" + ); +} + +#[tokio::test] +async fn duplex_non_full_duplex_body_creates_active_state() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::HeadersAndBody).await; + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &streamed_body_exchange_config()).unwrap(); + + exchange.send(make_request_headers()).await.unwrap(); + let hdr_resp = exchange.receive().await.unwrap(); + assert!( + matches!(hdr_resp, ExchangeEvent::RequestHeaders { .. }), + "should receive header response before sending body" + ); + + exchange.send(make_request_body(b"chunk1", true)).await.unwrap(); + let body_resp = exchange.receive().await.unwrap(); + assert!( + matches!(body_resp, ExchangeEvent::RequestBody { .. }), + "non-full-duplex body chunk must receive body response before sending another" + ); +} + +#[tokio::test] +async fn duplex_second_non_fd_send_while_active_rejected() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::HeadersAndBody).await; + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &streamed_body_exchange_config()).unwrap(); + + exchange.send(make_request_headers()).await.unwrap(); + let result = exchange.send(make_request_body(b"chunk", false)).await; + assert!( + matches!(result, Err(ExchangeError::OrderingViolation(_))), + "sending body before headers response should fail because active state is already outstanding" + ); +} + +#[tokio::test] +async fn duplex_response_trailers_send_and_classify() { + struct ResponseTrailerProcessor; + + #[async_trait] + impl ExternalProcessor for ResponseTrailerProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + while let Ok(Some(msg)) = stream.message().await { + let resp = match &msg.request { + Some(processing_request::Request::RequestHeaders(_)) => ProcessingResponse { + response: Some(processing_response::Response::RequestHeaders(HeadersResponse { + response: None, + })), + ..Default::default() + }, + Some(processing_request::Request::ResponseHeaders(_)) => ProcessingResponse { + response: Some(processing_response::Response::ResponseHeaders(HeadersResponse { + response: None, + })), + ..Default::default() + }, + Some(processing_request::Request::ResponseTrailers(_)) => ProcessingResponse { + response: Some(processing_response::Response::ResponseTrailers(TrailersResponse { + header_mutation: None, + })), + ..Default::default() + }, + _ => ProcessingResponse::default(), + }; + drop(tx.send(Ok(resp)).await); + } + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(ResponseTrailerProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &default_exchange_config()).unwrap(); + + exchange.send(make_request_headers()).await.unwrap(); + let _req_hdr = exchange.receive().await.unwrap(); + + exchange.send(make_response_headers()).await.unwrap(); + let _resp_hdr = exchange.receive().await.unwrap(); + + let trailers = + processing_request::Request::ResponseTrailers(praxis_proto::envoy::service::ext_proc::v3::HttpTrailers { + trailers: Some(praxis_proto::envoy::service::ext_proc::v3::HeaderMap { headers: vec![] }), + }); + exchange.send(trailers).await.unwrap(); + let event = exchange.receive().await.unwrap(); + assert!( + matches!(event, ExchangeEvent::ResponseTrailers { .. }), + "should classify response trailers response" + ); + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_metadata_on_body_event() { + struct BodyMetadataProcessor; + + #[async_trait] + impl ExternalProcessor for BodyMetadataProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + while let Ok(Some(msg)) = stream.message().await { + let resp = match &msg.request { + Some(processing_request::Request::RequestHeaders(_)) => ProcessingResponse { + response: Some(processing_response::Response::RequestHeaders(HeadersResponse { + response: None, + })), + ..Default::default() + }, + Some(processing_request::Request::RequestBody(_)) => { + let mut fields = HashMap::new(); + fields.insert( + "body_key".to_owned(), + prost_wkt_types::Value { + kind: Some(prost_wkt_types::value::Kind::StringValue("body_value".to_owned())), + }, + ); + ProcessingResponse { + response: Some(processing_response::Response::RequestBody(BodyResponse { + response: None, + })), + dynamic_metadata: Some(prost_wkt_types::Struct { fields }), + ..Default::default() + } + }, + _ => ProcessingResponse::default(), + }; + drop(tx.send(Ok(resp)).await); + } + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(BodyMetadataProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &streamed_body_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let _hdr = exchange.receive().await.unwrap(); + exchange.send(make_request_body(b"data", true)).await.unwrap(); + let event = exchange.receive().await.unwrap(); + match event { + ExchangeEvent::RequestBody { metadata, .. } => { + let md = metadata.expect("metadata should be present on body event"); + assert!( + md.fields.contains_key("body_key"), + "dynamic_metadata should be preserved on ExchangeEvent::RequestBody" + ); + }, + other => panic!("expected RequestBody, got {other:?}"), + } + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_metadata_on_immediate_event() { + struct ImmediateMetadataProcessor; + + #[async_trait] + impl ExternalProcessor for ImmediateMetadataProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + if let Ok(Some(_msg)) = stream.message().await { + let mut fields = HashMap::new(); + fields.insert( + "imm_key".to_owned(), + prost_wkt_types::Value { + kind: Some(prost_wkt_types::value::Kind::StringValue("imm_value".to_owned())), + }, + ); + let resp = ProcessingResponse { + response: Some(processing_response::Response::ImmediateResponse(ImmediateResponse { + status: Some(HttpStatus { code: 429 }), + headers: None, + body: "rate limited".to_owned(), + grpc_status: None, + details: String::new(), + })), + dynamic_metadata: Some(prost_wkt_types::Struct { fields }), + ..Default::default() + }; + drop(tx.send(Ok(resp)).await); + } + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(ImmediateMetadataProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &default_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let event = exchange.receive().await.unwrap(); + match event { + ExchangeEvent::Immediate { metadata, .. } => { + let md = metadata.expect("metadata should be present on immediate event"); + assert!( + md.fields.contains_key("imm_key"), + "dynamic_metadata should be preserved on ExchangeEvent::Immediate" + ); + }, + other => panic!("expected Immediate, got {other:?}"), + } + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_override_ignored_in_full_duplex() { + struct FullDuplexOverrideProcessor; + + #[async_trait] + impl ExternalProcessor for FullDuplexOverrideProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(8); + tokio::spawn(async move { + let _headers = stream.message().await.unwrap().unwrap(); + let override_envelope = build_override_response(5000); + drop(tx.send(Ok(override_envelope)).await); + let header_resp = ProcessingResponse { + response: Some(processing_response::Response::RequestHeaders(HeadersResponse { + response: None, + })), + ..Default::default() + }; + drop(tx.send(Ok(header_resp)).await); + + let _body = stream.message().await.unwrap().unwrap(); + use praxis_proto::envoy::service::ext_proc::v3::{ + BodyMutation, CommonResponse, StreamedBodyResponse, body_mutation, + }; + let body_resp = ProcessingResponse { + response: Some(processing_response::Response::RequestBody(BodyResponse { + response: Some(CommonResponse { + body_mutation: Some(BodyMutation { + mutation: Some(body_mutation::Mutation::StreamedResponse(StreamedBodyResponse { + body: b"data".to_vec(), + end_of_stream: true, + })), + }), + ..Default::default() + }), + })), + ..Default::default() + }; + drop(tx.send(Ok(body_resp)).await); + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(FullDuplexOverrideProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let config = ExchangeConfig { + max_message_timeout: Some(Duration::from_secs(10)), + ..full_duplex_exchange_config() + }; + let mut exchange = ExtProcExchange::open(channel, &config).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let hdr_event = exchange.receive().await.unwrap(); + assert!( + matches!(hdr_event, ExchangeEvent::RequestHeaders { .. }), + "override envelope ignored; real header response returned" + ); + + exchange.send(make_request_body(b"data", true)).await.unwrap(); + let body_event = exchange.receive().await.unwrap(); + assert!( + matches!(body_event, ExchangeEvent::RequestBody { .. }), + "should receive body response in full-duplex mode" + ); + + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_repeated_override_ignored() { + struct DoubleOverrideProcessor; + + #[async_trait] + impl ExternalProcessor for DoubleOverrideProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(8); + tokio::spawn(async move { + let msg = stream.message().await.unwrap().unwrap(); + let override1 = build_override_response(2000); + drop(tx.send(Ok(override1)).await); + let override2 = build_override_response(3000); + drop(tx.send(Ok(override2)).await); + let real_resp = build_add_header_response(&msg, "x-real", "response"); + drop(tx.send(Ok(real_resp)).await); + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(DoubleOverrideProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let config = ExchangeConfig { + max_message_timeout: Some(Duration::from_secs(10)), + ..default_exchange_config() + }; + let mut exchange = ExtProcExchange::open(channel, &config).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let event = exchange.receive().await.unwrap(); + match &event { + ExchangeEvent::RequestHeaders { response, .. } => { + assert!( + response.response.is_some(), + "should receive the real response with header mutation, not an override envelope" + ); + }, + other => panic!("expected RequestHeaders from real response, got {other:?}"), + } + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_zero_duration_override_ignored() { + struct ZeroOverrideProcessor; + + #[async_trait] + impl ExternalProcessor for ZeroOverrideProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + let msg = stream.message().await.unwrap().unwrap(); + let zero_override = ProcessingResponse { + override_message_timeout: Some(prost_types::Duration { seconds: 0, nanos: 0 }), + ..Default::default() + }; + drop(tx.send(Ok(zero_override)).await); + let real_resp = build_add_header_response(&msg, "x-real", "response"); + drop(tx.send(Ok(real_resp)).await); + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(ZeroOverrideProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let config = ExchangeConfig { + max_message_timeout: Some(Duration::from_secs(10)), + ..default_exchange_config() + }; + let mut exchange = ExtProcExchange::open(channel, &config).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let event = exchange.receive().await.unwrap(); + match &event { + ExchangeEvent::RequestHeaders { response, .. } => { + assert!( + response.response.is_some(), + "should receive the real response, not the zero-override envelope" + ); + }, + other => panic!("expected RequestHeaders from real response, got {other:?}"), + } + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_backpressure_deterministic() { + struct BarrierProcessor { + barrier: std::sync::Arc, + } + + #[async_trait] + impl ExternalProcessor for BarrierProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let barrier = self.barrier.clone(); + let (tx, rx) = tokio::sync::mpsc::channel(16); + tokio::spawn(async move { + let _msg = stream.message().await.unwrap().unwrap(); + barrier.wait().await; + while let Ok(Some(msg)) = stream.message().await { + let resp = build_noop_response(&msg); + drop(tx.send(Ok(resp)).await); + } + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let barrier = std::sync::Arc::new(tokio::sync::Barrier::new(2)); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(BarrierProcessor { + barrier: barrier.clone(), + }); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &full_duplex_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + + let chunk_size = 16_384; // 16 KiB + let big_chunk = vec![0xAB_u8; chunk_size]; + let max_attempts = 256; + let mut sent_count = 0_usize; + for _ in 0..max_attempts { + let send_fut = exchange.send(make_request_body(&big_chunk, false)); + let result = tokio::time::timeout(Duration::from_millis(200), send_fut).await; + if result.is_err() { + break; + } + result.unwrap().unwrap(); + sent_count += 1; + } + assert!( + sent_count < max_attempts, + "sends should eventually block due to backpressure; sent all {max_attempts} without blocking" + ); + + barrier.wait().await; + + let resume_send = exchange.send(make_request_body(b"after-release", true)); + let result = tokio::time::timeout(Duration::from_secs(2), resume_send).await; + assert!(result.is_ok(), "sends should resume after barrier is released"); + + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_deadline_starts_at_send_commit() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::Hang).await; + let channel = connect_channel(addr).await; + let config = ExchangeConfig { + message_timeout: Duration::from_millis(50), + ..default_exchange_config() + }; + let mut exchange = ExtProcExchange::open(channel, &config).unwrap(); + + let before_send = tokio::time::Instant::now(); + exchange.send(make_request_headers()).await.unwrap(); + + tokio::time::sleep(Duration::from_millis(30)).await; + + let result = exchange.receive().await; + let elapsed = before_send.elapsed(); + assert!( + matches!(result, Err(ExchangeError::Timeout)), + "should timeout when server hangs" + ); + assert!( + elapsed < Duration::from_millis(100), + "deadline should be ~50ms from send, not from receive; elapsed: {elapsed:?}" + ); + assert!( + elapsed >= Duration::from_millis(40), + "deadline should not expire before the configured timeout; elapsed: {elapsed:?}" + ); +} + +#[tokio::test] +async fn duplex_unsolicited_response_rejected() { + struct UnsolicitedResponseProcessor; + + #[async_trait] + impl ExternalProcessor for UnsolicitedResponseProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + let _msg = stream.message().await.unwrap().unwrap(); + let resp = ProcessingResponse { + response: Some(processing_response::Response::ResponseHeaders(HeadersResponse { + response: None, + })), + ..Default::default() + }; + drop(tx.send(Ok(resp)).await); + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(UnsolicitedResponseProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &default_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let result = exchange.receive().await; + assert!( + matches!(result, Err(ExchangeError::OrderingViolation(_))), + "unsolicited response for direction with no outbound headers should be rejected" + ); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("expected") || err.contains("unsolicited"), + "error should indicate wrong response type: {err}" + ); + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_full_duplex_headers_no_timeout() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::OverrideTimeout { + override_ms: 0, + delay_ms: 200, + name: "x-delayed".to_owned(), + value: "ok".to_owned(), + }) + .await; + let channel = connect_channel(addr).await; + let config = ExchangeConfig { + message_timeout: Duration::from_millis(50), + request_body_mode: BodySendMode::FullDuplexStreamed, + ..default_exchange_config() + }; + let mut exchange = ExtProcExchange::open(channel, &config).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let event = exchange.receive().await.unwrap(); + assert!( + matches!(event, ExchangeEvent::RequestHeaders { .. }), + "full-duplex headers should not timeout even when response is delayed past message_timeout" + ); +} + +#[tokio::test] +async fn duplex_full_duplex_trailers_while_deferred() { + struct DeferredTrailerProcessor; + + #[async_trait] + impl ExternalProcessor for DeferredTrailerProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(16); + tokio::spawn(async move { + let mut messages = Vec::new(); + while let Ok(Some(msg)) = stream.message().await { + let is_trailers = matches!(msg.request, Some(processing_request::Request::RequestTrailers(_))); + messages.push(msg); + if is_trailers { + break; + } + } + for msg in &messages { + let resp = match &msg.request { + Some(processing_request::Request::RequestHeaders(_)) => ProcessingResponse { + response: Some(processing_response::Response::RequestHeaders(HeadersResponse { + response: None, + })), + ..Default::default() + }, + Some(processing_request::Request::RequestBody(_)) => { + use praxis_proto::envoy::service::ext_proc::v3::{ + BodyMutation, CommonResponse, StreamedBodyResponse, body_mutation, + }; + ProcessingResponse { + response: Some(processing_response::Response::RequestBody(BodyResponse { + response: Some(CommonResponse { + body_mutation: Some(BodyMutation { + mutation: Some(body_mutation::Mutation::StreamedResponse( + StreamedBodyResponse { + body: Vec::new(), + end_of_stream: false, + }, + )), + }), + ..Default::default() + }), + })), + ..Default::default() + } + }, + Some(processing_request::Request::RequestTrailers(_)) => ProcessingResponse { + response: Some(processing_response::Response::RequestTrailers(TrailersResponse { + header_mutation: None, + })), + ..Default::default() + }, + _ => ProcessingResponse::default(), + }; + drop(tx.send(Ok(resp)).await); + } + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(DeferredTrailerProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &full_duplex_exchange_config()).unwrap(); + + exchange.send(make_request_headers()).await.unwrap(); + exchange.send(make_request_body(b"chunk1", false)).await.unwrap(); + exchange.send(make_request_body(b"chunk2", false)).await.unwrap(); + + let trailers = + processing_request::Request::RequestTrailers(praxis_proto::envoy::service::ext_proc::v3::HttpTrailers { + trailers: Some(praxis_proto::envoy::service::ext_proc::v3::HeaderMap { headers: vec![] }), + }); + exchange.send(trailers).await.unwrap(); + + let r1 = exchange.receive().await.unwrap(); + assert!( + matches!(r1, ExchangeEvent::RequestHeaders { .. }), + "should receive deferred header response" + ); + let r2 = exchange.receive().await.unwrap(); + assert!( + matches!(r2, ExchangeEvent::RequestBody { .. }), + "should receive body response" + ); + let r3 = exchange.receive().await.unwrap(); + assert!( + matches!(r3, ExchangeEvent::RequestBody { .. }), + "should receive second body response" + ); + let r4 = exchange.receive().await.unwrap(); + assert!( + matches!(r4, ExchangeEvent::RequestTrailers { .. }), + "should receive trailer response" + ); + + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_streamed_body_response_in_non_fd_rejected() { + struct StreamedInNonFdProcessor; + + #[async_trait] + impl ExternalProcessor for StreamedInNonFdProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + let _headers = stream.message().await.unwrap().unwrap(); + let header_resp = ProcessingResponse { + response: Some(processing_response::Response::RequestHeaders(HeadersResponse { + response: None, + })), + ..Default::default() + }; + drop(tx.send(Ok(header_resp)).await); + + let _body = stream.message().await.unwrap().unwrap(); + use praxis_proto::envoy::service::ext_proc::v3::{ + BodyMutation, CommonResponse, StreamedBodyResponse, body_mutation, + }; + let body_resp = ProcessingResponse { + response: Some(processing_response::Response::RequestBody(BodyResponse { + response: Some(CommonResponse { + body_mutation: Some(BodyMutation { + mutation: Some(body_mutation::Mutation::StreamedResponse(StreamedBodyResponse { + body: b"streamed".to_vec(), + end_of_stream: true, + })), + }), + ..Default::default() + }), + })), + ..Default::default() + }; + drop(tx.send(Ok(body_resp)).await); + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(StreamedInNonFdProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &streamed_body_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let _hdr = exchange.receive().await.unwrap(); + exchange.send(make_request_body(b"data", true)).await.unwrap(); + let result = exchange.receive().await; + assert!( + matches!(result, Err(ExchangeError::OrderingViolation(_))), + "StreamedBodyResponse mutation in non-full-duplex mode should be rejected" + ); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("StreamedBodyResponse"), + "error should mention StreamedBodyResponse: {err}" + ); + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_non_streamed_body_response_in_fd_rejected() { + struct NonStreamedInFdProcessor; + + #[async_trait] + impl ExternalProcessor for NonStreamedInFdProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + let _headers = stream.message().await.unwrap().unwrap(); + let header_resp = ProcessingResponse { + response: Some(processing_response::Response::RequestHeaders(HeadersResponse { + response: None, + })), + ..Default::default() + }; + drop(tx.send(Ok(header_resp)).await); + + let _body = stream.message().await.unwrap().unwrap(); + let body_resp = ProcessingResponse { + response: Some(processing_response::Response::RequestBody(BodyResponse { + response: None, + })), + ..Default::default() + }; + drop(tx.send(Ok(body_resp)).await); + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(NonStreamedInFdProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &full_duplex_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let _hdr = exchange.receive().await.unwrap(); + exchange.send(make_request_body(b"data", true)).await.unwrap(); + let result = exchange.receive().await; + assert!( + matches!(result, Err(ExchangeError::OrderingViolation(_))), + "non-StreamedBodyResponse mutation in full-duplex mode should be rejected" + ); + let err = result.unwrap_err().to_string(); + assert!(err.contains("full-duplex"), "error should mention full-duplex: {err}"); + drop(exchange); + let _ = shutdown_tx.send(()); +} + +// ----------------------------------------------------------------------------- +// Duplex Exchange Regression Tests +// ----------------------------------------------------------------------------- + +#[tokio::test] +async fn duplex_request_body_response_without_body_send_rejected() { + struct HeadersOnlyFdProcessor; + + #[async_trait] + impl ExternalProcessor for HeadersOnlyFdProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + let _headers = stream.message().await.unwrap().unwrap(); + let header_resp = ProcessingResponse { + response: Some(processing_response::Response::RequestHeaders(HeadersResponse { + response: None, + })), + ..Default::default() + }; + drop(tx.send(Ok(header_resp)).await); + use praxis_proto::envoy::service::ext_proc::v3::{ + BodyMutation, CommonResponse, StreamedBodyResponse, body_mutation, + }; + let body_resp = ProcessingResponse { + response: Some(processing_response::Response::RequestBody(BodyResponse { + response: Some(CommonResponse { + body_mutation: Some(BodyMutation { + mutation: Some(body_mutation::Mutation::StreamedResponse(StreamedBodyResponse { + body: b"unsolicited".to_vec(), + end_of_stream: true, + })), + }), + ..Default::default() + }), + })), + ..Default::default() + }; + drop(tx.send(Ok(body_resp)).await); + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(HeadersOnlyFdProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &full_duplex_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let _hdr = exchange.receive().await.unwrap(); + let result = exchange.receive().await; + assert!( + matches!(result, Err(ExchangeError::OrderingViolation(_))), + "body response without any body send should be rejected in full-duplex" + ); + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_request_trailer_response_without_trailer_send_rejected() { + struct HeadersBodyNoTrailerFdProcessor; + + #[async_trait] + impl ExternalProcessor for HeadersBodyNoTrailerFdProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(8); + tokio::spawn(async move { + let _headers = stream.message().await.unwrap().unwrap(); + let header_resp = ProcessingResponse { + response: Some(processing_response::Response::RequestHeaders(HeadersResponse { + response: None, + })), + ..Default::default() + }; + drop(tx.send(Ok(header_resp)).await); + + let _body = stream.message().await.unwrap().unwrap(); + use praxis_proto::envoy::service::ext_proc::v3::{ + BodyMutation, CommonResponse, StreamedBodyResponse, body_mutation, + }; + let body_resp = ProcessingResponse { + response: Some(processing_response::Response::RequestBody(BodyResponse { + response: Some(CommonResponse { + body_mutation: Some(BodyMutation { + mutation: Some(body_mutation::Mutation::StreamedResponse(StreamedBodyResponse { + body: b"data".to_vec(), + end_of_stream: true, + })), + }), + ..Default::default() + }), + })), + ..Default::default() + }; + drop(tx.send(Ok(body_resp)).await); + + let trailer_resp = ProcessingResponse { + response: Some(processing_response::Response::RequestTrailers(TrailersResponse { + header_mutation: None, + })), + ..Default::default() + }; + drop(tx.send(Ok(trailer_resp)).await); + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(HeadersBodyNoTrailerFdProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &full_duplex_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + exchange.send(make_request_body(b"data", true)).await.unwrap(); + let _hdr = exchange.receive().await.unwrap(); + let _body = exchange.receive().await.unwrap(); + let result = exchange.receive().await; + assert!( + matches!(result, Err(ExchangeError::OrderingViolation(_))), + "trailer response without trailer send should be rejected in full-duplex" + ); + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_response_body_response_without_body_send_rejected() { + struct ResponseHeadersOnlyFdProcessor; + + #[async_trait] + impl ExternalProcessor for ResponseHeadersOnlyFdProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + let _req_hdrs = stream.message().await.unwrap().unwrap(); + let req_hdr_resp = ProcessingResponse { + response: Some(processing_response::Response::RequestHeaders(HeadersResponse { + response: None, + })), + ..Default::default() + }; + drop(tx.send(Ok(req_hdr_resp)).await); + + let _resp_hdrs = stream.message().await.unwrap().unwrap(); + let resp_hdr_resp = ProcessingResponse { + response: Some(processing_response::Response::ResponseHeaders(HeadersResponse { + response: None, + })), + ..Default::default() + }; + drop(tx.send(Ok(resp_hdr_resp)).await); + + use praxis_proto::envoy::service::ext_proc::v3::{ + BodyMutation, CommonResponse, StreamedBodyResponse, body_mutation, + }; + let body_resp = ProcessingResponse { + response: Some(processing_response::Response::ResponseBody(BodyResponse { + response: Some(CommonResponse { + body_mutation: Some(BodyMutation { + mutation: Some(body_mutation::Mutation::StreamedResponse(StreamedBodyResponse { + body: b"unsolicited".to_vec(), + end_of_stream: true, + })), + }), + ..Default::default() + }), + })), + ..Default::default() + }; + drop(tx.send(Ok(body_resp)).await); + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(ResponseHeadersOnlyFdProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &full_duplex_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let _req_hdr = exchange.receive().await.unwrap(); + exchange.send(make_response_headers()).await.unwrap(); + let _resp_hdr = exchange.receive().await.unwrap(); + let result = exchange.receive().await; + assert!( + matches!(result, Err(ExchangeError::OrderingViolation(_))), + "response body response without body send should be rejected in full-duplex" + ); + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_response_trailer_response_without_trailer_send_rejected() { + struct ResponseHeadersOnlyTrailerProcessor; + + #[async_trait] + impl ExternalProcessor for ResponseHeadersOnlyTrailerProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + let _req_hdrs = stream.message().await.unwrap().unwrap(); + let req_hdr_resp = ProcessingResponse { + response: Some(processing_response::Response::RequestHeaders(HeadersResponse { + response: None, + })), + ..Default::default() + }; + drop(tx.send(Ok(req_hdr_resp)).await); + + let _resp_hdrs = stream.message().await.unwrap().unwrap(); + let resp_hdr_resp = ProcessingResponse { + response: Some(processing_response::Response::ResponseHeaders(HeadersResponse { + response: None, + })), + ..Default::default() + }; + drop(tx.send(Ok(resp_hdr_resp)).await); + + let trailer_resp = ProcessingResponse { + response: Some(processing_response::Response::ResponseTrailers(TrailersResponse { + header_mutation: None, + })), + ..Default::default() + }; + drop(tx.send(Ok(trailer_resp)).await); + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(ResponseHeadersOnlyTrailerProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &full_duplex_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let _req_hdr = exchange.receive().await.unwrap(); + exchange.send(make_response_headers()).await.unwrap(); + let _resp_hdr = exchange.receive().await.unwrap(); + let result = exchange.receive().await; + assert!( + matches!(result, Err(ExchangeError::OrderingViolation(_))), + "response trailer response without trailer send should be rejected in full-duplex" + ); + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_duplicate_non_fd_body_response_rejected() { + struct DuplicateBodyResponseProcessor; + + #[async_trait] + impl ExternalProcessor for DuplicateBodyResponseProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(8); + tokio::spawn(async move { + let _headers = stream.message().await.unwrap().unwrap(); + let header_resp = ProcessingResponse { + response: Some(processing_response::Response::RequestHeaders(HeadersResponse { + response: None, + })), + ..Default::default() + }; + drop(tx.send(Ok(header_resp)).await); + + let _body = stream.message().await.unwrap().unwrap(); + let body_resp = ProcessingResponse { + response: Some(processing_response::Response::RequestBody(BodyResponse { + response: None, + })), + ..Default::default() + }; + drop(tx.send(Ok(body_resp)).await); + + let dup_body_resp = ProcessingResponse { + response: Some(processing_response::Response::RequestBody(BodyResponse { + response: None, + })), + ..Default::default() + }; + drop(tx.send(Ok(dup_body_resp)).await); + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(DuplicateBodyResponseProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &streamed_body_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let _hdr = exchange.receive().await.unwrap(); + exchange.send(make_request_body(b"data", true)).await.unwrap(); + let _body = exchange.receive().await.unwrap(); + let result = exchange.receive().await; + assert!( + matches!(result, Err(ExchangeError::OrderingViolation(_))), + "duplicate body response in non-full-duplex mode should be rejected (no active state)" + ); + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_cross_direction_non_fd_response_without_active_match_rejected() { + struct CrossDirectionProcessor; + + #[async_trait] + impl ExternalProcessor for CrossDirectionProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + let _headers = stream.message().await.unwrap().unwrap(); + let resp = ProcessingResponse { + response: Some(processing_response::Response::ResponseHeaders(HeadersResponse { + response: None, + })), + ..Default::default() + }; + drop(tx.send(Ok(resp)).await); + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(CrossDirectionProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &default_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let result = exchange.receive().await; + assert!( + matches!(result, Err(ExchangeError::OrderingViolation(_))), + "cross-direction ResponseHeaders without response headers committed should be rejected" + ); + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_unsolicited_immediate_before_first_send_rejected() { + struct ImmediateBeforeSendProcessor; + + #[async_trait] + impl ExternalProcessor for ImmediateBeforeSendProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + _request: tonic::Request>, + ) -> Result, tonic::Status> { + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + let resp = ProcessingResponse { + response: Some(processing_response::Response::ImmediateResponse(ImmediateResponse { + status: Some(HttpStatus { code: 500 }), + headers: None, + body: "unsolicited".to_owned(), + grpc_status: None, + details: String::new(), + })), + ..Default::default() + }; + drop(tx.send(Ok(resp)).await); + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(ImmediateBeforeSendProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &default_exchange_config()).unwrap(); + let result = exchange.receive().await; + assert!( + matches!(result, Err(ExchangeError::OrderingViolation(_))), + "immediate response before first send should be rejected" + ); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("before first send"), + "error should mention 'before first send': {err}" + ); + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_rejected_response_does_not_advance_output_phase() { + struct WrongThenCorrectProcessor; + + #[async_trait] + impl ExternalProcessor for WrongThenCorrectProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(8); + tokio::spawn(async move { + let _headers = stream.message().await.unwrap().unwrap(); + let wrong_resp = ProcessingResponse { + response: Some(processing_response::Response::RequestBody(BodyResponse { + response: None, + })), + ..Default::default() + }; + drop(tx.send(Ok(wrong_resp)).await); + let correct_resp = ProcessingResponse { + response: Some(processing_response::Response::RequestHeaders(HeadersResponse { + response: None, + })), + ..Default::default() + }; + drop(tx.send(Ok(correct_resp)).await); + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(WrongThenCorrectProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let config = ExchangeConfig { + request_body_mode: BodySendMode::Streamed, + ..default_exchange_config() + }; + let mut exchange = ExtProcExchange::open(channel, &config).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + + let (req_before, resp_before) = exchange.output_phases(); + + let result = exchange.receive().await; + assert!( + matches!(result, Err(ExchangeError::OrderingViolation(_))), + "RequestBody response before RequestHeaders output should be rejected" + ); + + let (req_after, resp_after) = exchange.output_phases(); + assert_eq!( + req_before, req_after, + "request output phase should be unchanged after rejection" + ); + assert_eq!( + resp_before, resp_after, + "response output phase should be unchanged after rejection" + ); + + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_negative_override_ignored() { + struct NegativeOverrideProcessor; + + #[async_trait] + impl ExternalProcessor for NegativeOverrideProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + let msg = stream.message().await.unwrap().unwrap(); + let bad_override = ProcessingResponse { + override_message_timeout: Some(prost_types::Duration { seconds: -1, nanos: 0 }), + ..Default::default() + }; + drop(tx.send(Ok(bad_override)).await); + let real_resp = build_add_header_response(&msg, "x-real", "response"); + drop(tx.send(Ok(real_resp)).await); + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(NegativeOverrideProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let config = ExchangeConfig { + max_message_timeout: Some(Duration::from_secs(10)), + ..default_exchange_config() + }; + let mut exchange = ExtProcExchange::open(channel, &config).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let event = exchange.receive().await.unwrap(); + assert!( + matches!(event, ExchangeEvent::RequestHeaders { .. }), + "negative seconds override should be consumed and ignored; real response returned" + ); + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_negative_nanos_override_ignored() { + struct NegativeNanosOverrideProcessor; + + #[async_trait] + impl ExternalProcessor for NegativeNanosOverrideProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + let msg = stream.message().await.unwrap().unwrap(); + let bad_override = ProcessingResponse { + override_message_timeout: Some(prost_types::Duration { seconds: 1, nanos: -1 }), + ..Default::default() + }; + drop(tx.send(Ok(bad_override)).await); + let real_resp = build_add_header_response(&msg, "x-real", "response"); + drop(tx.send(Ok(real_resp)).await); + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(NegativeNanosOverrideProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let config = ExchangeConfig { + max_message_timeout: Some(Duration::from_secs(10)), + ..default_exchange_config() + }; + let mut exchange = ExtProcExchange::open(channel, &config).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let event = exchange.receive().await.unwrap(); + assert!( + matches!(event, ExchangeEvent::RequestHeaders { .. }), + "negative nanos override should be consumed and ignored; real response returned" + ); + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_out_of_range_nanos_override_ignored() { + struct OutOfRangeNanosProcessor; + + #[async_trait] + impl ExternalProcessor for OutOfRangeNanosProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + let msg = stream.message().await.unwrap().unwrap(); + let bad_override = ProcessingResponse { + override_message_timeout: Some(prost_types::Duration { + seconds: 1, + nanos: 2_000_000_000, + }), + ..Default::default() + }; + drop(tx.send(Ok(bad_override)).await); + let real_resp = build_add_header_response(&msg, "x-real", "response"); + drop(tx.send(Ok(real_resp)).await); + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(OutOfRangeNanosProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let config = ExchangeConfig { + max_message_timeout: Some(Duration::from_secs(10)), + ..default_exchange_config() + }; + let mut exchange = ExtProcExchange::open(channel, &config).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let event = exchange.receive().await.unwrap(); + assert!( + matches!(event, ExchangeEvent::RequestHeaders { .. }), + "out-of-range nanos override should be consumed and ignored; real response returned" + ); + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_sub_millisecond_override_ignored() { + struct SubMsOverrideProcessor; + + #[async_trait] + impl ExternalProcessor for SubMsOverrideProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + let msg = stream.message().await.unwrap().unwrap(); + let bad_override = ProcessingResponse { + override_message_timeout: Some(prost_types::Duration { + seconds: 0, + nanos: 500_000, // 0.5ms, below MIN_OVERRIDE + }), + ..Default::default() + }; + drop(tx.send(Ok(bad_override)).await); + let real_resp = build_add_header_response(&msg, "x-real", "response"); + drop(tx.send(Ok(real_resp)).await); + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(SubMsOverrideProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let config = ExchangeConfig { + max_message_timeout: Some(Duration::from_secs(10)), + ..default_exchange_config() + }; + let mut exchange = ExtProcExchange::open(channel, &config).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let event = exchange.receive().await.unwrap(); + assert!( + matches!(event, ExchangeEvent::RequestHeaders { .. }), + "sub-millisecond override (0.5ms) should be consumed and ignored; real response returned" + ); + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_deadline_overflow_returns_error_not_panic() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::EchoHeaders { + name: "x-overflow".to_owned(), + value: "test".to_owned(), + }) + .await; + let channel = connect_channel(addr).await; + let config = ExchangeConfig { + message_timeout: Duration::MAX, + ..default_exchange_config() + }; + let mut exchange = ExtProcExchange::open(channel, &config).unwrap(); + let result = exchange.send(make_request_headers()).await; + assert!( + matches!(result, Err(ExchangeError::DeadlineOverflow)), + "Duration::MAX should fail at send with deadline overflow, not panic" + ); +} + +#[tokio::test] +async fn duplex_trailer_metadata_preserved() { + struct TrailerMetadataProcessor; + + #[async_trait] + impl ExternalProcessor for TrailerMetadataProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + while let Ok(Some(msg)) = stream.message().await { + let resp = match &msg.request { + Some(processing_request::Request::RequestHeaders(_)) => ProcessingResponse { + response: Some(processing_response::Response::RequestHeaders(HeadersResponse { + response: None, + })), + ..Default::default() + }, + Some(processing_request::Request::RequestTrailers(_)) => { + let mut fields = HashMap::new(); + fields.insert( + "trailer_key".to_owned(), + prost_wkt_types::Value { + kind: Some(prost_wkt_types::value::Kind::StringValue("trailer_value".to_owned())), + }, + ); + ProcessingResponse { + response: Some(processing_response::Response::RequestTrailers(TrailersResponse { + header_mutation: None, + })), + dynamic_metadata: Some(prost_wkt_types::Struct { fields }), + ..Default::default() + } + }, + _ => ProcessingResponse::default(), + }; + drop(tx.send(Ok(resp)).await); + } + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(TrailerMetadataProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &default_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let _hdr = exchange.receive().await.unwrap(); + exchange.send(make_request_trailers()).await.unwrap(); + let event = exchange.receive().await.unwrap(); + match event { + ExchangeEvent::RequestTrailers { metadata, .. } => { + let md = metadata.expect("metadata should be present on trailer event"); + assert!( + md.fields.contains_key("trailer_key"), + "dynamic_metadata should be preserved on ExchangeEvent::RequestTrailers" + ); + }, + other => panic!("expected RequestTrailers, got {other:?}"), + } + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_cancelled_blocked_send_leaves_state_unchanged() { + use crate::duplex::commit_message; + + let (tx, mut rx) = tokio::sync::mpsc::channel(1); + let timeout = Duration::from_millis(200); + + let fill_msg = ProcessingRequest { + request: Some(make_request_headers()), + ..Default::default() + }; + tx.send(fill_msg).await.unwrap(); + + { + let cancelled_msg = ProcessingRequest { + request: Some(make_request_body(b"CANCELLED_ID", false)), + ..Default::default() + }; + let blocked = commit_message(&tx, cancelled_msg, Some(timeout)); + let poll_result = tokio::time::timeout(Duration::from_millis(50), blocked).await; + assert!(poll_result.is_err(), "send should be pending while channel is full"); + } + + let first = rx.recv().await.unwrap(); + assert!( + matches!(first.request, Some(processing_request::Request::RequestHeaders(_))), + "first received should be the fill message, not the cancelled body" + ); + + assert!( + rx.try_recv().is_err(), + "channel should be empty after removing the fill message; cancelled message must not be present" + ); + + let followup_msg = ProcessingRequest { + request: Some(make_request_body(b"FOLLOWUP_ID", true)), + ..Default::default() + }; + let result = commit_message(&tx, followup_msg, None).await; + assert!(result.is_ok(), "follow-up should succeed after cancelled send"); + + let received = rx.recv().await.unwrap(); + if let Some(processing_request::Request::RequestBody(body)) = &received.request { + assert_eq!( + body.body, b"FOLLOWUP_ID", + "should receive follow-up, not cancelled message" + ); + } else { + panic!("expected RequestBody follow-up"); + } +} + +#[tokio::test] +async fn duplex_repeated_close_with_eof_count() { + let eof_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0)); + + struct EofCountingProcessor { + eof_count: std::sync::Arc, + } + + #[async_trait] + impl ExternalProcessor for EofCountingProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let eof_count = self.eof_count.clone(); + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + while let Ok(Some(msg)) = stream.message().await { + let resp = build_noop_response(&msg); + drop(tx.send(Ok(resp)).await); + } + eof_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(EofCountingProcessor { + eof_count: eof_count.clone(), + }); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let shared_channel = connect_channel(addr).await; + + for i in 0..100 { + let mut exchange = ExtProcExchange::open(shared_channel.clone(), &default_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let resp = exchange.receive().await.unwrap(); + assert!( + matches!(resp, ExchangeEvent::RequestHeaders { .. }), + "exchange {i} should receive a response" + ); + exchange.finish_sending(); + drop(exchange); + } + + tokio::time::sleep(Duration::from_millis(200)).await; + + let observed = eof_count.load(std::sync::atomic::Ordering::SeqCst); + assert_eq!(observed, 100, "server should observe exactly 100 EOFs, got {observed}"); + + let mut final_exchange = ExtProcExchange::open(shared_channel.clone(), &default_exchange_config()).unwrap(); + final_exchange.send(make_request_headers()).await.unwrap(); + let resp = final_exchange.receive().await.unwrap(); + assert!( + matches!(resp, ExchangeEvent::RequestHeaders { .. }), + "final exchange on same channel should succeed" + ); + final_exchange.finish_sending(); + + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_cross_direction_started_non_fd_duplicate_body_rejected() { + struct WrongTypeProcessor; + + #[async_trait] + impl ExternalProcessor for WrongTypeProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(16); + tokio::spawn(async move { + let mut msg_count = 0_u32; + while let Ok(Some(msg)) = stream.message().await { + msg_count += 1; + let resp = if msg_count == 5 { + ProcessingResponse { + response: Some(processing_response::Response::ResponseBody(BodyResponse { + response: None, + })), + ..Default::default() + } + } else { + build_noop_response(&msg) + }; + drop(tx.send(Ok(resp)).await); + } + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + let svc = ExternalProcessorServer::new(WrongTypeProcessor); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let config = streamed_body_exchange_config(); + let mut exchange = ExtProcExchange::open(channel, &config).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let _rh = exchange.receive().await.unwrap(); + + exchange.send(make_request_body(b"chunk1", false)).await.unwrap(); + let _rb = exchange.receive().await.unwrap(); + + exchange.send(make_response_headers()).await.unwrap(); + let _resh = exchange.receive().await.unwrap(); + + exchange + .send(processing_request::Request::ResponseBody( + praxis_proto::envoy::service::ext_proc::v3::HttpBody { + body: b"resp_body".to_vec(), + end_of_stream: false, + }, + )) + .await + .unwrap(); + let _resb = exchange.receive().await.unwrap(); + + exchange.send(make_request_body(b"chunk2", false)).await.unwrap(); + let result = exchange.receive().await; + assert!( + matches!(result, Err(ExchangeError::OrderingViolation(_))), + "ResponseBody when active expects RequestBody must be rejected" + ); + drop(exchange); + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn duplex_blocked_send_deadline_starts_after_commit() { + use std::pin::pin; + + use crate::duplex::commit_message; + + let (tx, mut rx) = tokio::sync::mpsc::channel(1); + let timeout = Duration::from_millis(200); + + let fill_msg = ProcessingRequest { + request: Some(make_request_headers()), + ..Default::default() + }; + tx.send(fill_msg).await.unwrap(); + + let target_msg = ProcessingRequest { + request: Some(make_request_body(b"target", false)), + ..Default::default() + }; + let mut blocked_future = pin!(commit_message(&tx, target_msg, Some(timeout))); + + let poll_result = tokio::time::timeout(Duration::from_millis(250), &mut blocked_future).await; + assert!(poll_result.is_err(), "send should remain pending while channel is full"); + + let _first = rx.recv().await.unwrap(); + + let result = blocked_future.await; + let deadline = result.unwrap().unwrap(); + + let remaining = deadline.duration_since(tokio::time::Instant::now()); + assert!( + remaining > Duration::from_millis(150), + "deadline should have ~200ms remaining since it started at commit, not at reserve: {remaining:?}" + ); +} + +// ------------------------------------------------------------------------- +// Single-Owner Pending-Process Driver Tests +// ------------------------------------------------------------------------- + +#[tokio::test(flavor = "current_thread")] +async fn driver_delayed_response_headers_current_thread() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::DelayedRouting { + header_name: "x-ep".to_owned(), + header_value: "ep1".to_owned(), + }) + .await; + let channel = connect_channel(addr).await; + let config = full_duplex_exchange_config(); + let mut exchange = ExtProcExchange::open(channel, &config).unwrap(); + + exchange.send(make_request_headers()).await.unwrap(); + exchange.send(make_request_body(b"chunk1", false)).await.unwrap(); + exchange.send(make_request_body(b"chunk2", false)).await.unwrap(); + exchange.send(make_request_body(b"", true)).await.unwrap(); + + let event = tokio::time::timeout(Duration::from_secs(5), exchange.receive()) + .await + .expect("should not timeout") + .expect("should receive headers response"); + assert!( + matches!(event, ExchangeEvent::RequestHeaders { .. }), + "first event should be request headers response" + ); +} + +#[tokio::test] +async fn driver_one_process_invocation() { + let invocation_count = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0)); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); + + struct CountingProcessor { + count: std::sync::Arc, + } + + #[async_trait] + impl ExternalProcessor for CountingProcessor { + type ProcessStream = Pin> + Send>>; + + async fn process( + &self, + request: tonic::Request>, + ) -> Result, tonic::Status> { + self.count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + let mut stream = request.into_inner(); + let (tx, rx) = tokio::sync::mpsc::channel(4); + tokio::spawn(async move { + while let Ok(Some(_msg)) = stream.message().await { + let resp = ProcessingResponse { + response: Some(processing_response::Response::RequestHeaders(HeadersResponse { + response: None, + })), + ..Default::default() + }; + drop(tx.send(Ok(resp)).await); + } + }); + Ok(tonic::Response::new(Box::pin( + tokio_stream::wrappers::ReceiverStream::new(rx), + ))) + } + } + + let svc = ExternalProcessorServer::new(CountingProcessor { + count: invocation_count.clone(), + }); + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown(tokio_stream::wrappers::TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }) + .await + .unwrap(); + }); + wait_for_server(addr).await; + + let channel = connect_channel(addr).await; + let mut exchange = ExtProcExchange::open(channel, &default_exchange_config()).unwrap(); + exchange.send(make_request_headers()).await.unwrap(); + let _resp = exchange.receive().await.unwrap(); + drop(exchange.send(make_request_headers()).await); + drop(exchange); + + tokio::time::sleep(Duration::from_millis(100)).await; + assert_eq!( + invocation_count.load(std::sync::atomic::Ordering::SeqCst), + 1, + "exactly one Process invocation per exchange" + ); + + let _ = shutdown_tx.send(()); +} + +#[tokio::test] +async fn driver_outbound_half_close_preserves_drain() { + let (addr, _guard) = start_duplex_processor(DuplexBehavior::ImmediateOnBody { + status: 403, + body: "blocked".to_owned(), + }) + .await; + let channel = connect_channel(addr).await; + let config = full_duplex_exchange_config(); + let mut exchange = ExtProcExchange::open(channel, &config).unwrap(); + + exchange.send(make_request_headers()).await.unwrap(); + let _hdr = exchange.receive().await.unwrap(); + exchange.send(make_request_body(b"data", true)).await.unwrap(); + + let event = tokio::time::timeout(Duration::from_secs(5), exchange.receive()) + .await + .expect("should not timeout") + .expect("should receive immediate response"); + assert!( + matches!(&event, ExchangeEvent::Immediate { response, .. } if response.status.as_ref().is_some_and(|s| s.code == 403)), + "should receive ImmediateResponse with exact 403 status" + ); +}