Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 58 additions & 3 deletions orch8-api/src/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,25 @@
//! `GET /instances/{id}/stream` returns a Server-Sent Events stream that emits:
//! - `state` events when instance state changes
//! - `output` events when new block outputs appear
//! - `llm_delta` events with incremental text from a streaming `llm_call`
//! step (`stream: true`), forwarded live from the in-process
//! [`orch8_engine::stream_bus`] — not polled from storage
//! - `done` event when the instance reaches a terminal state
//!
//! The stream polls storage at a configurable interval (default 500ms) and sends
//! keepalive comments every 15s. Closes when the instance reaches a terminal state.
//! State/output changes poll storage at a configurable interval (default
//! 500ms); keepalive comments are sent every 15s. Closes when the instance
//! reaches a terminal state.

use std::convert::Infallible;
use std::time::Duration;

use axum::extract::{Path, Query, State};
use axum::response::sse::{Event, KeepAlive, Sse};
use serde::Deserialize;
use tokio::sync::broadcast;
use uuid::Uuid;

use orch8_engine::stream_bus::{stream_bus, StreamEvent};
use orch8_types::ids::InstanceId;
use orch8_types::instance::InstanceState;

Expand All @@ -33,6 +39,19 @@ const fn default_poll_ms() -> u64 {
500
}

/// Await the next live event from the stream-bus subscription. With no
/// subscription left (`None`, after the channel closed) pends forever — the
/// caller guards the select arm with `delta_rx.is_some()`, so this branch is
/// then never polled.
async fn next_delta(
rx: Option<&mut broadcast::Receiver<StreamEvent>>,
) -> Result<StreamEvent, broadcast::error::RecvError> {
match rx {
Some(rx) => rx.recv().await,
None => std::future::pending().await,
}
}

