From ec0bb23118cb26c8afec49865133352935cec014 Mon Sep 17 00:00:00 2001 From: Oleksii Date: Wed, 10 Jun 2026 01:21:50 -0300 Subject: [PATCH] =?UTF-8?q?feat(llm):=20streaming=20llm=5Fcall=20=E2=80=94?= =?UTF-8?q?=20provider=20SSE=20consumption=20with=20durable=20accumulated?= =?UTF-8?q?=20output?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `stream: true` on llm_call consumes the provider's streaming wire protocol (OpenAI /chat/completions SSE with stream_options.include_usage; Anthropic /messages event stream), publishes incremental text deltas to clients watching GET /instances/{id}/stream, and produces a durable step output identical to the non-streaming path (message, finish_reason, usage, tool calls) — downstream blocks are unaffected. - New `orch8_engine::stream_bus`: per-instance tokio broadcast registry (lazily created on subscribe, capacity-bounded, dropped when the last subscriber disconnects). Handlers publish `llm_delta` events; the API's SSE endpoint subscribes in the same process and forwards them as `llm_delta` SSE events alongside the existing polled state/output events. - Incremental SSE parser shared by both providers (CRLF-safe, chunk-split and UTF-8-split safe); request-body building extracted so streaming and non-streaming share multimodal conversion. - Failure taxonomy: mid-stream drop, missing terminal event ([DONE] / message_stop) and an inter-chunk idle timeout (stream_idle_timeout_secs, default 30s) all fail Retryable; Anthropic mid-stream error events map to the existing retryable/permanent split. Failover streams per attempt. - `stream` + `response_schema` falls back to non-streaming (logged): the validate/repair loop needs complete responses. dry_run skips the provider call unchanged. - Engine-only deployments compile unchanged: the bus lives behind the existing engine/api boundary and stays inert without subscribers. Tests: mock-SSE provider streams asserting streamed output equals the non-streaming shape (incl. usage and tool calls) for both providers, mid-stream drop / idle timeout / error-event taxonomy, schema fallback, failover streaming, bus pub/sub lifecycle, and API e2e tests proving a connected SSE client receives llm_delta events scoped to its instance. Co-Authored-By: Claude Fable 5 --- orch8-api/src/streaming.rs | 61 ++- orch8-api/tests/streaming.rs | 175 +++++++ orch8-engine/src/handlers/llm/anthropic.rs | 291 ++++++++++- orch8-engine/src/handlers/llm/mod.rs | 572 ++++++++++++++++++++- orch8-engine/src/handlers/llm/openai.rs | 231 ++++++++- orch8-engine/src/handlers/llm/sse.rs | 213 ++++++++ orch8-engine/src/lib.rs | 1 + orch8-engine/src/stream_bus.rs | 174 +++++++ 8 files changed, 1668 insertions(+), 50 deletions(-) create mode 100644 orch8-api/tests/streaming.rs create mode 100644 orch8-engine/src/handlers/llm/sse.rs create mode 100644 orch8-engine/src/stream_bus.rs diff --git a/orch8-api/src/streaming.rs b/orch8-api/src/streaming.rs index 6eff7425..abecfb0f 100644 --- a/orch8-api/src/streaming.rs +++ b/orch8-api/src/streaming.rs @@ -3,10 +3,14 @@ //! `GET /instances/{id}/stream` returns a Server-Sent Events stream that emits: //! - `state` events when instance state changes //! - `output` events when new block outputs appear +//! - `llm_delta` events with incremental text from a streaming `llm_call` +//! step (`stream: true`), forwarded live from the in-process +//! [`orch8_engine::stream_bus`] — not polled from storage //! - `done` event when the instance reaches a terminal state //! -//! The stream polls storage at a configurable interval (default 500ms) and sends -//! keepalive comments every 15s. Closes when the instance reaches a terminal state. +//! State/output changes poll storage at a configurable interval (default +//! 500ms); keepalive comments are sent every 15s. Closes when the instance +//! reaches a terminal state. use std::convert::Infallible; use std::time::Duration; @@ -14,8 +18,10 @@ use std::time::Duration; use axum::extract::{Path, Query, State}; use axum::response::sse::{Event, KeepAlive, Sse}; use serde::Deserialize; +use tokio::sync::broadcast; use uuid::Uuid; +use orch8_engine::stream_bus::{stream_bus, StreamEvent}; use orch8_types::ids::InstanceId; use orch8_types::instance::InstanceState; @@ -33,6 +39,19 @@ const fn default_poll_ms() -> u64 { 500 } +/// Await the next live event from the stream-bus subscription. With no +/// subscription left (`None`, after the channel closed) pends forever — the +/// caller guards the select arm with `delta_rx.is_some()`, so this branch is +/// then never polled. +async fn next_delta( + rx: Option<&mut broadcast::Receiver>, +) -> Result { + match rx { + Some(rx) => rx.recv().await, + None => std::future::pending().await, + } +} + const fn is_terminal(state: InstanceState) -> bool { matches!( state, @@ -47,7 +66,7 @@ const fn is_terminal(state: InstanceState) -> bool { ("poll_ms" = Option, Query, description = "Poll interval in ms (min 100ms)"), ), responses( - (status = 200, description = "Server-Sent Events stream of instance state/tree/output changes", content_type = "text/event-stream"), + (status = 200, description = "Server-Sent Events stream of instance state/output changes plus live llm_delta events from streaming llm_call steps", content_type = "text/event-stream"), (status = 404, description = "Instance not found"), (status = 503, description = "Too many concurrent streams"), ) @@ -103,6 +122,12 @@ pub(crate) async fn stream_instance( let shutdown = state.shutdown.clone(); let storage = state.storage.clone(); + // Subscribe to the in-process stream bus BEFORE the response starts, so + // llm_delta events published after the client connects are never missed. + // The engine runs in the same process (orch8-server); in API-only + // deployments the channel simply never receives anything. + let mut delta_rx = Some(stream_bus().subscribe(instance_id)); + tokio::spawn(async move { // Hold the permit for the lifetime of the stream task; dropped // automatically on return/panic, freeing a slot. @@ -123,6 +148,36 @@ pub(crate) async fn stream_instance( () = shutdown.cancelled() => break, () = tx.closed() => break, _ = ticker.tick() => {} + event = next_delta(delta_rx.as_mut()), if delta_rx.is_some() => { + match event { + Ok(ev) => { + // `StreamEvent` serialization is infallible (plain + // strings + tag); fall back to `{}` defensively. + let payload = serde_json::to_string(&ev) + .unwrap_or_else(|_| "{}".to_string()); + let sse = Event::default().event("llm_delta").data(payload); + if tx.send(Ok(sse)).await.is_err() { + break; + } + } + Err(broadcast::error::RecvError::Lagged(skipped)) => { + // Best-effort live view: a slow client just loses + // deltas; the durable output event carries the + // full text anyway. + tracing::debug!( + skipped, + instance_id = %instance_id.into_uuid(), + "stream: client lagged behind llm_delta broadcast" + ); + } + Err(broadcast::error::RecvError::Closed) => { + delta_rx = None; + } + } + // Deltas don't advance the storage poll; wait for the + // next tick before re-querying. + continue; + } } // If the last iterations failed, wait an extra backoff period diff --git a/orch8-api/tests/streaming.rs b/orch8-api/tests/streaming.rs new file mode 100644 index 00000000..1877735d --- /dev/null +++ b/orch8-api/tests/streaming.rs @@ -0,0 +1,175 @@ +//! E2E tests for `GET /instances/{id}/stream` — specifically the forwarding +//! of live `llm_delta` events from the in-process engine stream bus. +//! +//! The test harness does not run the engine tick loop, so instead of +//! executing a real streaming `llm_call` we publish to the bus directly +//! (exactly what the handler's `DeltaSink` does) and assert a connected SSE +//! client receives the events. The handler-side publication path is covered +//! by `orch8-engine`'s mock-SSE provider tests. + +use std::time::Duration; + +use orch8_api::test_harness::spawn_test_server; +use orch8_engine::stream_bus::{stream_bus, StreamEvent}; +use orch8_types::ids::InstanceId; +use reqwest::StatusCode; +use serde_json::json; +use uuid::Uuid; + +fn mk_sequence_body(id: Uuid) -> serde_json::Value { + json!({ + "id": id, + "tenant_id": "t1", + "namespace": "ns1", + "name": "stream-seq", + "version": 1, + "deprecated": false, + "blocks": [ + { + "type": "step", + "id": "s1", + "handler": "noop", + "params": {}, + "cancellable": true + } + ], + "interceptors": null, + "created_at": chrono::Utc::now().to_rfc3339() + }) +} + +/// Create a sequence and an instance of it; returns the instance id. +async fn create_instance(client: &reqwest::Client, base_url: &str) -> Uuid { + let seq_id = Uuid::now_v7(); + let resp = client + .post(format!("{base_url}/sequences")) + .header("X-Tenant-Id", "t1") + .json(&mk_sequence_body(seq_id)) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::CREATED); + + let resp = client + .post(format!("{base_url}/instances")) + .header("X-Tenant-Id", "t1") + .json(&json!({ + "sequence_id": seq_id, + "tenant_id": "t1", + "namespace": "ns1", + "context": { "data": {}, "config": {}, "audit": [] } + })) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::CREATED); + let created: serde_json::Value = resp.json().await.unwrap(); + created["id"].as_str().unwrap().parse().unwrap() +} + +/// Read SSE chunks into `buf` until it contains `needle` (10s safety cap). +async fn read_until(resp: &mut reqwest::Response, buf: &mut String, needle: &str) { + tokio::time::timeout(Duration::from_secs(10), async { + while !buf.contains(needle) { + let chunk = resp + .chunk() + .await + .expect("SSE read failed") + .expect("SSE stream ended before expected event"); + buf.push_str(&String::from_utf8_lossy(&chunk)); + } + }) + .await + .unwrap_or_else(|_| panic!("timed out waiting for {needle:?} in SSE stream; got: {buf}")); +} + +#[tokio::test] +async fn stream_forwards_llm_delta_events_from_bus() { + let srv = spawn_test_server().await; + let client = reqwest::Client::new(); + let inst_id = create_instance(&client, &srv.base_url).await; + + let mut resp = client + .get(format!("{}/instances/{inst_id}/stream", srv.base_url)) + .header("X-Tenant-Id", "t1") + .query(&[("poll_ms", "100")]) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + // The bus subscription is established before the response headers are + // sent, but wait for the first poll-driven `state` event anyway so we + // know the stream task is fully live. + let mut buf = String::new(); + read_until(&mut resp, &mut buf, "event: state").await; + + // Publish exactly what a streaming llm_call's DeltaSink publishes. + let instance_id = InstanceId::from_uuid(inst_id); + for delta in ["Hel", "lo"] { + stream_bus().publish( + instance_id, + StreamEvent::LlmDelta { + block_id: "s1".to_string(), + delta: delta.to_string(), + }, + ); + } + + read_until(&mut resp, &mut buf, "event: llm_delta").await; + read_until( + &mut resp, + &mut buf, + r#"{"type":"llm_delta","block_id":"s1","delta":"Hel"}"#, + ) + .await; + read_until( + &mut resp, + &mut buf, + r#"{"type":"llm_delta","block_id":"s1","delta":"lo"}"#, + ) + .await; +} + +#[tokio::test] +async fn stream_ignores_deltas_for_other_instances() { + let srv = spawn_test_server().await; + let client = reqwest::Client::new(); + let inst_id = create_instance(&client, &srv.base_url).await; + + let mut resp = client + .get(format!("{}/instances/{inst_id}/stream", srv.base_url)) + .header("X-Tenant-Id", "t1") + .query(&[("poll_ms", "100")]) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let mut buf = String::new(); + read_until(&mut resp, &mut buf, "event: state").await; + + // A delta for a DIFFERENT instance must not reach this client; one for + // the watched instance (published after) must. Receiving the marker + // proves the foreign delta was not interleaved before it. + stream_bus().publish( + InstanceId::new(), + StreamEvent::LlmDelta { + block_id: "other".to_string(), + delta: "foreign".to_string(), + }, + ); + stream_bus().publish( + InstanceId::from_uuid(inst_id), + StreamEvent::LlmDelta { + block_id: "s1".to_string(), + delta: "marker".to_string(), + }, + ); + + read_until(&mut resp, &mut buf, r#""delta":"marker""#).await; + assert!( + !buf.contains("foreign"), + "delta for another instance leaked into this stream: {buf}" + ); +} diff --git a/orch8-engine/src/handlers/llm/anthropic.rs b/orch8-engine/src/handlers/llm/anthropic.rs index e0850ea4..73d2118c 100644 --- a/orch8-engine/src/handlers/llm/anthropic.rs +++ b/orch8-engine/src/handlers/llm/anthropic.rs @@ -1,26 +1,20 @@ -use serde_json::{json, Value}; -use tracing::debug; +use std::collections::BTreeMap; + +use serde_json::{json, Map, Value}; +use tracing::{debug, warn}; use orch8_types::error::StepError; use super::common::{ classify_api_error, classify_reqwest_error, extract_system_message, is_json_object_format, - merge_json_response_fields, retryable, + merge_json_response_fields, permanent, retryable, safe_truncate, }; -use super::{anthropic_default_model, http_client}; - -pub(super) async fn call_anthropic( - params: &Value, - api_key: &str, - base_url: &str, -) -> Result { - let url = format!("{base_url}/messages"); - - let model = params - .get("model") - .and_then(Value::as_str) - .unwrap_or(anthropic_default_model()); +use super::sse::{next_chunk, stream_idle_timeout, SseParser}; +use super::{anthropic_default_model, http_client, DeltaSink}; +/// Build the `/messages` request body shared by the streaming and +/// non-streaming paths (so multimodal message conversion behaves identically). +fn build_body(params: &Value, model: &str) -> Map { let messages_raw = params.get("messages").cloned().unwrap_or(json!([])); let (system_from_msgs, messages) = extract_system_message(&messages_raw); // Plain-string content is cloned unchanged; normalized image blocks @@ -63,9 +57,29 @@ pub(super) async fn call_anthropic( body.insert(key.into(), val.clone()); } } + body +} + +pub(super) async fn call_anthropic( + params: &Value, + api_key: &str, + base_url: &str, + deltas: Option<&DeltaSink>, +) -> Result { + let url = format!("{base_url}/messages"); + + let model = params + .get("model") + .and_then(Value::as_str) + .unwrap_or(anthropic_default_model()); + + let mut body = build_body(params, model); + if deltas.is_some() { + body.insert("stream".into(), json!(true)); + } let body = Value::Object(body); - debug!(url = %url, model = %model, "llm_call: Anthropic"); + debug!(url = %url, model = %model, streaming = deltas.is_some(), "llm_call: Anthropic"); let resp = http_client() .post(&url) @@ -77,6 +91,10 @@ pub(super) async fn call_anthropic( .await .map_err(|e| classify_reqwest_error(&e))?; + if let Some(sink) = deltas { + return consume_anthropic_stream(resp, params, sink).await; + } + let status = resp.status().as_u16(); let resp_body: Value = resp .json() @@ -103,6 +121,227 @@ pub(super) async fn call_anthropic( Ok(output) } +/// Accumulator for the Anthropic streaming event protocol. Rebuilds the +/// complete (non-streaming-shaped) `/messages` response body so the final +/// output goes through the exact same [`normalize_anthropic_response`] path. +#[derive(Default)] +struct AnthropicStreamAcc { + model: Value, + /// Content blocks keyed by stream `index`. + blocks: BTreeMap, + /// Accumulated `input_json_delta` fragments per `tool_use` block index. + partial_tool_json: BTreeMap, + /// Merged usage: `message_start` provides `input_tokens`, the final + /// `message_delta` overlays the authoritative `output_tokens`. + usage: Map, + stop_reason: Value, + done: bool, +} + +impl AnthropicStreamAcc { + /// Ingest one event payload, publishing text deltas to `sink`. + /// Returns an error for explicit `error` events from the provider. + fn ingest(&mut self, data: &str, sink: &DeltaSink) -> Result<(), StepError> { + let Ok(event) = serde_json::from_str::(data) else { + warn!( + data_preview = %safe_truncate(data, 200), + "llm_call: skipping unparseable Anthropic streaming event" + ); + return Ok(()); + }; + match event.get("type").and_then(Value::as_str).unwrap_or("") { + "message_start" => { + if let Some(message) = event.get("message") { + if let Some(model) = message.get("model").filter(|m| m.is_string()) { + self.model = model.clone(); + } + self.merge_usage(message.get("usage")); + } + } + "content_block_start" => { + if let (Some(index), Some(block)) = ( + event.get("index").and_then(Value::as_u64), + event.get("content_block"), + ) { + self.blocks.insert(index, block.clone()); + } + } + "content_block_delta" => self.ingest_block_delta(&event, sink), + "content_block_stop" => { + if let Some(index) = event.get("index").and_then(Value::as_u64) { + self.finalize_tool_input(index); + } + } + "message_delta" => { + if let Some(reason) = event + .get("delta") + .and_then(|d| d.get("stop_reason")) + .filter(|r| !r.is_null()) + { + self.stop_reason = reason.clone(); + } + self.merge_usage(event.get("usage")); + } + "message_stop" => self.done = true, + "error" => return Err(classify_stream_error(event.get("error"))), + // `ping` and unknown / future event types are ignored. + _ => {} + } + Ok(()) + } + + fn ingest_block_delta(&mut self, event: &Value, sink: &DeltaSink) { + let Some(index) = event.get("index").and_then(Value::as_u64) else { + return; + }; + let Some(delta) = event.get("delta") else { + return; + }; + match delta.get("type").and_then(Value::as_str).unwrap_or("") { + "text_delta" => { + if let Some(text) = delta.get("text").and_then(Value::as_str) { + if let Some(Value::String(existing)) = + self.blocks.get_mut(&index).and_then(|b| b.get_mut("text")) + { + existing.push_str(text); + } + if !text.is_empty() { + sink.publish(text); + } + } + } + "input_json_delta" => { + if let Some(fragment) = delta.get("partial_json").and_then(Value::as_str) { + self.partial_tool_json + .entry(index) + .or_default() + .push_str(fragment); + } + } + _ => {} + } + } + + /// Parse the accumulated `input_json_delta` fragments of a `tool_use` + /// block into its `input` field. + fn finalize_tool_input(&mut self, index: u64) { + let Some(fragments) = self.partial_tool_json.remove(&index) else { + return; + }; + if fragments.is_empty() { + return; + } + match serde_json::from_str::(&fragments) { + Ok(input) => { + if let Some(block) = self.blocks.get_mut(&index) { + block["input"] = input; + } + } + Err(e) => warn!( + error = %e, + "llm_call: accumulated tool_use input is not valid JSON" + ), + } + } + + fn merge_usage(&mut self, usage: Option<&Value>) { + if let Some(map) = usage.and_then(Value::as_object) { + for (k, v) in map { + self.usage.insert(k.clone(), v.clone()); + } + } + } + + /// Rebuild the non-streaming response body and normalize it. + fn into_output(self, params: &Value) -> Value { + let content: Vec = self.blocks.into_values().collect(); + let rebuilt = json!({ + "content": content, + "model": self.model, + "stop_reason": self.stop_reason, + "usage": Value::Object(self.usage), + }); + let mut output = normalize_anthropic_response(&rebuilt); + if is_json_object_format(params) { + if let Some(content_owned) = output + .get("message") + .and_then(|m| m.get("content")) + .and_then(Value::as_str) + .map(String::from) + { + merge_json_response_fields(&content_owned, &mut output); + } + } + output + } +} + +/// Map a mid-stream `error` event to the retry taxonomy: capacity/server +/// conditions are retryable (and eligible for provider failover); request, +/// auth and permission errors are permanent. +fn classify_stream_error(error: Option<&Value>) -> StepError { + let error_type = error + .and_then(|e| e.get("type")) + .and_then(Value::as_str) + .unwrap_or("unknown"); + let message = error + .and_then(|e| e.get("message")) + .and_then(Value::as_str) + .unwrap_or("unknown error"); + match error_type { + "invalid_request_error" + | "authentication_error" + | "permission_error" + | "not_found_error" + | "request_too_large" => permanent(format!("stream error ({error_type}): {message}")), + // overloaded_error, api_error, rate_limit_error, timeout_error, … + _ => retryable(format!("stream error ({error_type}): {message}")), + } +} + +/// Consume an Anthropic SSE event stream, publishing text deltas and +/// rebuilding the full response. +/// +/// Termination contract: the provider must send `message_stop`. A stream +/// that ends before that is incomplete and fails **retryable**; a chunk gap +/// longer than the idle timeout fails retryable via [`next_chunk`]. +async fn consume_anthropic_stream( + mut resp: reqwest::Response, + params: &Value, + sink: &DeltaSink, +) -> Result { + let status = resp.status().as_u16(); + if status >= 400 { + // Error responses are plain JSON, not SSE. + let resp_body: Value = resp + .json() + .await + .map_err(|e| retryable(format!("response parse error: {e}")))?; + return Err(classify_api_error(status, &resp_body)); + } + + let idle_timeout = stream_idle_timeout(params); + let mut parser = SseParser::default(); + let mut acc = AnthropicStreamAcc::default(); + + while let Some(chunk) = next_chunk(&mut resp, idle_timeout).await? { + for event in parser.push(&chunk) { + acc.ingest(&event.data, sink)?; + } + if acc.done { + break; + } + } + + if !acc.done { + return Err(retryable( + "provider stream ended before message_stop — response is incomplete".to_string(), + )); + } + + Ok(acc.into_output(params)) +} + fn normalize_anthropic_response(resp_body: &Value) -> Value { let content = resp_body.get("content").cloned().unwrap_or_default(); @@ -196,4 +435,22 @@ mod tests { let output = normalize_anthropic_response(&resp); assert_eq!(output["provider"], "anthropic"); } + + #[test] + fn classify_stream_error_taxonomy() { + let overloaded = json!({"type": "overloaded_error", "message": "busy"}); + assert!(matches!( + classify_stream_error(Some(&overloaded)), + StepError::Retryable { .. } + )); + let invalid = json!({"type": "invalid_request_error", "message": "bad"}); + assert!(matches!( + classify_stream_error(Some(&invalid)), + StepError::Permanent { .. } + )); + assert!(matches!( + classify_stream_error(None), + StepError::Retryable { .. } + )); + } } diff --git a/orch8-engine/src/handlers/llm/mod.rs b/orch8-engine/src/handlers/llm/mod.rs index 859ee7a9..a7ddc0fa 100644 --- a/orch8-engine/src/handlers/llm/mod.rs +++ b/orch8-engine/src/handlers/llm/mod.rs @@ -25,6 +25,31 @@ //! | `response_schema` | object | — | JSON Schema the response must satisfy (validated, auto-repaired) | //! | `max_repair_attempts` | number | `2` | Schema-repair re-calls per provider attempt (hard cap 5) | //! | `max_image_bytes` | number | 20 MiB | Per-image size cap, pre-encoding (can only lower the default) | +//! | `stream` | bool | `false` | Consume the provider's SSE stream (see Streaming) | +//! | `stream_idle_timeout_secs` | number | `30` | Max gap between streamed chunks before failing retryable | +//! +//! ## Streaming +//! +//! With `stream: true` the provider call uses the streaming wire protocol +//! (`OpenAI` `/chat/completions` SSE with `stream_options.include_usage`; +//! Anthropic `/messages` event stream). Incremental text deltas are published +//! to the in-process [`crate::stream_bus`] as `llm_delta` events — clients +//! watching `GET /instances/{id}/stream` receive them live. The step's +//! **durable output is unchanged**: deltas are accumulated and the completed +//! output has exactly the non-streaming shape (message, `finish_reason`, +//! `usage`, tool calls), so downstream blocks are unaffected. +//! +//! Failure taxonomy: a mid-stream connection drop, a stream that ends before +//! the provider's terminal event (`[DONE]` / `message_stop`), or a chunk gap +//! exceeding `stream_idle_timeout_secs` all fail **retryable** (and are +//! eligible for provider failover). Provider `error` events map to the usual +//! retryable/permanent split. +//! +//! Interactions: `stream` combined with `response_schema` falls back to +//! non-streaming (logged) — the validate/repair loop requires complete +//! responses, and the durable output is identical either way. `dry_run` +//! skips the provider call entirely, exactly as without streaming. Failover +//! and multimodal content work unchanged (request building is shared). //! //! ## Multimodal content //! @@ -82,6 +107,7 @@ pub(crate) mod common; mod multimodal; mod openai; mod schema; +mod sse; use std::sync::OnceLock; use std::time::Duration; @@ -144,6 +170,33 @@ pub(crate) fn http_client() -> &'static reqwest::Client { }) } +/// Live-delta publisher for a streaming `llm_call` step: forwards incremental +/// text fragments to the per-instance [`crate::stream_bus`] channel so clients +/// watching the instance's SSE stream see tokens as they arrive. Publishing is +/// best-effort and a no-op when nobody is subscribed; the durable step output +/// is always the full accumulated response regardless. +pub(crate) struct DeltaSink { + instance_id: orch8_types::ids::InstanceId, + block_id: String, +} + +impl DeltaSink { + /// Publish one text delta for this step. + fn publish(&self, delta: &str) { + let bus = crate::stream_bus::stream_bus(); + if !bus.has_subscribers(self.instance_id) { + return; // skip the event allocation when nobody is watching + } + bus.publish( + self.instance_id, + crate::stream_bus::StreamEvent::LlmDelta { + block_id: self.block_id.clone(), + delta: delta.to_string(), + }, + ); + } +} + /// Main handler: routes to the correct provider API. /// /// If the params contain a `providers` array, iterates through each provider @@ -157,6 +210,31 @@ pub async fn handle_llm_call(mut ctx: StepContext) -> Result { // error without burning tokens. let response_schema = schema::compile_response_schema(&ctx.params)?; + // Streaming applies only without `response_schema`: the validate/repair + // loop needs complete responses, so `stream` + `response_schema` falls + // back to a regular (non-streaming) call with identical durable output. + let stream_requested = ctx + .params + .get("stream") + .and_then(Value::as_bool) + .unwrap_or(false); + let delta_sink = if stream_requested && !dry { + if response_schema.is_some() { + warn!( + block_id = %ctx.block_id.as_str(), + "llm_call: stream=true combined with response_schema — falling back to non-streaming" + ); + None + } else { + Some(DeltaSink { + instance_id: ctx.instance_id, + block_id: ctx.block_id.as_str().to_string(), + }) + } + } else { + None + }; + if let Some(providers) = ctx.params.get("providers").and_then(Value::as_array) { if dry { // Validate before skipping: an empty providers array is a config @@ -170,7 +248,13 @@ pub async fn handle_llm_call(mut ctx: StepContext) -> Result { // Resolve artifact-backed image blocks ONCE, before the failover loop // and schema paths, so every provider attempt reuses the fetched bytes. multimodal::resolve_message_images(&ctx.storage, ctx.instance_id, &mut ctx.params).await?; - return handle_llm_call_failover(&ctx.params, &providers, response_schema.as_ref()).await; + return handle_llm_call_failover( + &ctx.params, + &providers, + response_schema.as_ref(), + delta_sink.as_ref(), + ) + .await; } let provider = ctx @@ -204,7 +288,9 @@ pub async fn handle_llm_call(mut ctx: StepContext) -> Result { .await .map_err(SchemaCallFailure::into_permanent)? } - None => dispatch_provider(&ctx.params, &api_key, &base, &provider).await?, + None => { + dispatch_provider(&ctx.params, &api_key, &base, &provider, delta_sink.as_ref()).await? + } }; emit_gen_ai_telemetry(&ctx.params, &provider, &out); // Capture token usage for cost aggregation (best-effort — never fails the call). @@ -251,17 +337,20 @@ fn emit_gen_ai_telemetry(params: &Value, provider: &str, out: &Value) { ); } -/// Route a single call to the correct provider API. +/// Route a single call to the correct provider API. `deltas: Some(_)` selects +/// the streaming wire protocol (SSE consumption + delta publication); the +/// returned output has the same shape either way. async fn dispatch_provider( params: &Value, api_key: &str, base: &str, provider: &str, + deltas: Option<&DeltaSink>, ) -> Result { if provider == "anthropic" { - anthropic::call_anthropic(params, api_key, base).await + anthropic::call_anthropic(params, api_key, base, deltas).await } else { - openai::call_openai_compat(params, api_key, base, provider).await + openai::call_openai_compat(params, api_key, base, provider, deltas).await } } @@ -334,7 +423,9 @@ async fn call_provider_with_schema( let mut last_response = String::new(); for attempt in 0..=max { - let mut out = dispatch_provider(&work, api_key, base, provider) + // Always non-streaming: schema validation + repair needs the complete + // response (the `stream` fallback is decided in `handle_llm_call`). + let mut out = dispatch_provider(&work, api_key, base, provider, None) .await .map_err(SchemaCallFailure::Provider)?; let (input, output) = usage_tokens(&out); @@ -448,6 +539,7 @@ async fn handle_llm_call_failover( params: &Value, providers: &[Value], response_schema: Option<&schema::CompiledSchema>, + delta_sink: Option<&DeltaSink>, ) -> Result { if providers.is_empty() { return Err(permanent("providers array is empty".to_string())); @@ -459,12 +551,12 @@ async fn handle_llm_call_failover( .map_or(DEFAULT_TOTAL_TIMEOUT, Duration::from_secs); if total_timeout.is_zero() { - return failover_inner(params, providers, response_schema).await; + return failover_inner(params, providers, response_schema, delta_sink).await; } match tokio::time::timeout( total_timeout, - failover_inner(params, providers, response_schema), + failover_inner(params, providers, response_schema, delta_sink), ) .await { @@ -479,6 +571,7 @@ async fn failover_inner( params: &Value, providers: &[Value], response_schema: Option<&schema::CompiledSchema>, + delta_sink: Option<&DeltaSink>, ) -> Result { let per_attempt_timeout = params .get("per_provider_timeout_secs") @@ -523,7 +616,9 @@ async fn failover_inner( .await .map_err(SchemaCallFailure::into_retryable_for_failover) } - None => dispatch_provider(&merged, &api_key, &base, provider_name).await, + None => { + dispatch_provider(&merged, &api_key, &base, provider_name, delta_sink).await + } } }; @@ -601,7 +696,7 @@ mod tests { fn empty_providers_returns_error() { let result = tokio::runtime::Runtime::new() .unwrap() - .block_on(handle_llm_call_failover(&json!({}), &[], None)); + .block_on(handle_llm_call_failover(&json!({}), &[], None, None)); assert!(result.is_err()); } @@ -695,7 +790,7 @@ mod tests { fn cumulative_timeout_returns_error_on_empty_providers_before_timeout() { let result = tokio::runtime::Runtime::new() .unwrap() - .block_on(handle_llm_call_failover(&json!({}), &[], None)); + .block_on(handle_llm_call_failover(&json!({}), &[], None, None)); assert!(matches!(result, Err(StepError::Permanent { .. }))); } @@ -704,7 +799,7 @@ mod tests { let params = json!({"total_timeout_secs": 0}); let result = tokio::runtime::Runtime::new() .unwrap() - .block_on(handle_llm_call_failover(¶ms, &[], None)); + .block_on(handle_llm_call_failover(¶ms, &[], None, None)); assert!(matches!(result, Err(StepError::Permanent { .. }))); } @@ -1381,4 +1476,457 @@ mod tests { "no validation in dry-run" ); } + + // --- streaming (mock SSE servers) --------------------------------------- + // + // Same dep-free TCP pattern as `start_openai_mock`, but the response is a + // Server-Sent Events body (no Content-Length; body ends when the + // connection closes). `hold_open` keeps the socket open after the body to + // exercise the inter-chunk idle timeout. + + async fn start_sse_mock( + sse_body: String, + hold_open: bool, + ) -> (String, Arc>>) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let base = format!("http://127.0.0.1:{port}"); + let requests = Arc::new(tokio::sync::Mutex::new(Vec::::new())); + let requests_srv = Arc::clone(&requests); + + tokio::spawn(async move { + loop { + let Ok((mut stream, _)) = listener.accept().await else { + break; + }; + let raw = read_http_request(&mut stream).await; + if let Some(pos) = raw.windows(4).position(|w| w == b"\r\n\r\n") { + if let Ok(v) = serde_json::from_slice::(&raw[pos + 4..]) { + requests_srv.lock().await.push(v); + } + } + let resp = format!( + "HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\n\ + Cache-Control: no-cache\r\nConnection: close\r\n\r\n{sse_body}" + ); + let _ = stream.write_all(resp.as_bytes()).await; + if hold_open { + // Leave the connection open with no further data so the + // client's idle timeout (not EOF) decides the outcome. + tokio::time::sleep(Duration::from_secs(60)).await; + } + let _ = stream.shutdown().await; + } + }); + + super::super::builtin::mark_url_safe_for_test(&base).await; + (base, requests) + } + + /// Render JSON chunks as `OpenAI`-style `data:` SSE events. + fn openai_sse(chunks: &[Value], with_done: bool) -> String { + use std::fmt::Write as _; + let mut body = String::new(); + for chunk in chunks { + let _ = write!(body, "data: {chunk}\n\n"); + } + if with_done { + body.push_str("data: [DONE]\n\n"); + } + body + } + + /// Render `(event, data)` pairs as Anthropic-style named SSE events. + fn anthropic_sse(events: &[(&str, Value)]) -> String { + use std::fmt::Write as _; + let mut body = String::new(); + for (name, data) in events { + let _ = write!(body, "event: {name}\ndata: {data}\n\n"); + } + body + } + + /// A complete `OpenAI` SSE stream producing "Hello world" with usage 10/5. + fn openai_hello_stream() -> String { + openai_sse( + &[ + json!({"id": "c1", "model": "gpt-test", "choices": [ + {"index": 0, "delta": {"role": "assistant", "content": ""}, "finish_reason": null}]}), + json!({"choices": [{"index": 0, "delta": {"content": "Hello "}, "finish_reason": null}]}), + json!({"choices": [{"index": 0, "delta": {"content": "world"}, "finish_reason": null}]}), + json!({"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]}), + json!({"choices": [], "usage": {"prompt_tokens": 10, "completion_tokens": 5}}), + ], + true, + ) + } + + fn stream_params(base: &str) -> Value { + json!({ + "provider": "openai", + "model": "gpt-test", + "api_key": "k", + "base_url": base, + "messages": [{"role": "user", "content": "hello"}], + "stream": true, + }) + } + + #[tokio::test] + async fn openai_streaming_output_matches_non_streaming_shape() { + // Non-streaming reference call. + let (base_ref, _) = start_openai_mock(vec![openai_resp("Hello world", 10, 5)]).await; + let mut ref_params = stream_params(&base_ref); + ref_params["stream"] = json!(false); + let reference = handle_llm_call(test_ctx(ref_params).await).await.unwrap(); + + // Streaming call over the canned SSE stream. + let (base, requests) = start_sse_mock(openai_hello_stream(), false).await; + let out = handle_llm_call(test_ctx(stream_params(&base)).await) + .await + .unwrap(); + + assert_eq!( + out, reference, + "accumulated streaming output must equal the non-streaming shape" + ); + assert_eq!(out["message"]["content"], "Hello world"); + assert_eq!(out["usage"]["prompt_tokens"], 10); + assert_eq!(out["usage"]["completion_tokens"], 5); + assert_eq!(out["finish_reason"], "stop"); + + let reqs = requests.lock().await; + assert_eq!(reqs.len(), 1); + assert_eq!(reqs[0]["stream"], true, "provider asked to stream"); + assert_eq!( + reqs[0]["stream_options"]["include_usage"], true, + "usage requested so the streamed output keeps token accounting" + ); + } + + #[tokio::test] + async fn openai_streaming_accumulates_tool_call_deltas() { + let body = openai_sse( + &[ + json!({"model": "gpt-test", "choices": [{"index": 0, "delta": {"role": "assistant", "tool_calls": [ + {"index": 0, "id": "call_1", "type": "function", "function": {"name": "search", "arguments": ""}} + ]}, "finish_reason": null}]}), + json!({"choices": [{"index": 0, "delta": {"tool_calls": [ + {"index": 0, "function": {"arguments": "{\"q\":"}} + ]}, "finish_reason": null}]}), + json!({"choices": [{"index": 0, "delta": {"tool_calls": [ + {"index": 0, "function": {"arguments": "\"rust\"}"}} + ]}, "finish_reason": null}]}), + json!({"choices": [{"index": 0, "delta": {}, "finish_reason": "tool_calls"}]}), + json!({"choices": [], "usage": {"prompt_tokens": 7, "completion_tokens": 3}}), + ], + true, + ); + let (base, _requests) = start_sse_mock(body, false).await; + let out = handle_llm_call(test_ctx(stream_params(&base)).await) + .await + .unwrap(); + + assert_eq!(out["finish_reason"], "tool_calls"); + assert_eq!( + out["message"]["content"], + Value::Null, + "tool-call-only response keeps content null like non-streaming" + ); + let calls = out["message"]["tool_calls"].as_array().unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0]["id"], "call_1"); + assert_eq!(calls[0]["type"], "function"); + assert_eq!(calls[0]["function"]["name"], "search"); + assert_eq!(calls[0]["function"]["arguments"], "{\"q\":\"rust\"}"); + assert_eq!(out["usage"]["prompt_tokens"], 7); + } + + #[tokio::test] + async fn openai_stream_dropped_before_done_is_retryable() { + // Chunks but no `[DONE]`: the connection closes mid-response. + let body = openai_sse( + &[json!({"model": "gpt-test", "choices": [ + {"index": 0, "delta": {"role": "assistant", "content": "Hel"}, "finish_reason": null}]})], + false, + ); + let (base, _) = start_sse_mock(body, false).await; + let err = handle_llm_call(test_ctx(stream_params(&base)).await) + .await + .expect_err("incomplete stream must fail"); + match err { + StepError::Retryable { message, .. } => { + assert!(message.contains("[DONE]"), "{message}"); + } + other @ StepError::Permanent { .. } => panic!("expected Retryable, got {other}"), + } + } + + #[tokio::test] + async fn openai_stream_idle_timeout_is_retryable() { + // One chunk, then the socket stays open and silent. + let body = openai_sse( + &[json!({"model": "gpt-test", "choices": [ + {"index": 0, "delta": {"role": "assistant", "content": "He"}, "finish_reason": null}]})], + false, + ); + let (base, _) = start_sse_mock(body, true).await; + let mut params = stream_params(&base); + params["stream_idle_timeout_secs"] = json!(1); + let err = handle_llm_call(test_ctx(params).await) + .await + .expect_err("stalled stream must fail"); + match err { + StepError::Retryable { message, .. } => { + assert!(message.contains("stalled"), "{message}"); + } + other @ StepError::Permanent { .. } => panic!("expected Retryable, got {other}"), + } + } + + /// A complete Anthropic event stream producing "a cat" with usage 4/2. + fn anthropic_cat_stream() -> String { + anthropic_sse(&[ + ( + "message_start", + json!({"type": "message_start", "message": { + "id": "m1", "model": "claude-test", "role": "assistant", + "usage": {"input_tokens": 4, "output_tokens": 1}}}), + ), + ( + "content_block_start", + json!({"type": "content_block_start", "index": 0, + "content_block": {"type": "text", "text": ""}}), + ), + ("ping", json!({"type": "ping"})), + ( + "content_block_delta", + json!({"type": "content_block_delta", "index": 0, + "delta": {"type": "text_delta", "text": "a "}}), + ), + ( + "content_block_delta", + json!({"type": "content_block_delta", "index": 0, + "delta": {"type": "text_delta", "text": "cat"}}), + ), + ( + "content_block_stop", + json!({"type": "content_block_stop", "index": 0}), + ), + ( + "message_delta", + json!({"type": "message_delta", + "delta": {"stop_reason": "end_turn"}, "usage": {"output_tokens": 2}}), + ), + ("message_stop", json!({"type": "message_stop"})), + ]) + } + + fn anthropic_stream_params(base: &str) -> Value { + json!({ + "provider": "anthropic", + "model": "claude-test", + "api_key": "k", + "base_url": base, + "messages": [{"role": "user", "content": "describe"}], + "stream": true, + }) + } + + #[tokio::test] + async fn anthropic_streaming_output_matches_non_streaming_shape() { + // Non-streaming reference call against the canned full response. + let (base_ref, _) = start_openai_mock(vec![json!({ + "content": [{"type": "text", "text": "a cat"}], + "model": "claude-test", + "stop_reason": "end_turn", + "usage": {"input_tokens": 4, "output_tokens": 2}, + })]) + .await; + let mut ref_params = anthropic_stream_params(&base_ref); + ref_params["stream"] = json!(false); + let reference = handle_llm_call(test_ctx(ref_params).await).await.unwrap(); + + let (base, requests) = start_sse_mock(anthropic_cat_stream(), false).await; + let out = handle_llm_call(test_ctx(anthropic_stream_params(&base)).await) + .await + .unwrap(); + + assert_eq!( + out, reference, + "accumulated streaming output must equal the non-streaming shape" + ); + assert_eq!(out["message"]["content"], "a cat"); + assert_eq!(out["usage"], json!({"input_tokens": 4, "output_tokens": 2})); + assert_eq!(out["finish_reason"], "end_turn"); + assert_eq!(out["model"], "claude-test"); + + let reqs = requests.lock().await; + assert_eq!(reqs.len(), 1); + assert_eq!(reqs[0]["stream"], true); + } + + #[tokio::test] + async fn anthropic_streaming_accumulates_tool_use_input_json() { + let body = anthropic_sse(&[ + ( + "message_start", + json!({"type": "message_start", "message": { + "id": "m1", "model": "claude-test", "role": "assistant", + "usage": {"input_tokens": 9, "output_tokens": 1}}}), + ), + ( + "content_block_start", + json!({"type": "content_block_start", "index": 0, + "content_block": {"type": "tool_use", "id": "tc_1", "name": "search", "input": {}}}), + ), + ( + "content_block_delta", + json!({"type": "content_block_delta", "index": 0, + "delta": {"type": "input_json_delta", "partial_json": "{\"q\":"}}), + ), + ( + "content_block_delta", + json!({"type": "content_block_delta", "index": 0, + "delta": {"type": "input_json_delta", "partial_json": "\"rust\"}"}}), + ), + ( + "content_block_stop", + json!({"type": "content_block_stop", "index": 0}), + ), + ( + "message_delta", + json!({"type": "message_delta", + "delta": {"stop_reason": "tool_use"}, "usage": {"output_tokens": 6}}), + ), + ("message_stop", json!({"type": "message_stop"})), + ]); + let (base, _) = start_sse_mock(body, false).await; + let out = handle_llm_call(test_ctx(anthropic_stream_params(&base)).await) + .await + .unwrap(); + + assert_eq!(out["finish_reason"], "tool_use"); + let calls = out["message"]["tool_calls"].as_array().unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0]["id"], "tc_1"); + assert_eq!(calls[0]["function"]["name"], "search"); + assert_eq!(calls[0]["function"]["arguments"], "{\"q\":\"rust\"}"); + assert_eq!(out["usage"]["output_tokens"], 6); + } + + #[tokio::test] + async fn anthropic_stream_overloaded_error_event_is_retryable() { + let body = anthropic_sse(&[ + ( + "message_start", + json!({"type": "message_start", "message": { + "id": "m1", "model": "claude-test", "role": "assistant", + "usage": {"input_tokens": 1, "output_tokens": 1}}}), + ), + ( + "error", + json!({"type": "error", + "error": {"type": "overloaded_error", "message": "Overloaded"}}), + ), + ]); + let (base, _) = start_sse_mock(body, false).await; + let err = handle_llm_call(test_ctx(anthropic_stream_params(&base)).await) + .await + .expect_err("error event must fail the call"); + match err { + StepError::Retryable { message, .. } => { + assert!(message.contains("overloaded_error"), "{message}"); + } + other @ StepError::Permanent { .. } => panic!("expected Retryable, got {other}"), + } + } + + #[tokio::test] + async fn anthropic_stream_dropped_before_message_stop_is_retryable() { + let body = anthropic_sse(&[ + ( + "message_start", + json!({"type": "message_start", "message": { + "id": "m1", "model": "claude-test", "role": "assistant", + "usage": {"input_tokens": 1, "output_tokens": 1}}}), + ), + ( + "content_block_start", + json!({"type": "content_block_start", "index": 0, + "content_block": {"type": "text", "text": ""}}), + ), + ]); + let (base, _) = start_sse_mock(body, false).await; + let err = handle_llm_call(test_ctx(anthropic_stream_params(&base)).await) + .await + .expect_err("incomplete stream must fail"); + match err { + StepError::Retryable { message, .. } => { + assert!(message.contains("message_stop"), "{message}"); + } + other @ StepError::Permanent { .. } => panic!("expected Retryable, got {other}"), + } + } + + #[tokio::test] + async fn stream_with_response_schema_falls_back_to_non_streaming() { + // The mock serves a plain (non-SSE) JSON response — proving the call + // went out as a regular completion despite `stream: true`. + let (base, requests) = + start_openai_mock(vec![openai_resp(r#"{"name": "ok"}"#, 10, 5)]).await; + let mut params = stream_params(&base); + params["response_schema"] = name_schema(); + let out = handle_llm_call(test_ctx(params).await).await.unwrap(); + assert_eq!(out["schema_valid"], true); + assert_eq!(out["name"], "ok"); + + let reqs = requests.lock().await; + assert_eq!(reqs.len(), 1); + assert!( + reqs[0].get("stream").is_none(), + "request body must not ask the provider to stream" + ); + } + + #[tokio::test] + async fn streaming_publishes_llm_deltas_to_bus() { + use crate::stream_bus::{stream_bus, StreamEvent}; + + let (base, _) = start_sse_mock(openai_hello_stream(), false).await; + let ctx = test_ctx(stream_params(&base)).await; + let instance_id = ctx.instance_id; + let mut rx = stream_bus().subscribe(instance_id); + + let out = handle_llm_call(ctx).await.unwrap(); + assert_eq!(out["message"]["content"], "Hello world"); + + let mut deltas = Vec::new(); + while let Ok(event) = rx.try_recv() { + let StreamEvent::LlmDelta { block_id, delta } = event; + assert_eq!(block_id, "b"); + deltas.push(delta); + } + assert_eq!(deltas, vec!["Hello ".to_string(), "world".to_string()]); + } + + #[tokio::test] + async fn failover_streams_from_second_provider_after_first_fails() { + // Provider A's key env var is unset → skipped; provider B streams. + let (base_b, requests_b) = start_sse_mock(openai_hello_stream(), false).await; + let ctx = test_ctx(json!({ + "messages": [{"role": "user", "content": "hello"}], + "stream": true, + "providers": [ + {"provider": "openai", "api_key_env": "LLM_TEST_UNSET_STREAM_KEY", "model": "m", "base_url": base_b}, + {"provider": "openai", "api_key": "k", "model": "gpt-test", "base_url": base_b}, + ], + })) + .await; + let out = handle_llm_call(ctx).await.unwrap(); + assert_eq!(out["tried"], json!(["openai", "openai"])); + assert_eq!(out["message"]["content"], "Hello world"); + assert_eq!(out["usage"]["completion_tokens"], 5); + assert_eq!(requests_b.lock().await.len(), 1); + } } diff --git a/orch8-engine/src/handlers/llm/openai.rs b/orch8-engine/src/handlers/llm/openai.rs index de6113ab..920ccbe5 100644 --- a/orch8-engine/src/handlers/llm/openai.rs +++ b/orch8-engine/src/handlers/llm/openai.rs @@ -1,27 +1,20 @@ -use serde_json::{json, Value}; -use tracing::debug; +use std::collections::BTreeMap; + +use serde_json::{json, Map, Value}; +use tracing::{debug, warn}; use orch8_types::error::StepError; use super::common::{ classify_api_error, classify_reqwest_error, is_json_object_format, merge_json_response_fields, - retryable, + retryable, safe_truncate, }; -use super::{http_client, openai_default_model}; - -pub(super) async fn call_openai_compat( - params: &Value, - api_key: &str, - base_url: &str, - provider: &str, -) -> Result { - let url = format!("{base_url}/chat/completions"); - - let model = params - .get("model") - .and_then(Value::as_str) - .unwrap_or(openai_default_model()); +use super::sse::{next_chunk, stream_idle_timeout, SseParser}; +use super::{http_client, openai_default_model, DeltaSink}; +/// Build the `/chat/completions` request body shared by the streaming and +/// non-streaming paths (so multimodal message conversion behaves identically). +fn build_body(params: &Value, model: &str) -> Map { let messages = { let mut msgs = Vec::new(); if let Some(sys) = params.get("system").and_then(Value::as_str) { @@ -56,9 +49,33 @@ pub(super) async fn call_openai_compat( body.insert(key.into(), val.clone()); } } + body +} + +pub(super) async fn call_openai_compat( + params: &Value, + api_key: &str, + base_url: &str, + provider: &str, + deltas: Option<&DeltaSink>, +) -> Result { + let url = format!("{base_url}/chat/completions"); + + let model = params + .get("model") + .and_then(Value::as_str) + .unwrap_or(openai_default_model()); + + let mut body = build_body(params, model); + if deltas.is_some() { + body.insert("stream".into(), json!(true)); + // Without this the final usage chunk is omitted and the streamed + // output would lose token accounting relative to non-streaming. + body.insert("stream_options".into(), json!({"include_usage": true})); + } let body = Value::Object(body); - debug!(url = %url, model = %model, provider = %provider, "llm_call: OpenAI-compatible"); + debug!(url = %url, model = %model, provider = %provider, streaming = deltas.is_some(), "llm_call: OpenAI-compatible"); let resp = http_client() .post(&url) @@ -69,6 +86,10 @@ pub(super) async fn call_openai_compat( .await .map_err(|e| classify_reqwest_error(&e))?; + if let Some(sink) = deltas { + return consume_openai_stream(resp, params, provider, sink).await; + } + let status = resp.status().as_u16(); let resp_body: Value = resp .json() @@ -105,3 +126,177 @@ pub(super) async fn call_openai_compat( Ok(output) } + +/// Partially-accumulated tool call, keyed by the chunk `index` field. +#[derive(Default)] +struct ToolCallAcc { + id: Option, + name: Option, + arguments: String, +} + +/// Accumulator for `OpenAI` streaming chunks. Produces an output identical in +/// shape to the non-streaming path once the stream terminates with `[DONE]`. +#[derive(Default)] +struct OpenAiStreamAcc { + model: Value, + role: Option, + content: String, + saw_content: bool, + tool_calls: BTreeMap, + finish_reason: Value, + usage: Value, + done: bool, +} + +impl OpenAiStreamAcc { + /// Ingest one `data:` payload, publishing text deltas to `sink`. + fn ingest(&mut self, data: &str, sink: &DeltaSink) { + if data.trim() == "[DONE]" { + self.done = true; + return; + } + let Ok(chunk) = serde_json::from_str::(data) else { + warn!( + data_preview = %safe_truncate(data, 200), + "llm_call: skipping unparseable streaming chunk" + ); + return; + }; + if self.model.is_null() { + if let Some(model) = chunk.get("model").filter(|m| m.is_string()) { + self.model = model.clone(); + } + } + // The final chunk (stream_options.include_usage) carries usage with + // an empty choices array. + if let Some(usage) = chunk.get("usage").filter(|u| u.is_object()) { + self.usage = usage.clone(); + } + let Some(choice) = chunk.get("choices").and_then(|c| c.get(0)) else { + return; + }; + if let Some(fr) = choice.get("finish_reason").filter(|fr| !fr.is_null()) { + self.finish_reason = fr.clone(); + } + let Some(delta) = choice.get("delta") else { + return; + }; + if let Some(role) = delta.get("role").and_then(Value::as_str) { + self.role.get_or_insert_with(|| role.to_string()); + } + if let Some(text) = delta.get("content").and_then(Value::as_str) { + self.saw_content = true; + if !text.is_empty() { + self.content.push_str(text); + sink.publish(text); + } + } + if let Some(calls) = delta.get("tool_calls").and_then(Value::as_array) { + for call in calls { + let index = call.get("index").and_then(Value::as_u64).unwrap_or(0); + let acc = self.tool_calls.entry(index).or_default(); + if let Some(id) = call.get("id").and_then(Value::as_str) { + acc.id.get_or_insert_with(|| id.to_string()); + } + if let Some(function) = call.get("function") { + if let Some(name) = function.get("name").and_then(Value::as_str) { + acc.name.get_or_insert_with(|| name.to_string()); + } + if let Some(args) = function.get("arguments").and_then(Value::as_str) { + acc.arguments.push_str(args); + } + } + } + } + } + + /// Assemble the final output in the exact shape of the non-streaming path. + fn into_output(self, provider: &str, params: &Value) -> Value { + let mut message = serde_json::Map::new(); + message.insert( + "role".into(), + json!(self.role.as_deref().unwrap_or("assistant")), + ); + // Tool-call-only responses report `content: null` (matching the + // non-streaming response shape); otherwise the accumulated text. + let content = if !self.saw_content && !self.tool_calls.is_empty() { + Value::Null + } else { + json!(self.content) + }; + message.insert("content".into(), content); + if !self.tool_calls.is_empty() { + let calls: Vec = self + .tool_calls + .into_values() + .map(|tc| { + json!({ + "id": tc.id, + "type": "function", + "function": {"name": tc.name, "arguments": tc.arguments}, + }) + }) + .collect(); + message.insert("tool_calls".into(), json!(calls)); + } + + let mut output = json!({ + "provider": provider, + "model": self.model, + "message": Value::Object(message), + "finish_reason": self.finish_reason, + "usage": self.usage, + }); + + if is_json_object_format(params) && self.saw_content { + merge_json_response_fields(&self.content, &mut output); + } + output + } +} + +/// Consume an OpenAI-compatible SSE stream, publishing text deltas and +/// accumulating the full response. +/// +/// Termination contract: the provider must send `data: [DONE]`. A stream +/// that ends (EOF / connection drop) before that is incomplete and fails +/// **retryable**; a chunk gap longer than the idle timeout fails retryable +/// via [`next_chunk`]. +async fn consume_openai_stream( + mut resp: reqwest::Response, + params: &Value, + provider: &str, + sink: &DeltaSink, +) -> Result { + let status = resp.status().as_u16(); + if status >= 400 { + // Error responses are plain JSON, not SSE. + let resp_body: Value = resp + .json() + .await + .map_err(|e| retryable(format!("response parse error: {e}")))?; + return Err(classify_api_error(status, &resp_body)); + } + + let idle_timeout = stream_idle_timeout(params); + let mut parser = SseParser::default(); + let mut acc = OpenAiStreamAcc::default(); + + while let Some(chunk) = next_chunk(&mut resp, idle_timeout).await? { + for event in parser.push(&chunk) { + acc.ingest(&event.data, sink); + } + if acc.done { + break; + } + } + + if !acc.done { + return Err(retryable( + "provider stream ended before [DONE] — response is incomplete".to_string(), + )); + } + + Ok(acc.into_output(provider, params)) +} diff --git a/orch8-engine/src/handlers/llm/sse.rs b/orch8-engine/src/handlers/llm/sse.rs new file mode 100644 index 00000000..346bd1e6 --- /dev/null +++ b/orch8-engine/src/handlers/llm/sse.rs @@ -0,0 +1,213 @@ +//! Minimal incremental Server-Sent Events parser for provider streams. +//! +//! Both streaming wire formats consumed by `llm_call` are SSE: +//! - `OpenAI` `/chat/completions` with `stream: true` — `data:` lines carrying +//! JSON chunks, terminated by `data: [DONE]`. +//! - Anthropic `/messages` with `stream: true` — `event:` + `data:` pairs +//! (`message_start`, `content_block_delta`, `message_delta`, …). +//! +//! The parser is byte-buffer based: network chunks may split an event (or +//! even a UTF-8 code point) anywhere, so bytes are buffered until a blank +//! line terminates an event. Splitting on `\n` byte values is UTF-8 safe — +//! `0x0A` never appears inside a multi-byte sequence. + +use std::time::Duration; + +use orch8_types::error::StepError; + +use super::common::retryable; + +/// One parsed SSE event: optional `event:` name plus the joined `data:` payload. +#[derive(Debug, PartialEq, Eq)] +pub(super) struct SseEvent { + /// Value of the `event:` field, when present (Anthropic names its events; + /// `OpenAI` does not). + pub event: Option, + /// All `data:` lines of the event, joined with `\n` per the SSE spec. + pub data: String, +} + +/// Incremental SSE parser. Feed raw body chunks with [`SseParser::push`]; +/// completed events are returned as they terminate. +#[derive(Default)] +pub(super) struct SseParser { + buf: Vec, +} + +impl SseParser { + /// Append a body chunk and drain every event completed by it. + pub fn push(&mut self, chunk: &[u8]) -> Vec { + self.buf.extend_from_slice(chunk); + let mut events = Vec::new(); + while let Some((end, sep_len)) = find_event_boundary(&self.buf) { + let raw: Vec = self.buf.drain(..end + sep_len).collect(); + if let Some(event) = parse_event(&raw[..end]) { + events.push(event); + } + } + events + } +} + +/// Locate the first blank-line event terminator: `\n\n` or `\n\r\n`. +/// Returns `(payload_end, separator_len)`. +fn find_event_boundary(buf: &[u8]) -> Option<(usize, usize)> { + let mut i = 0; + while i + 1 < buf.len() { + if buf[i] == b'\n' { + if buf[i + 1] == b'\n' { + return Some((i, 2)); + } + if buf[i + 1] == b'\r' && i + 2 < buf.len() && buf[i + 2] == b'\n' { + return Some((i, 3)); + } + } + i += 1; + } + None +} + +/// Parse one raw event block into its `event:` / `data:` fields. Returns +/// `None` for blocks carrying neither (comments, `id:` only, keepalives). +fn parse_event(raw: &[u8]) -> Option { + let text = String::from_utf8_lossy(raw); + let mut event = None; + let mut data_lines: Vec<&str> = Vec::new(); + for line in text.lines() { + // With CRLF framing the `\n\r\n` event boundary leaves the final + // line's `\r` in the payload (no trailing `\n` for `lines()` to pair + // it with) — strip it explicitly. + let line = line.strip_suffix('\r').unwrap_or(line); + if let Some(value) = line.strip_prefix("data:") { + data_lines.push(value.strip_prefix(' ').unwrap_or(value)); + } else if let Some(value) = line.strip_prefix("event:") { + event = Some(value.trim().to_string()); + } + // `id:`, `retry:` and `:` comment lines are intentionally ignored. + } + if event.is_none() && data_lines.is_empty() { + return None; + } + Some(SseEvent { + event, + data: data_lines.join("\n"), + }) +} + +/// Default inter-chunk idle timeout: a provider stream that goes silent for +/// this long is treated as stalled and fails retryable. +pub(super) const DEFAULT_STREAM_IDLE_TIMEOUT: Duration = Duration::from_secs(30); + +/// Resolve the inter-chunk idle timeout from `stream_idle_timeout_secs` +/// (default 30s; values of 0 fall back to the default). +pub(super) fn stream_idle_timeout(params: &serde_json::Value) -> Duration { + params + .get("stream_idle_timeout_secs") + .and_then(serde_json::Value::as_u64) + .filter(|secs| *secs > 0) + .map_or(DEFAULT_STREAM_IDLE_TIMEOUT, Duration::from_secs) +} + +/// Read the next body chunk with the inter-chunk idle timeout applied. +/// +/// All mid-stream failures are **retryable**: a stall, a dropped connection, +/// or a transport error after a 200 status are transient conditions — a step +/// retry (or provider failover) re-issues the whole request. +pub(super) async fn next_chunk( + resp: &mut reqwest::Response, + idle_timeout: Duration, +) -> Result, StepError> { + match tokio::time::timeout(idle_timeout, resp.chunk()).await { + Err(_) => Err(retryable(format!( + "stream stalled: no data received for {idle_timeout:?}" + ))), + Ok(Ok(chunk)) => Ok(chunk), + Ok(Err(e)) => Err(retryable(format!("stream transport error: {e}"))), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn ev(event: Option<&str>, data: &str) -> SseEvent { + SseEvent { + event: event.map(str::to_string), + data: data.to_string(), + } + } + + #[test] + fn parses_single_data_event() { + let mut p = SseParser::default(); + assert_eq!(p.push(b"data: {\"a\":1}\n\n"), vec![ev(None, "{\"a\":1}")]); + } + + #[test] + fn parses_event_split_across_chunks() { + let mut p = SseParser::default(); + assert!(p.push(b"data: hel").is_empty()); + assert!(p.push(b"lo\n").is_empty()); + assert_eq!(p.push(b"\n"), vec![ev(None, "hello")]); + } + + #[test] + fn parses_multiple_events_in_one_chunk() { + let mut p = SseParser::default(); + let events = p.push(b"data: a\n\ndata: b\n\n"); + assert_eq!(events, vec![ev(None, "a"), ev(None, "b")]); + } + + #[test] + fn parses_named_events_with_crlf() { + let mut p = SseParser::default(); + let events = p.push(b"event: message_start\r\ndata: {}\r\n\r\n"); + assert_eq!(events, vec![ev(Some("message_start"), "{}")]); + } + + #[test] + fn joins_multi_line_data() { + let mut p = SseParser::default(); + assert_eq!( + p.push(b"data: line1\ndata: line2\n\n"), + vec![ev(None, "line1\nline2")] + ); + } + + #[test] + fn ignores_comments_and_ids() { + let mut p = SseParser::default(); + assert!(p.push(b": keepalive\n\nid: 7\n\n").is_empty()); + } + + #[test] + fn data_without_space_after_colon() { + let mut p = SseParser::default(); + assert_eq!(p.push(b"data:[DONE]\n\n"), vec![ev(None, "[DONE]")]); + } + + #[test] + fn multibyte_utf8_split_across_chunks_survives() { + let mut p = SseParser::default(); + let payload = "data: caf\u{e9}\n\n".as_bytes().to_vec(); + let (a, b) = payload.split_at(8); // splits the two-byte 'é' + assert!(p.push(a).is_empty()); + assert_eq!(p.push(b), vec![ev(None, "caf\u{e9}")]); + } + + #[test] + fn idle_timeout_default_and_override() { + assert_eq!( + stream_idle_timeout(&serde_json::json!({})), + DEFAULT_STREAM_IDLE_TIMEOUT + ); + assert_eq!( + stream_idle_timeout(&serde_json::json!({"stream_idle_timeout_secs": 5})), + Duration::from_secs(5) + ); + assert_eq!( + stream_idle_timeout(&serde_json::json!({"stream_idle_timeout_secs": 0})), + DEFAULT_STREAM_IDLE_TIMEOUT + ); + } +} diff --git a/orch8-engine/src/lib.rs b/orch8-engine/src/lib.rs index ee964447..a5130fb9 100644 --- a/orch8-engine/src/lib.rs +++ b/orch8-engine/src/lib.rs @@ -18,6 +18,7 @@ pub mod scheduler; pub mod scheduling; pub mod sequence_cache; pub mod signals; +pub mod stream_bus; pub mod template; pub mod triggers; pub mod webhooks; diff --git a/orch8-engine/src/stream_bus.rs b/orch8-engine/src/stream_bus.rs new file mode 100644 index 00000000..3fa46077 --- /dev/null +++ b/orch8-engine/src/stream_bus.rs @@ -0,0 +1,174 @@ +//! In-process pub/sub bus for live instance events. +//! +//! Carries incremental events that are NOT part of an instance's durable +//! record — today that is `llm_delta`: token deltas emitted by a streaming +//! `llm_call` step. Producers (step handlers) publish per-instance; consumers +//! (the API's `GET /instances/{id}/stream` SSE endpoint, running in the same +//! process) subscribe per-instance and forward events to connected clients. +//! +//! Design: +//! - One [`tokio::sync::broadcast`] channel per instance, **lazily created on +//! first subscribe**. Publishing to an instance nobody watches is a cheap +//! map lookup and a no-op — no channel is allocated. +//! - Capacity-bounded ([`CHANNEL_CAPACITY`]): a slow subscriber lags (drops +//! oldest events) instead of buffering unboundedly. Deltas are best-effort +//! by design — the full accumulated text always lands in the step's durable +//! output. +//! - Channels are dropped when their last subscriber disconnects: a publish +//! that finds no receivers removes the entry, and every subscribe sweeps +//! entries whose receiver count reached zero. +//! +//! The bus lives in `orch8-engine` (not the API crate) because handlers are +//! the producers; engine-only deployments (e.g. mobile) compile it unchanged +//! and simply never subscribe, so it stays inert. + +use std::collections::HashMap; +use std::sync::{Mutex, OnceLock, PoisonError}; + +use serde::Serialize; +use tokio::sync::broadcast; + +use orch8_types::ids::InstanceId; + +/// Per-instance broadcast capacity. A lagging subscriber loses the oldest +/// deltas (best-effort live view); the durable step output is unaffected. +pub const CHANNEL_CAPACITY: usize = 256; + +/// A live (non-durable) event observed while an instance runs. +#[derive(Clone, Debug, PartialEq, Eq, Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum StreamEvent { + /// An incremental text delta from a streaming `llm_call` step. + LlmDelta { + /// Block id of the `llm_call` step producing the delta. + block_id: String, + /// The incremental text fragment (already-accumulated text is not + /// repeated; concatenating deltas reproduces the full response text). + delta: String, + }, +} + +/// Registry of per-instance broadcast channels. See the module docs for the +/// lifecycle (lazy creation, capacity bound, drop-on-last-unsubscribe). +#[derive(Default)] +pub struct StreamBus { + channels: Mutex>>, +} + +impl StreamBus { + /// Lock the registry, recovering from a poisoned lock (the registry holds + /// only channel handles, so a panicked holder cannot leave it logically + /// inconsistent). + fn lock( + &self, + ) -> std::sync::MutexGuard<'_, HashMap>> { + self.channels.lock().unwrap_or_else(PoisonError::into_inner) + } + + /// Subscribe to live events for `instance_id`, creating the channel if it + /// does not exist yet. Also sweeps channels whose subscribers are all gone + /// (cheap GC keyed to the rare subscribe path). + pub fn subscribe(&self, instance_id: InstanceId) -> broadcast::Receiver { + let mut map = self.lock(); + map.retain(|_, tx| tx.receiver_count() > 0); + map.entry(instance_id) + .or_insert_with(|| broadcast::channel(CHANNEL_CAPACITY).0) + .subscribe() + } + + /// Publish an event to subscribers of `instance_id`. A no-op when nobody + /// is subscribed; if the last subscriber has gone away, the channel is + /// dropped here. + pub fn publish(&self, instance_id: InstanceId, event: StreamEvent) { + let mut map = self.lock(); + if let Some(tx) = map.get(&instance_id) { + if tx.send(event).is_err() { + // No live receivers — drop the channel so the map can't grow + // with stale entries between subscribes. + map.remove(&instance_id); + } + } + } + + /// `true` when at least one subscriber is listening for `instance_id`. + #[must_use] + pub fn has_subscribers(&self, instance_id: InstanceId) -> bool { + self.lock() + .get(&instance_id) + .is_some_and(|tx| tx.receiver_count() > 0) + } +} + +/// Process-global stream bus. The API layer (same process in `orch8-server`) +/// subscribes here; the `llm_call` handler publishes here. +pub fn stream_bus() -> &'static StreamBus { + static BUS: OnceLock = OnceLock::new(); + BUS.get_or_init(StreamBus::default) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn delta(block: &str, text: &str) -> StreamEvent { + StreamEvent::LlmDelta { + block_id: block.to_string(), + delta: text.to_string(), + } + } + + #[tokio::test] + async fn publish_reaches_subscriber() { + let bus = StreamBus::default(); + let id = InstanceId::new(); + let mut rx = bus.subscribe(id); + bus.publish(id, delta("b1", "hel")); + bus.publish(id, delta("b1", "lo")); + assert_eq!(rx.recv().await.unwrap(), delta("b1", "hel")); + assert_eq!(rx.recv().await.unwrap(), delta("b1", "lo")); + } + + #[tokio::test] + async fn publish_without_subscriber_is_noop() { + let bus = StreamBus::default(); + let id = InstanceId::new(); + bus.publish(id, delta("b", "x")); + assert!(!bus.has_subscribers(id), "no channel created by publish"); + } + + #[tokio::test] + async fn channel_dropped_after_last_subscriber_goes_away() { + let bus = StreamBus::default(); + let id = InstanceId::new(); + let rx = bus.subscribe(id); + assert!(bus.has_subscribers(id)); + drop(rx); + assert!(!bus.has_subscribers(id)); + // The next publish observes zero receivers and removes the entry. + bus.publish(id, delta("b", "x")); + assert!(bus.lock().get(&id).is_none(), "stale channel removed"); + } + + #[tokio::test] + async fn events_are_scoped_per_instance() { + let bus = StreamBus::default(); + let (a, b) = (InstanceId::new(), InstanceId::new()); + let mut rx_a = bus.subscribe(a); + let mut rx_b = bus.subscribe(b); + bus.publish(a, delta("blk", "only-a")); + assert_eq!(rx_a.recv().await.unwrap(), delta("blk", "only-a")); + assert!(matches!( + rx_b.try_recv(), + Err(broadcast::error::TryRecvError::Empty) + )); + } + + #[test] + fn llm_delta_serializes_with_type_tag() { + let json = serde_json::to_value(delta("step-1", "hi")).unwrap(); + assert_eq!( + json, + serde_json::json!({"type": "llm_delta", "block_id": "step-1", "delta": "hi"}) + ); + } +}