const fn is_terminal(state: InstanceState) -> bool {
matches!(
state,
Expand All @@ -47,7 +66,7 @@ const fn is_terminal(state: InstanceState) -> bool {
("poll_ms" = Option<u64>, Query, description = "Poll interval in ms (min 100ms)"),
),
responses(
(status = 200, description = "Server-Sent Events stream of instance state/tree/output changes", content_type = "text/event-stream"),
(status = 200, description = "Server-Sent Events stream of instance state/output changes plus live llm_delta events from streaming llm_call steps", content_type = "text/event-stream"),
(status = 404, description = "Instance not found"),
(status = 503, description = "Too many concurrent streams"),
)
Expand Down Expand Up @@ -103,6 +122,12 @@ pub(crate) async fn stream_instance(
let shutdown = state.shutdown.clone();
let storage = state.storage.clone();

// Subscribe to the in-process stream bus BEFORE the response starts, so
// llm_delta events published after the client connects are never missed.
// The engine runs in the same process (orch8-server); in API-only
// deployments the channel simply never receives anything.
let mut delta_rx = Some(stream_bus().subscribe(instance_id));

tokio::spawn(async move {
// Hold the permit for the lifetime of the stream task; dropped
// automatically on return/panic, freeing a slot.
Expand All @@ -123,6 +148,36 @@ pub(crate) async fn stream_instance(
() = shutdown.cancelled() => break,
() = tx.closed() => break,
_ = ticker.tick() => {}
event = next_delta(delta_rx.as_mut()), if delta_rx.is_some() => {
match event {
Ok(ev) => {
// `StreamEvent` serialization is infallible (plain
// strings + tag); fall back to `{}` defensively.
let payload = serde_json::to_string(&ev)
.unwrap_or_else(|_| "{}".to_string());
let sse = Event::default().event("llm_delta").data(payload);
if tx.send(Ok(sse)).await.is_err() {
break;
}
}
Err(broadcast::error::RecvError::Lagged(skipped)) => {
// Best-effort live view: a slow client just loses
// deltas; the durable output event carries the
// full text anyway.
tracing::debug!(
skipped,
instance_id = %instance_id.into_uuid(),
"stream: client lagged behind llm_delta broadcast"
);
}
Err(broadcast::error::RecvError::Closed) => {
delta_rx = None;
}
}
// Deltas don't advance the storage poll; wait for the
// next tick before re-querying.
continue;
}
}

// If the last iterations failed, wait an extra backoff period
Expand Down
175 changes: 175 additions & 0 deletions orch8-api/tests/streaming.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
//! E2E tests for `GET /instances/{id}/stream` — specifically the forwarding
//! of live `llm_delta` events from the in-process engine stream bus.
//!
//! The test harness does not run the engine tick loop, so instead of
//! executing a real streaming `llm_call` we publish to the bus directly
//! (exactly what the handler's `DeltaSink` does) and assert a connected SSE
//! client receives the events. The handler-side publication path is covered
//! by `orch8-engine`'s mock-SSE provider tests.

use std::time::Duration;

use orch8_api::test_harness::spawn_test_server;
use orch8_engine::stream_bus::{stream_bus, StreamEvent};
use orch8_types::ids::InstanceId;
use reqwest::StatusCode;
use serde_json::json;
use uuid::Uuid;

fn mk_sequence_body(id: Uuid) -> serde_json::Value {
json!({
"id": id,
"tenant_id": "t1",
"namespace": "ns1",
"name": "stream-seq",
"version": 1,
"deprecated": false,
"blocks": [
{
"type": "step",
"id": "s1",
"handler": "noop",
"params": {},
"cancellable": true
}
],
"interceptors": null,
"created_at": chrono::Utc::now().to_rfc3339()
})
}

/// Create a sequence and an instance of it; returns the instance id.
async fn create_instance(client: &reqwest::Client, base_url: &str) -> Uuid {
let seq_id = Uuid::now_v7();
let resp = client
.post(format!("{base_url}/sequences"))
.header("X-Tenant-Id", "t1")
.json(&mk_sequence_body(seq_id))
.send()
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::CREATED);

let resp = client
.post(format!("{base_url}/instances"))
.header("X-Tenant-Id", "t1")
.json(&json!({
"sequence_id": seq_id,
"tenant_id": "t1",
"namespace": "ns1",
"context": { "data": {}, "config": {}, "audit": [] }
}))
.send()
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::CREATED);
let created: serde_json::Value = resp.json().await.unwrap();
created["id"].as_str().unwrap().parse().unwrap()
}

/// Read SSE chunks into `buf` until it contains `needle` (10s safety cap).
async fn read_until(resp: &mut reqwest::Response, buf: &mut String, needle: &str) {
tokio::time::timeout(Duration::from_secs(10), async {
while !buf.contains(needle) {
let chunk = resp
.chunk()
.await
.expect("SSE read failed")
.expect("SSE stream ended before expected event");
buf.push_str(&String::from_utf8_lossy(&chunk));
}
})
.await
.unwrap_or_else(|_| panic!("timed out waiting for {needle:?} in SSE stream; got: {buf}"));
}

#[tokio::test]
async fn stream_forwards_llm_delta_events_from_bus() {
let srv = spawn_test_server().await;
let client = reqwest::Client::new();
let inst_id = create_instance(&client, &srv.base_url).await;

let mut resp = client
.get(format!("{}/instances/{inst_id}/stream", srv.base_url))
.header("X-Tenant-Id", "t1")
.query(&[("poll_ms", "100")])
.send()
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);

// The bus subscription is established before the response headers are
// sent, but wait for the first poll-driven `state` event anyway so we
// know the stream task is fully live.
let mut buf = String::new();
read_until(&mut resp, &mut buf, "event: state").await;

// Publish exactly what a streaming llm_call's DeltaSink publishes.
let instance_id = InstanceId::from_uuid(inst_id);
for delta in ["Hel", "lo"] {
stream_bus().publish(
instance_id,
StreamEvent::LlmDelta {
block_id: "s1".to_string(),
delta: delta.to_string(),
},
);
}

read_until(&mut resp, &mut buf, "event: llm_delta").await;
read_until(
&mut resp,
&mut buf,
r#"{"type":"llm_delta","block_id":"s1","delta":"Hel"}"#,
)
.await;
read_until(
&mut resp,
&mut buf,
r#"{"type":"llm_delta","block_id":"s1","delta":"lo"}"#,
)
.await;
}

#[tokio::test]
async fn stream_ignores_deltas_for_other_instances() {
let srv = spawn_test_server().await;
let client = reqwest::Client::new();
let inst_id = create_instance(&client, &srv.base_url).await;

let mut resp = client
.get(format!("{}/instances/{inst_id}/stream", srv.base_url))
.header("X-Tenant-Id", "t1")
.query(&[("poll_ms", "100")])
.send()
.await
.unwrap();
assert_eq!(resp.status(), StatusCode::OK);

let mut buf = String::new();
read_until(&mut resp, &mut buf, "event: state").await;

// A delta for a DIFFERENT instance must not reach this client; one for
// the watched instance (published after) must. Receiving the marker
// proves the foreign delta was not interleaved before it.
stream_bus().publish(
InstanceId::new(),
StreamEvent::LlmDelta {
block_id: "other".to_string(),
delta: "foreign".to_string(),
},
);
stream_bus().publish(
InstanceId::from_uuid(inst_id),
StreamEvent::LlmDelta {
block_id: "s1".to_string(),
delta: "marker".to_string(),
},
);

read_until(&mut resp, &mut buf, r#""delta":"marker""#).await;
assert!(
!buf.contains("foreign"),
"delta for another instance leaked into this stream: {buf}"
);
}
Loading
Loading