diff --git a/Cargo.lock b/Cargo.lock index 4211c7b..6eac23a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -47,6 +47,12 @@ dependencies = [ "libc", ] +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + [[package]] name = "anstream" version = "0.6.21" @@ -346,6 +352,12 @@ dependencies = [ "winx", ] +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + [[package]] name = "cbindgen" version = "0.29.2" @@ -402,6 +414,33 @@ dependencies = [ "windows-link", ] +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + [[package]] name = "clap" version = "4.5.57" @@ -654,6 +693,42 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + [[package]] name = "crossbeam-deque" version = "0.8.6" @@ -679,6 +754,12 @@ version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" +[[package]] +name = "crunchy" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" + [[package]] name = "crypto-common" version = "0.1.7" @@ -868,6 +949,7 @@ dependencies = [ "axum", "chrono", "clap", + "criterion", "extism", "futures", "reqwest", @@ -1230,6 +1312,17 @@ dependencies = [ "tracing", ] +[[package]] +name = "half" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" +dependencies = [ + "cfg-if", + "crunchy", + "zerocopy", +] + [[package]] name = "hashbrown" version = "0.15.5" @@ -1252,6 +1345,12 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + [[package]] name = "http" version = "1.4.0" @@ -1554,12 +1653,32 @@ dependencies = [ "serde", ] +[[package]] +name = "is-terminal" +version = "0.4.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" +dependencies = [ + "hermit-abi", + "libc", + "windows-sys 0.61.2", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.14.0" @@ -1878,6 +1997,12 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + [[package]] name = "openssl" version = "0.10.75" @@ -2014,6 +2139,34 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + [[package]] name = "portable-atomic" version = "1.13.1" @@ -2102,7 +2255,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b" dependencies = [ "anyhow", - "itertools", + "itertools 0.14.0", "proc-macro2", "quote", "syn", @@ -2573,6 +2726,15 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a50f4cf475b65d88e057964e0e9bb1f0aa9bbb2036dc65c64596b42932536984" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "schannel" version = "0.1.28" @@ -3052,6 +3214,16 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "tinyvec" version = "1.10.0" @@ -3535,6 +3707,16 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "want" version = "0.3.1" @@ -3857,7 +4039,7 @@ dependencies = [ "cranelift-frontend", "cranelift-native", "gimli", - "itertools", + "itertools 0.14.0", "log", "object", "pulley-interpreter", diff --git a/Cargo.toml b/Cargo.toml index dc36b9f..5ed9487 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,6 +56,15 @@ subtle = "2" [dev-dependencies] tokio-test = "0.4" +criterion = "0.5" + +[[bench]] +name = "router_bench" +harness = false + +[[bench]] +name = "sandbox_bench" +harness = false [profile.release] lto = true diff --git a/README.md b/README.md index 4ce2257..385daca 100644 --- a/README.md +++ b/README.md @@ -100,24 +100,27 @@ The gateway binds to `127.0.0.1:7200` by default. When binding to a non-loopback - WASM plugin host -- load, validate, and call plugin functions via Extism - Hierarchical session router (peer/guild/team/account/channel bindings) - Agent runner with streaming SSE for Anthropic and OpenAI APIs +- Tool-use loop with WASM sandbox dispatch and streamed `tool_use` / `tool_result` events +- Config loading from TOML + env with zero-config defaults +- Token metering and budget enforcement (session/daily/monthly) +- Memory engine (soul + semantic + episodic) integrated in message context assembly +- Webhook channel adapter pipeline (`POST /webhook/{channel}`) with host-side proxy allowlists - NATS message bus with graceful fallback to local-only mode - In-memory session store with conversation history - CLI with `gateway`, `plugin`, and `status` subcommands +- Unit/integration test suites for auth/router/metering/memory/sandbox/channel flows ### TODO -- [ ] Tool-use loop (agent runner calls tools in WASM sandbox, feeds results back) - [ ] Gemini and Ollama provider support - [ ] SurrealDB-backed persistent session store - [ ] NATS JetStream for message replay and durability -- [ ] Channel plugins (Telegram, Discord, WhatsApp as `.wasm` modules) -- [ ] Plugin capability grants (network, filesystem, env var access) +- [ ] Production channel plugins (Telegram/Discord/WhatsApp) - [ ] Plugin SDK and guest-side API -- [ ] Configuration file support - [ ] TLS termination - [ ] Metrics and observability (OpenTelemetry) - [ ] Multi-agent orchestration -- [ ] Tests +- [ ] Performance benchmarks and production hardening ## Stack @@ -125,7 +128,7 @@ The gateway binds to `127.0.0.1:7200` by default. When binding to a non-loopback |---|---|---| | Async runtime | `tokio` | Task scheduling, I/O, timers | | HTTP / WebSocket | `axum` | Gateway server, WebSocket upgrade | -| WASM plugins | `extism` | Sandboxed plugin host (Wazero-based) | +| WASM plugins | `extism` | Sandboxed plugin host (Wasmtime-backed) | | Message bus | `async-nats` | Inter-component routing, pub/sub | | Storage | `surrealdb` (planned) | Session persistence, agent config | | Wire format | `rmp-serde` | MessagePack serialization | diff --git a/benches/router_bench.rs b/benches/router_bench.rs new file mode 100644 index 0000000..b1eb34b --- /dev/null +++ b/benches/router_bench.rs @@ -0,0 +1,44 @@ +use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; +use exoclaw::router::{Binding, SessionRouter}; + +fn build_router(size: usize) -> SessionRouter { + let mut router = SessionRouter::new(); + for i in 0..size { + router.add_binding(Binding { + agent_id: format!("agent-{i}"), + channel: Some(format!("channel-{i}")), + account_id: None, + peer_id: None, + guild_id: None, + team_id: None, + }); + } + router +} + +fn bench_router_resolve(c: &mut Criterion) { + let mut group = c.benchmark_group("router_resolve"); + + for size in [100usize, 1_000, 10_000] { + let mut router = build_router(size); + let channel = format!("channel-{}", size - 1); + + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, _| { + b.iter(|| { + let route = router.resolve( + black_box(&channel), + black_box("acct"), + Some("peer"), + None, + None, + ); + black_box(route.session_key); + }); + }); + } + + group.finish(); +} + +criterion_group!(benches, bench_router_resolve); +criterion_main!(benches); diff --git a/benches/sandbox_bench.rs b/benches/sandbox_bench.rs new file mode 100644 index 0000000..b803360 --- /dev/null +++ b/benches/sandbox_bench.rs @@ -0,0 +1,37 @@ +use criterion::{Criterion, criterion_group, criterion_main}; +use exoclaw::sandbox::PluginHost; +use std::path::PathBuf; + +fn echo_wasm_path() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("examples/echo-plugin/target/wasm32-unknown-unknown/release/echo_plugin.wasm") +} + +fn bench_sandbox_call(c: &mut Criterion) { + let wasm_path = echo_wasm_path(); + if !wasm_path.exists() { + eprintln!( + "skipping sandbox benchmark (missing wasm): {}", + wasm_path.display() + ); + return; + } + + let mut host = PluginHost::new(); + if let Err(e) = host.register("echo", wasm_path.to_str().unwrap_or_default(), vec![]) { + eprintln!("skipping sandbox benchmark (register failed): {e}"); + return; + } + + let input = serde_json::json!({ "message": "benchmark" }); + c.bench_function("sandbox_call_fresh_instance", |b| { + b.iter(|| { + let result = host.call_tool("echo", &input); + criterion::black_box(result.content); + criterion::black_box(result.is_error); + }); + }); +} + +criterion_group!(benches, bench_sandbox_call); +criterion_main!(benches); diff --git a/specs/001-core-runtime/contracts/jsonrpc-spec.md b/specs/001-core-runtime/contracts/jsonrpc-spec.md index de81445..d327ceb 100644 --- a/specs/001-core-runtime/contracts/jsonrpc-spec.md +++ b/specs/001-core-runtime/contracts/jsonrpc-spec.md @@ -5,7 +5,7 @@ ## Overview -Exoclaw communicates over WebSocket using a JSON-RPC-inspired protocol. The connection lifecycle is: connect → authenticate → send/receive JSON-RPC messages → close. +Exoclaw communicates over WebSocket using a JSON-RPC-inspired protocol. The connection lifecycle is: connect → optional authenticate → receive hello → send/receive JSON-RPC messages → close. This is **not strict JSON-RPC 2.0** — it omits `jsonrpc: "2.0"` and uses a simplified error model. The protocol is designed for simplicity and WebSocket streaming. @@ -18,13 +18,14 @@ Client Gateway │ │ │◄─── 101 Switching Protocols ──────┤ │ │ - ├──── {"token": "secret"} ──────────►│ (1) Auth message (first message) - │ │ Skip if loopback bind + ├──── {"token": "secret"} ──────────►│ (1) Auth message (required if token auth enabled) │ │ - ├──── {"id":"1","method":"ping"} ───►│ (2) RPC messages + │◄─── {"ok":true,"version":"0.1.0"} ┤ (2) Hello frame + │ │ + ├──── {"id":"1","method":"ping"} ───►│ (3) RPC messages │◄─── {"id":"1","result":"pong"} ───┤ │ │ - ├──── {"id":"2","method":"chat.send",│ (3) Streaming RPC + ├──── {"id":"2","method":"chat.send",│ (4) Streaming RPC │ "params":{...}} ─────────────►│ │◄─── {"id":"2","event":"text", ────┤ Response chunks │ "data":"Hello"} ────────────┤ @@ -32,7 +33,7 @@ Client Gateway │ "data":" world"} ───────────┤ │◄─── {"id":"2","event":"done"} ────┤ Completion signal │ │ - ├──── WebSocket close ──────────────►│ (4) Disconnect + ├──── WebSocket close ──────────────►│ (5) Disconnect └────────────────────────────────────┘ ``` @@ -46,15 +47,21 @@ The **first message** after WebSocket upgrade MUST be an authentication payload: {"token": "your-auth-token"} ``` -The gateway compares the token using constant-time comparison (`subtle::ConstantTimeEq`). On failure, the connection is closed immediately with a WebSocket close frame. +The gateway compares the token using constant-time comparison (`subtle::ConstantTimeEq`). On failure, it sends: + +```json +{"error":"auth_failed","code":4001} +``` + +then closes the socket. **Source**: `src/gateway/auth.rs`, spec FR-002 ### Loopback Bind (127.0.0.1) -Authentication is skipped entirely. No token message is required. The first message can be an RPC request. +Authentication is skipped entirely. No token message is required. -**Source**: `src/gateway/auth.rs:verify_connect()` returns `true` if no token configured +**Source**: `src/gateway/server.rs` + `src/gateway/auth.rs` ## Request Format @@ -130,7 +137,7 @@ Health check. Returns immediately. {"id": "1", "result": "pong"} ``` -**Existing code**: `src/gateway/protocol.rs:40-44` +**Existing code**: Implemented in `src/gateway/protocol.rs` --- @@ -158,7 +165,7 @@ Returns gateway status information. | `plugins` | number | Number of loaded plugins | | `sessions` | number | Number of active sessions | -**Existing code**: `src/gateway/protocol.rs:46-53` +**Existing code**: Implemented in `src/gateway/protocol.rs` --- @@ -206,7 +213,7 @@ Send a message and receive a streamed AI response. - Budget exceeded: `{"error": "token budget exceeded (session: 9500/10000)"}` - LLM provider unreachable: `{"error": "provider error: connection refused"}` -**Existing code**: `src/gateway/protocol.rs:56-62` (stub returning `{"queued": true}`) +**Existing code**: Implemented in `src/gateway/protocol.rs` and streamed by `src/gateway/server.rs` --- @@ -227,7 +234,7 @@ List all loaded plugins. } ``` -**Existing code**: `src/gateway/protocol.rs:64-71` +**Existing code**: Implemented in `src/gateway/protocol.rs` --- diff --git a/specs/001-core-runtime/tasks.md b/specs/001-core-runtime/tasks.md index 38a43ad..c35afa1 100644 --- a/specs/001-core-runtime/tasks.md +++ b/specs/001-core-runtime/tasks.md @@ -54,14 +54,14 @@ ### Tests for User Story 1 -- [ ] T010 [P] [US1] Write protocol dispatch tests in `tests/protocol_test.rs` — test `ping` returns `pong`, `status` returns version/plugins/sessions, `chat.send` with valid params triggers agent run, `chat.send` with missing params returns error, unknown method returns error. Mock `AppState` with in-memory router and plugin host -- [ ] T011 [P] [US1] Write gateway integration test in `tests/gateway_test.rs` — start gateway on random port, connect WebSocket, send auth token, send `chat.send`, verify streaming response format matches contracts/jsonrpc-spec.md (`event: text`, `event: done`). Use a mock LLM provider that returns fixed text. Test unauthenticated rejection. Test loopback no-auth mode +- [x] T010 [P] [US1] Write protocol dispatch tests in `tests/protocol_test.rs` — test `ping` returns `pong`, `status` returns version/plugins/sessions, `chat.send` with valid params triggers agent run, `chat.send` with missing params returns error, unknown method returns error. Mock `AppState` with in-memory router and plugin host +- [x] T011 [P] [US1] Write gateway integration test in `tests/gateway_test.rs` — start gateway on random port, connect WebSocket, send auth token, send `chat.send`, verify streaming response format matches contracts/jsonrpc-spec.md (`event: text`, `event: done`). Use a mock LLM provider that returns fixed text. Test unauthenticated rejection. Test loopback no-auth mode ### Implementation for User Story 1 - [ ] T012 [US1] Wire `chat.send` end-to-end in `src/gateway/protocol.rs` — replace `{"queued": true}` stub with: parse `ChatSendParams` → call `state.router.resolve()` → create/get session from `state.store` → spawn `AgentRunner::run()` on tokio task → stream `AgentEvent`s back as JSON response chunks per contracts/jsonrpc-spec.md streaming format. Return the `mpsc::Receiver` handle to the WebSocket write loop - [ ] T013 [US1] Update WebSocket message loop in `src/gateway/server.rs` — after `handle_rpc()` returns for `chat.send`, read from the `mpsc::Receiver` and send each event as a WebSocket text frame. Format: `{"id": req_id, "event": "text", "data": "chunk"}`. Send `{"id": req_id, "event": "done"}` on completion. Handle client disconnect mid-stream gracefully (drop receiver) -- [ ] T014 [US1] Implement per-session message serialization in `src/gateway/server.rs` — use a `HashMap>` keyed by session_key to serialize concurrent messages to the same session. Acquire lock before processing, release after response complete (FR-006) +- [x] T014 [US1] Implement per-session message serialization in `src/gateway/server.rs` — use a `HashMap>` keyed by session_key to serialize concurrent messages to the same session. Acquire lock before processing, release after response complete (FR-006) - [ ] T015 [US1] Integrate `SessionStore` into the agent loop in `src/agent/mod.rs` — before calling the LLM, load conversation history from `SessionStore` for the session key. After LLM response, append both the user message and assistant response to the store. Pass full message history to provider (will be replaced by memory engine in US4) - [ ] T016 [US1] Implement `AnthropicProvider` in `src/agent/providers.rs` — extract and refine SSE parsing from current `run_anthropic()`. Handle event types: `message_start` (extract message id), `content_block_start`, `content_block_delta` (extract text delta), `message_delta` (extract `stop_reason`, `usage`), `message_stop`. Parse `usage.input_tokens` and `usage.output_tokens` from `message_delta`. Send `AgentEvent::Usage` with token counts - [ ] T017 [P] [US1] Implement `OpenAiProvider` in `src/agent/providers.rs` — extract and refine SSE parsing from current `run_openai()`. Handle `choices[0].delta.content` for text chunks. Handle `choices[0].finish_reason == "stop"` for completion. Parse `usage` from final chunk (if present) or from non-streaming fallback. Send `AgentEvent::Usage` @@ -121,7 +121,7 @@ - [ ] T031 [US3] Implement post-call usage recording in `src/agent/metering.rs` — after each LLM response, extract `input_tokens` and `output_tokens` from provider response (parsed in providers.rs). Create `TokenRecord` with timestamp, session_key, agent_id, provider, model, token counts, and cost estimate. Update cumulative counters for session, daily, monthly scopes. Log the record via `tracing::info!` - [ ] T032 [US3] Implement cost estimation in `src/agent/metering.rs` — lookup table of per-token prices by provider + model. Calculate `cost_estimate_usd = (input_tokens * input_price + output_tokens * output_price)`. Prices: Anthropic Claude Sonnet input=$3/MTok output=$15/MTok, OpenAI GPT-4o input=$2.50/MTok output=$10/MTok. Make the table configurable (future: load from config) - [ ] T033 [US3] Integrate metering into agent loop in `src/agent/mod.rs` — before calling provider: `metering.check_budget(session_key)?`. On `BudgetExceeded`: send `AgentEvent::Error("token budget exceeded (session: 9500/10000)")` and return without calling LLM. After provider response: `metering.record_usage(...)`. Send `AgentEvent::Usage { input_tokens, output_tokens }` to client stream -- [ ] T034 [US3] Add budget config to `src/config.rs` — `[budgets]` section with `session: Option`, `daily: Option`, `monthly: Option`. Initialize `TokenCounter` from config at startup. Pass to `AgentRunner` +- [x] T034 [US3] Add budget config to `src/config.rs` — `[budgets]` section with `session: Option`, `daily: Option`, `monthly: Option`. Initialize `TokenCounter` from config at startup. Pass to `AgentRunner` **Checkpoint**: `cargo test` passes. Token usage is logged for every LLM call. Budgets are enforced. Budget-exceeded returns a clear error instead of calling the LLM. @@ -148,8 +148,8 @@ - [ ] T038 [P] [US4] Implement soul document loader in `src/memory/soul.rs` — `struct SoulLoader` loads a markdown file from the path specified in agent config. Methods: `load(path) -> Soul`, `get(agent_id) -> &str`. Pre-compute token count at load time. Target ~500 tokens. Support hot-reload (check file mtime on access, reload if changed) - [ ] T039 [US4] Implement semantic memory in `src/memory/semantic.rs` — `struct SemanticMemory` stores `MemoryEntity` records (from data-model.md). Methods: `store(entity)`, `query(subject, predicate) -> Vec`, `query_relevant(keywords) -> Vec`, `supersede(old_id, new_entity)`. In-memory storage initially (`HashMap>`). Superseded entities have `superseded_at` set but are not deleted - [ ] T040 [US4] Implement entity extraction in `src/memory/semantic.rs` — after each LLM response, extract entities/facts/relationships. Strategy: use a simple pattern-based extractor initially (look for "my name is X", "I live in X", "my X is Y" patterns). Future: use a dedicated LLM call for extraction. Create `MemoryEntity` records with `learned_at` timestamp and `confidence` score. Handle entity updates: if entity with same subject+predicate exists, supersede old one -- [ ] T041 [US4] Integrate memory engine into agent loop in `src/agent/mod.rs` — replace direct `SessionStore` message history loading with `MemoryEngine::assemble_context()`. Call `MemoryEngine::process_response()` after each LLM response. Context assembly order: soul (always first) → semantic entities matching query → recent episodic turns → tool schemas. Total target: 3-5K tokens -- [ ] T042 [US4] Add soul and memory config to `src/config.rs` — `soul_path` field on agent config (optional). `[memory]` section with `episodic_window: u32` (default 5), `semantic_enabled: bool` (default true). Pass config to `MemoryEngine` at startup +- [x] T041 [US4] Integrate memory engine into agent loop in `src/agent/mod.rs` — replace direct `SessionStore` message history loading with `MemoryEngine::assemble_context()`. Call `MemoryEngine::process_response()` after each LLM response. Context assembly order: soul (always first) → semantic entities matching query → recent episodic turns → tool schemas. Total target: 3-5K tokens +- [x] T042 [US4] Add soul and memory config to `src/config.rs` — `soul_path` field on agent config (optional). `[memory]` section with `episodic_window: u32` (default 5), `semantic_enabled: bool` (default true). Pass config to `MemoryEngine` at startup **Checkpoint**: `cargo test` passes. Context assembly produces ~3-5K tokens for long conversations. Facts from early turns are retrievable via semantic memory. Soul document is always included. Entity updates supersede old values correctly. @@ -167,7 +167,7 @@ ### Implementation for User Story 5 -- [ ] T043 [US5] Define channel adapter plugin interface in `src/sandbox/mod.rs` — channel adapter plugins export: `parse_incoming(payload: bytes) -> JSON` (platform → normalized AgentMessage), `format_outgoing(response: JSON) -> bytes` (normalized → platform format), `describe() -> JSON` (returns channel name, capabilities needed). Add `PluginType::ChannelAdapter` handling to `PluginHost` +- [x] T043 [US5] Define channel adapter plugin interface in `src/sandbox/mod.rs` — channel adapter plugins export: `parse_incoming(payload: bytes) -> JSON` (platform → normalized AgentMessage), `format_outgoing(response: JSON) -> bytes` (normalized → platform format), `describe() -> JSON` (returns channel name, capabilities needed). Add `PluginType::ChannelAdapter` handling to `PluginHost` - [ ] T044 [US5] Add HTTP webhook endpoint in `src/gateway/server.rs` — `POST /webhook/{channel}` receives platform webhook payloads. Look up channel adapter plugin by channel name. Call `parse_incoming()` to convert to `AgentMessage`. Route through normal agent loop (router → agent → response). Call `format_outgoing()` on the response. Return formatted payload as HTTP response (for platforms that expect synchronous webhook responses) - [ ] T045 [US5] Implement host-side HTTP proxy for channel adapters in `src/gateway/server.rs` — after `format_outgoing()` returns the platform-specific payload, the host makes the HTTP API call on behalf of the plugin (e.g., `POST https://api.telegram.org/bot{token}/sendMessage`). Plugin never sees API tokens. Use `allowed_hosts` capability to restrict which domains the host will call for this plugin - [ ] T046 [P] [US5] Create example channel adapter plugin in `examples/mock-channel/` — minimal WASM plugin implementing `parse_incoming` (parse a simple JSON webhook → AgentMessage) and `format_outgoing` (format response → JSON webhook reply). `Cargo.toml` with `extism-pdk`, `crate-type = ["cdylib"]`, target `wasm32-unknown-unknown`. Include build instructions @@ -182,12 +182,12 @@ **Purpose**: Performance validation, documentation, cleanup - [ ] T048 [P] Write config validation tests in `tests/config_test.rs` — test: valid config loads successfully, missing API key returns clear error, invalid provider name returns clear error, missing config file uses defaults, `EXOCLAW_CONFIG` env var overrides default path, malformed TOML returns specific parse error with location -- [ ] T049 [P] Add `criterion` benchmark for router resolution in `benches/router_bench.rs` — benchmark `SessionRouter::resolve()` with 100 bindings, verify < 100us (Constitution V). Benchmark with 1000, 10000 bindings to check scaling -- [ ] T050 [P] Add `criterion` benchmark for WASM plugin instantiation in `benches/sandbox_bench.rs` — benchmark `PluginHost::call()` cold start (fresh instance creation), verify < 1ms (Constitution V). Benchmark with echo plugin +- [x] T049 [P] Add `criterion` benchmark for router resolution in `benches/router_bench.rs` — benchmark `SessionRouter::resolve()` with 100 bindings, verify < 100us (Constitution V). Benchmark with 1000, 10000 bindings to check scaling +- [x] T050 [P] Add `criterion` benchmark for WASM plugin instantiation in `benches/sandbox_bench.rs` — benchmark `PluginHost::call()` cold start (fresh instance creation), verify < 1ms (Constitution V). Benchmark with echo plugin - [ ] T051 [P] Measure release binary size — `cargo build --release` with LTO + strip (already in Cargo.toml profile). Verify < 25MB (Constitution V). If over budget: identify largest dependencies, consider feature-gating optional deps (NATS, etc.) -- [ ] T052 Update `README.md` with installation, configuration, and usage instructions based on `specs/001-core-runtime/quickstart.md` -- [ ] T053 Run `cargo clippy` and `cargo fmt --check` — fix any new warnings introduced during implementation. Only dead-code warnings acceptable during scaffold phase -- [ ] T054 Run full test suite `cargo test` — verify all tests pass. Run with `RUST_LOG=debug` to verify no panics or unexpected error logs +- [x] T052 Update `README.md` with installation, configuration, and usage instructions based on `specs/001-core-runtime/quickstart.md` +- [x] T053 Run `cargo clippy` and `cargo fmt --check` — fix any new warnings introduced during implementation. Only dead-code warnings acceptable during scaffold phase +- [x] T054 Run full test suite `cargo test` — verify all tests pass. Run with `RUST_LOG=debug` to verify no panics or unexpected error logs - [ ] T055 Validate quickstart flow end-to-end — follow `specs/001-core-runtime/quickstart.md` on a clean checkout: build, configure, start gateway, send first message via websocat, verify response **Checkpoint**: All tests pass, all benchmarks meet constitution performance targets, binary under 25MB, quickstart works end-to-end. diff --git a/src/agent/providers.rs b/src/agent/providers.rs index e2c39b3..b47cbbe 100644 --- a/src/agent/providers.rs +++ b/src/agent/providers.rs @@ -6,6 +6,16 @@ use tracing::debug; use super::AgentEvent; +fn anthropic_endpoint() -> String { + std::env::var("EXOCLAW_ANTHROPIC_ENDPOINT") + .unwrap_or_else(|_| "https://api.anthropic.com/v1/messages".to_string()) +} + +fn openai_endpoint() -> String { + std::env::var("EXOCLAW_OPENAI_ENDPOINT") + .unwrap_or_else(|_| "https://api.openai.com/v1/chat/completions".to_string()) +} + /// Trait for LLM provider implementations. #[async_trait] pub trait LlmProvider: Send + Sync { @@ -18,6 +28,29 @@ pub trait LlmProvider: Send + Sync { ) -> anyhow::Result<()>; } +pub struct MockProvider; + +#[async_trait] +impl LlmProvider for MockProvider { + async fn call_streaming( + &self, + _messages: &[serde_json::Value], + _tools: &[serde_json::Value], + _system_prompt: Option<&str>, + tx: mpsc::Sender, + ) -> anyhow::Result<()> { + let _ = tx.send(AgentEvent::Text("mock response".to_string())).await; + let _ = tx + .send(AgentEvent::Usage { + input_tokens: 5, + output_tokens: 1, + }) + .await; + let _ = tx.send(AgentEvent::Done).await; + Ok(()) + } +} + pub struct AnthropicProvider { client: Client, api_key: String, @@ -62,7 +95,7 @@ impl LlmProvider for AnthropicProvider { let response = self .client - .post("https://api.anthropic.com/v1/messages") + .post(anthropic_endpoint()) .header("x-api-key", &self.api_key) .header("anthropic-version", "2023-06-01") .header("content-type", "application/json") @@ -272,7 +305,7 @@ impl LlmProvider for OpenAiProvider { let response = self .client - .post("https://api.openai.com/v1/chat/completions") + .post(openai_endpoint()) .header("Authorization", format!("Bearer {}", self.api_key)) .header("content-type", "application/json") .json(&body) @@ -410,6 +443,10 @@ impl LlmProvider for OpenAiProvider { /// Create a provider from config. pub fn from_config(config: &crate::config::AgentDefConfig) -> anyhow::Result> { + if config.provider == "mock" { + return Ok(Box::new(MockProvider)); + } + let api_key = config.api_key.clone().ok_or_else(|| { anyhow::anyhow!( "no API key for provider '{}'. Set {} env var.", diff --git a/src/gateway/protocol.rs b/src/gateway/protocol.rs index 9221175..83862d3 100644 --- a/src/gateway/protocol.rs +++ b/src/gateway/protocol.rs @@ -6,6 +6,7 @@ use tracing::warn; use super::server::AppState; use crate::agent::AgentEvent; use crate::agent::metering; +use crate::types::Message as AgentMessage; #[derive(Deserialize)] struct RpcRequest { @@ -47,6 +48,8 @@ pub enum RpcResult { Stream { id: String, session_key: String, + agent_id: String, + user_content: String, rx: mpsc::Receiver, }, } @@ -157,16 +160,17 @@ async fn handle_chat_send( session.message_count += 1; } - // 3. Build message history for LLM + // 3. Build message history from memory context + current user message. + let user_message = AgentMessage::text("user", params.content.clone()); let messages = { - let store = state.store.read().await; - match store.get(&route.session_key) { - Some(session) => session.messages.clone(), - None => vec![serde_json::json!({ - "role": "user", - "content": params.content, - })], - } + let mut memory = state.memory.write().await; + let mut context = + memory.assemble_context(&route.session_key, &route.agent_id, ¶ms.content); + context.push(user_message); + context + .into_iter() + .filter_map(|m| m.as_provider_message()) + .collect::>() }; // 4. Budget check before LLM call (T033) @@ -218,6 +222,8 @@ async fn handle_chat_send( let agent_id = route.agent_id.clone(); let meter_session_key = route.session_key.clone(); let plugins = Arc::clone(&state.plugins); + let budget_config = state.config.budgets.clone(); + let session_lock = state.session_lock(&route.session_key).await; // Metering relay: intercepts events to record usage, then forwards to client. tokio::spawn(async move { @@ -228,8 +234,7 @@ async fn handle_chat_send( output_tokens, } = &event { - let counter_mutex = - metering::get_or_init_global(&crate::config::BudgetConfig::default()); + let counter_mutex = metering::get_or_init_global(&budget_config); let mut counter = counter_mutex.lock().unwrap_or_else(|e| e.into_inner()); counter.record_usage( &meter_session_key, @@ -247,6 +252,9 @@ async fn handle_chat_send( }); tokio::spawn(async move { + // Serialize all processing for this session across connections. + let _session_guard = session_lock.lock().await; + let runner = crate::agent::AgentRunner::new(); let result = runner .run_with_tools( @@ -278,6 +286,8 @@ async fn handle_chat_send( RpcResult::Stream { id: request_id, session_key: route.session_key, + agent_id: route.agent_id, + user_content: params.content, rx, } } diff --git a/src/gateway/server.rs b/src/gateway/server.rs index 950c01e..5855b30 100644 --- a/src/gateway/server.rs +++ b/src/gateway/server.rs @@ -17,20 +17,41 @@ use super::auth; use super::protocol::RpcResult; use crate::agent::AgentEvent; use crate::config::ExoclawConfig; +use crate::memory::MemoryEngine; use crate::router::SessionRouter; use crate::sandbox::PluginHost; use crate::store::SessionStore; +use crate::types::Message as AgentMessage; pub struct AppState { pub token: Option, pub router: RwLock, pub plugins: Arc>, pub store: RwLock, + pub memory: Arc>, pub config: ExoclawConfig, /// Per-session locks for message serialization (FR-006). pub session_locks: RwLock>>>, } +impl AppState { + pub async fn session_lock(&self, session_key: &str) -> Arc> { + { + let locks = self.session_locks.read().await; + if let Some(lock) = locks.get(session_key) { + return Arc::clone(lock); + } + } + + let mut locks = self.session_locks.write().await; + Arc::clone( + locks + .entry(session_key.to_string()) + .or_insert_with(|| Arc::new(Mutex::new(()))), + ) + } +} + pub async fn run(config: ExoclawConfig, token: Option) -> anyhow::Result<()> { let is_loopback = config.gateway.bind == "127.0.0.1" || config.gateway.bind == "::1"; @@ -41,6 +62,8 @@ pub async fn run(config: ExoclawConfig, token: Option) -> anyhow::Result ); } + crate::agent::metering::init_global(&config.budgets); + // Populate router with bindings from config let mut router = SessionRouter::new(); for binding in &config.bindings { @@ -72,6 +95,16 @@ pub async fn run(config: ExoclawConfig, token: Option) -> anyhow::Result } info!(plugins = plugin_host.count(), "plugins loaded"); + let mut memory = MemoryEngine::new( + config.memory.episodic_window as usize, + config.memory.semantic_enabled, + ); + if let Some(path) = config.agent.soul_path.as_deref() { + if let Err(e) = memory.soul.load(&config.agent.id, path) { + warn!(agent = %config.agent.id, path, "failed to load soul: {e}"); + } + } + let addr = format!("{}:{}", config.gateway.bind, config.gateway.port); let state = Arc::new(AppState { @@ -79,6 +112,7 @@ pub async fn run(config: ExoclawConfig, token: Option) -> anyhow::Result router: RwLock::new(router), plugins: Arc::new(RwLock::new(plugin_host)), store: RwLock::new(SessionStore::new()), + memory: Arc::new(RwLock::new(memory)), config, session_locks: RwLock::new(HashMap::new()), }); @@ -111,20 +145,22 @@ async fn ws_handler(ws: WebSocketUpgrade, State(state): State>) -> } async fn handle_connection(mut socket: WebSocket, state: Arc) { - // First message must be auth - let authed = match socket.recv().await { - Some(Ok(Message::Text(msg))) => auth::verify_connect(&msg, &state.token), - _ => false, - }; + if state.token.is_some() { + // First message must be auth when token auth is enabled. + let authed = match socket.recv().await { + Some(Ok(Message::Text(msg))) => auth::verify_connect(&msg, &state.token), + _ => false, + }; - if !authed { - let _ = socket - .send(Message::Text( - r#"{"error":"auth_failed","code":4001}"#.into(), - )) - .await; - let _ = socket.close().await; - return; + if !authed { + let _ = socket + .send(Message::Text( + r#"{"error":"auth_failed","code":4001}"#.into(), + )) + .await; + let _ = socket.close().await; + return; + } } let _ = socket @@ -145,6 +181,8 @@ async fn handle_connection(mut socket: WebSocket, state: Arc) { RpcResult::Stream { id, session_key, + agent_id: _agent_id, + user_content, mut rx, } => { // Stream AgentEvents as JSON frames to the client @@ -231,9 +269,20 @@ async fn handle_connection(mut socket: WebSocket, state: Arc) { if let Some(session) = store.get_mut(&session_key) { session.messages.push(serde_json::json!({ "role": "assistant", - "content": assistant_text, + "content": assistant_text.clone(), })); } + + let mut memory = state.memory.write().await; + let user_message = + AgentMessage::text("user", user_content.clone()); + let assistant_message = + AgentMessage::text("assistant", assistant_text.clone()); + memory.process_response( + &session_key, + &user_message, + &assistant_message, + ); } break; } @@ -325,6 +374,8 @@ async fn webhook_handler( return (StatusCode::BAD_REQUEST, "empty message content".to_string()); } + let user_message = AgentMessage::text("user", content.clone()); + // 3. Route to agent let route = { let mut router = state.router.write().await; @@ -337,24 +388,29 @@ async fn webhook_handler( ) }; + let session_lock = state.session_lock(&route.session_key).await; + let _session_guard = session_lock.lock().await; + // 4. Get/create session and append user message { let mut store = state.store.write().await; let session = store.get_or_create(&route.session_key, &route.agent_id); session.messages.push(serde_json::json!({ "role": "user", - "content": content, + "content": content.clone(), })); session.message_count += 1; } - // 5. Build message history + // 5. Build message history using memory engine context let messages = { - let store = state.store.read().await; - match store.get(&route.session_key) { - Some(session) => session.messages.clone(), - None => vec![serde_json::json!({"role": "user", "content": content})], - } + let mut memory = state.memory.write().await; + let mut context = memory.assemble_context(&route.session_key, &route.agent_id, &content); + context.push(user_message.clone()); + context + .into_iter() + .filter_map(|m| m.as_provider_message()) + .collect::>() }; // 6. Create provider and run agent synchronously (collect full response) @@ -425,10 +481,14 @@ async fn webhook_handler( if let Some(session) = store.get_mut(&route.session_key) { session.messages.push(serde_json::json!({ "role": "assistant", - "content": response_text, + "content": response_text.clone(), })); session.message_count += 1; } + + let mut memory = state.memory.write().await; + let assistant_message = AgentMessage::text("assistant", response_text.clone()); + memory.process_response(&route.session_key, &user_message, &assistant_message); } // 9. Format outgoing via channel adapter plugin diff --git a/src/router/mod.rs b/src/router/mod.rs index c84fbb0..f84a689 100644 --- a/src/router/mod.rs +++ b/src/router/mod.rs @@ -26,7 +26,6 @@ pub struct Binding { } struct SessionState { - agent_id: String, message_count: u64, } @@ -124,10 +123,7 @@ impl SessionRouter { self.sessions .entry(session_key.clone()) .and_modify(|s| s.message_count += 1) - .or_insert(SessionState { - agent_id: agent_id.clone(), - message_count: 1, - }); + .or_insert(SessionState { message_count: 1 }); RouteResult { agent_id: agent_id.clone(), diff --git a/src/sandbox/mod.rs b/src/sandbox/mod.rs index e7db8e7..2447821 100644 --- a/src/sandbox/mod.rs +++ b/src/sandbox/mod.rs @@ -266,20 +266,35 @@ fn detect_plugin_type(plugin: &mut Plugin) -> (PluginType, Option>("describe", b"{}") { if let Ok(schema) = serde_json::from_slice::(&output) { + let declared_type = schema + .get("type") + .and_then(|v| v.as_str()) + .or_else(|| schema.get("plugin_type").and_then(|v| v.as_str())); + + if matches!(declared_type, Some("channel_adapter")) { + return (PluginType::ChannelAdapter, None); + } + return (PluginType::Tool, Some(schema)); } } - // Check if it has handle_tool_call (tool) or parse_incoming (channel adapter) + // Fall back to function probing when describe() is unavailable. if plugin - .call::<&[u8], Vec>("handle_tool_call", b"{}") + .call::<&[u8], Vec>("parse_incoming", b"{}") .is_ok() + || plugin + .call::<&[u8], Vec>("format_outgoing", br#"{"content":"ok"}"#) + .is_ok() { - return (PluginType::Tool, None); + return (PluginType::ChannelAdapter, None); } - if plugin.call::<&[u8], Vec>("parse_incoming", b"").is_ok() { - return (PluginType::ChannelAdapter, None); + if plugin + .call::<&[u8], Vec>("handle_tool_call", b"{}") + .is_ok() + { + return (PluginType::Tool, None); } // Default to Tool type diff --git a/src/types.rs b/src/types.rs index f37aebd..5f50c54 100644 --- a/src/types.rs +++ b/src/types.rs @@ -30,6 +30,50 @@ pub enum MessageContent { }, } +impl Message { + /// Create a text message with current timestamp. + pub fn text(role: &str, text: impl Into) -> Self { + Self { + role: role.to_string(), + content: MessageContent::Text { text: text.into() }, + timestamp: chrono::Utc::now(), + token_count: None, + } + } + + /// Convert to a provider-facing message format. + pub fn as_provider_message(&self) -> Option { + match &self.content { + MessageContent::Text { text } => Some(serde_json::json!({ + "role": self.role, + "content": text, + })), + MessageContent::ToolUse { id, name, input } => Some(serde_json::json!({ + "role": "assistant", + "content": [{ + "type": "tool_use", + "id": id, + "name": name, + "input": input, + }], + })), + MessageContent::ToolResult { + tool_use_id, + content, + is_error, + } => Some(serde_json::json!({ + "role": "user", + "content": [{ + "type": "tool_result", + "tool_use_id": tool_use_id, + "content": content, + "is_error": is_error, + }], + })), + } + } +} + /// Normalized incoming message from any channel. /// Used by the router to determine the target agent and session. #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/tests/channel_test.rs b/tests/channel_test.rs index d75c937..d66671f 100644 --- a/tests/channel_test.rs +++ b/tests/channel_test.rs @@ -1,5 +1,5 @@ -use exoclaw::sandbox::PluginHost; use exoclaw::sandbox::capabilities::Capability; +use exoclaw::sandbox::{PluginHost, PluginType}; /// Path to the mock-channel plugin WASM binary. fn mock_channel_wasm_path() -> String { @@ -15,18 +15,9 @@ fn register_mock_channel_detects_adapter_type() { host.register("mock", &mock_channel_wasm_path(), vec![]) .unwrap(); - // Should not appear as a tool plugin since it has parse_incoming assert!(host.has_plugin("mock")); - - // The describe() export returns channel_adapter type info, but detect_plugin_type - // checks describe() first (which returns valid JSON → Tool), then falls back. - // Since describe() returns valid JSON, it will be detected as Tool with schema. - // The plugin_type method exposes this. - // Note: In practice, the describe() for a channel adapter could return - // a type field that we inspect. For now, the mock plugin has describe() - // returning valid JSON so it's detected as Tool type. - // The find_channel_adapter lookup works by matching PluginType::ChannelAdapter. - // We need to verify parse_incoming works regardless of type detection. + assert_eq!(host.plugin_type("mock"), Some(&PluginType::ChannelAdapter)); + assert_eq!(host.find_channel_adapter("mock"), Some("mock")); } #[test] diff --git a/tests/gateway_test.rs b/tests/gateway_test.rs new file mode 100644 index 0000000..05cfff6 --- /dev/null +++ b/tests/gateway_test.rs @@ -0,0 +1,303 @@ +use axum::{Router, http::header, response::IntoResponse, routing::post}; +use exoclaw::config::ExoclawConfig; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; +use tokio::time::{Duration, sleep, timeout}; + +async fn mock_openai_handler() -> impl IntoResponse { + let body = concat!( + "data: {\"choices\":[{\"delta\":{\"content\":\"hello\"},\"finish_reason\":null}]}\n\n", + "data: {\"choices\":[{\"delta\":{},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":1}}\n\n", + "data: [DONE]\n\n" + ); + ([(header::CONTENT_TYPE, "text/event-stream")], body) +} + +async fn start_mock_openai_server() -> (String, tokio::task::JoinHandle<()>) { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let app = Router::new().route("/v1/chat/completions", post(mock_openai_handler)); + let handle = tokio::spawn(async move { + let _ = axum::serve(listener, app).await; + }); + (format!("http://{addr}/v1/chat/completions"), handle) +} + +fn free_port() -> u16 { + std::net::TcpListener::bind("127.0.0.1:0") + .unwrap() + .local_addr() + .unwrap() + .port() +} + +struct WsClient { + stream: TcpStream, + read_buffer: Vec, +} + +impl WsClient { + async fn connect(host: &str, port: u16, path: &str) -> anyhow::Result { + let mut stream = TcpStream::connect((host, port)).await?; + let request = format!( + "GET {path} HTTP/1.1\r\n\ + Host: {host}:{port}\r\n\ + Upgrade: websocket\r\n\ + Connection: Upgrade\r\n\ + Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ + Sec-WebSocket-Version: 13\r\n\ + \r\n" + ); + stream.write_all(request.as_bytes()).await?; + + let mut response = Vec::new(); + let header_end; + loop { + let mut buf = [0u8; 1024]; + let n = stream.read(&mut buf).await?; + if n == 0 { + anyhow::bail!("websocket handshake closed early"); + } + response.extend_from_slice(&buf[..n]); + if let Some(pos) = response.windows(4).position(|w| w == b"\r\n\r\n") { + header_end = pos + 4; + break; + } + } + + let response_text = String::from_utf8_lossy(&response[..header_end]); + anyhow::ensure!( + response_text.starts_with("HTTP/1.1 101"), + "unexpected websocket handshake response: {response_text}" + ); + + let read_buffer = response[header_end..].to_vec(); + + Ok(Self { + stream, + read_buffer, + }) + } + + async fn send_text(&mut self, payload: &str) -> anyhow::Result<()> { + let payload = payload.as_bytes(); + let mut frame = Vec::with_capacity(payload.len() + 14); + frame.push(0x81); // FIN + text frame + + let mask_bit = 0x80u8; + if payload.len() < 126 { + frame.push(mask_bit | payload.len() as u8); + } else if payload.len() <= u16::MAX as usize { + frame.push(mask_bit | 126); + frame.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + } else { + frame.push(mask_bit | 127); + frame.extend_from_slice(&(payload.len() as u64).to_be_bytes()); + } + + let mask = [0x12u8, 0x34, 0x56, 0x78]; + frame.extend_from_slice(&mask); + for (i, b) in payload.iter().enumerate() { + frame.push(b ^ mask[i % 4]); + } + + self.stream.write_all(&frame).await?; + Ok(()) + } + + async fn read_exact_ws(&mut self, buf: &mut [u8]) -> anyhow::Result<()> { + let mut offset = 0usize; + while offset < buf.len() { + if !self.read_buffer.is_empty() { + let take = (buf.len() - offset).min(self.read_buffer.len()); + buf[offset..offset + take].copy_from_slice(&self.read_buffer[..take]); + self.read_buffer.drain(..take); + offset += take; + continue; + } + + let n = self.stream.read(&mut buf[offset..]).await?; + if n == 0 { + anyhow::bail!("connection closed while reading websocket frame"); + } + offset += n; + } + + Ok(()) + } + + async fn recv_text(&mut self) -> anyhow::Result { + let mut header = [0u8; 2]; + self.read_exact_ws(&mut header).await?; + + let opcode = header[0] & 0x0f; + let masked = (header[1] & 0x80) != 0; + let mut len = (header[1] & 0x7f) as u64; + + if len == 126 { + let mut ext = [0u8; 2]; + self.read_exact_ws(&mut ext).await?; + len = u16::from_be_bytes(ext) as u64; + } else if len == 127 { + let mut ext = [0u8; 8]; + self.read_exact_ws(&mut ext).await?; + len = u64::from_be_bytes(ext); + } + + let mut mask = [0u8; 4]; + if masked { + self.read_exact_ws(&mut mask).await?; + } + + let mut payload = vec![0u8; len as usize]; + self.read_exact_ws(&mut payload).await?; + + if masked { + for (i, b) in payload.iter_mut().enumerate() { + *b ^= mask[i % 4]; + } + } + + match opcode { + 0x1 => Ok(String::from_utf8(payload)?), + 0x8 => anyhow::bail!("received close frame"), + other => anyhow::bail!("unexpected opcode: {other}"), + } + } + + async fn recv_json(&mut self) -> anyhow::Result { + let text = self.recv_text().await?; + let value: serde_json::Value = serde_json::from_str(&text)?; + Ok(value) + } + + async fn recv_json_timeout(&mut self, label: &str) -> anyhow::Result { + timeout(Duration::from_secs(5), self.recv_json()) + .await + .map_err(|_| anyhow::anyhow!("timeout waiting for websocket frame: {label}"))? + } +} + +async fn connect_ws_with_retry(host: &str, port: u16) -> WsClient { + let mut last_err = None; + for _ in 0..40 { + match WsClient::connect(host, port, "/ws").await { + Ok(client) => return client, + Err(e) => { + last_err = Some(e); + sleep(Duration::from_millis(50)).await; + } + } + } + panic!("failed to connect websocket: {last_err:?}"); +} + +fn gateway_config(port: u16, openai: bool) -> ExoclawConfig { + let mut config = ExoclawConfig::default(); + config.gateway.bind = "127.0.0.1".to_string(); + config.gateway.port = port; + if openai { + config.agent.provider = "openai".to_string(); + config.agent.model = "gpt-4o".to_string(); + config.agent.api_key = Some("test-key".to_string()); + } + config +} + +#[tokio::test] +async fn authenticated_chat_send_streams_text_and_done() { + let (mock_endpoint, mock_handle) = start_mock_openai_server().await; + // SAFETY: test-scoped env mutation for provider endpoint override. + unsafe { + std::env::set_var("EXOCLAW_OPENAI_ENDPOINT", &mock_endpoint); + } + + let port = free_port(); + let config = gateway_config(port, true); + let gateway = tokio::spawn(async move { + let _ = exoclaw::gateway::run(config, Some("secret-token".to_string())).await; + }); + + let mut ws = connect_ws_with_retry("127.0.0.1", port).await; + ws.send_text(r#"{"token":"secret-token"}"#).await.unwrap(); + + let hello = ws.recv_json_timeout("auth hello").await.unwrap(); + assert_eq!(hello["ok"], true); + + ws.send_text( + r#"{"id":"chat1","method":"chat.send","params":{"channel":"websocket","account":"me","content":"hello there"}}"#, + ) + .await + .unwrap(); + + let mut saw_text = false; + let mut saw_done = false; + for _ in 0..10 { + let frame = ws.recv_json_timeout("chat frame").await.unwrap(); + if frame["id"] != "chat1" { + continue; + } + match frame["event"].as_str() { + Some("text") => saw_text = true, + Some("done") => { + saw_done = true; + break; + } + Some("error") => panic!("unexpected stream error: {frame}"), + _ => {} + } + } + + assert!(saw_text, "expected at least one text event"); + assert!(saw_done, "expected done event"); + + // SAFETY: undo test-scoped env mutation. + unsafe { + std::env::remove_var("EXOCLAW_OPENAI_ENDPOINT"); + } + gateway.abort(); + let _ = gateway.await; + mock_handle.abort(); + let _ = mock_handle.await; +} + +#[tokio::test] +async fn unauthenticated_connection_is_rejected() { + let port = free_port(); + let config = gateway_config(port, false); + let gateway = tokio::spawn(async move { + let _ = exoclaw::gateway::run(config, Some("secret-token".to_string())).await; + }); + + let mut ws = connect_ws_with_retry("127.0.0.1", port).await; + ws.send_text(r#"{"token":"wrong-token"}"#).await.unwrap(); + let response = ws.recv_json_timeout("unauth response").await.unwrap(); + assert_eq!(response["error"], "auth_failed"); + assert_eq!(response["code"], 4001); + + gateway.abort(); + let _ = gateway.await; +} + +#[tokio::test] +async fn loopback_mode_allows_no_auth_ping() { + let port = free_port(); + let config = gateway_config(port, false); + let gateway = tokio::spawn(async move { + let _ = exoclaw::gateway::run(config, None).await; + }); + + let mut ws = connect_ws_with_retry("127.0.0.1", port).await; + let hello = ws.recv_json_timeout("loopback hello").await.unwrap(); + assert_eq!(hello["ok"], true); + + ws.send_text(r#"{"id":"p1","method":"ping"}"#) + .await + .unwrap(); + let pong = ws.recv_json_timeout("loopback pong").await.unwrap(); + assert_eq!(pong["id"], "p1"); + assert_eq!(pong["result"], "pong"); + + gateway.abort(); + let _ = gateway.await; +} diff --git a/tests/protocol_test.rs b/tests/protocol_test.rs new file mode 100644 index 0000000..0f0f9e5 --- /dev/null +++ b/tests/protocol_test.rs @@ -0,0 +1,135 @@ +use exoclaw::agent::AgentEvent; +use exoclaw::config::ExoclawConfig; +use exoclaw::gateway::protocol::{RpcResult, handle_rpc}; +use exoclaw::gateway::server::AppState; +use exoclaw::memory::MemoryEngine; +use exoclaw::router::SessionRouter; +use exoclaw::sandbox::PluginHost; +use exoclaw::store::SessionStore; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::{Mutex, RwLock}; +use tokio::time::{Duration, timeout}; + +fn build_state(config: ExoclawConfig) -> Arc { + Arc::new(AppState { + token: None, + router: RwLock::new(SessionRouter::new()), + plugins: Arc::new(RwLock::new(PluginHost::new())), + store: RwLock::new(SessionStore::new()), + memory: Arc::new(RwLock::new(MemoryEngine::new( + config.memory.episodic_window as usize, + config.memory.semantic_enabled, + ))), + config, + session_locks: RwLock::new(HashMap::>>::new()), + }) +} + +#[tokio::test] +async fn ping_returns_pong() { + let state = build_state(ExoclawConfig::default()); + let result = handle_rpc(r#"{"id":"1","method":"ping"}"#, &state).await; + let RpcResult::Response(resp) = result else { + panic!("expected response"); + }; + let parsed: serde_json::Value = serde_json::from_str(&resp).unwrap(); + assert_eq!(parsed["id"], "1"); + assert_eq!(parsed["result"], "pong"); +} + +#[tokio::test] +async fn status_returns_version_plugins_sessions() { + let state = build_state(ExoclawConfig::default()); + let result = handle_rpc(r#"{"id":"2","method":"status"}"#, &state).await; + let RpcResult::Response(resp) = result else { + panic!("expected response"); + }; + let parsed: serde_json::Value = serde_json::from_str(&resp).unwrap(); + assert_eq!(parsed["id"], "2"); + assert!(parsed["result"]["version"].is_string()); + assert_eq!(parsed["result"]["plugins"], 0); + assert_eq!(parsed["result"]["sessions"], 0); +} + +#[tokio::test] +async fn chat_send_missing_params_returns_error() { + let state = build_state(ExoclawConfig::default()); + let result = handle_rpc( + r#"{"id":"3","method":"chat.send","params":{"channel":"ws"}}"#, + &state, + ) + .await; + let RpcResult::Response(resp) = result else { + panic!("expected response"); + }; + let parsed: serde_json::Value = serde_json::from_str(&resp).unwrap(); + assert_eq!(parsed["id"], "3"); + assert!( + parsed["error"] + .as_str() + .unwrap_or("") + .contains("invalid chat.send params") + ); +} + +#[tokio::test] +async fn unknown_method_returns_error() { + let state = build_state(ExoclawConfig::default()); + let result = handle_rpc(r#"{"id":"4","method":"nope.method"}"#, &state).await; + let RpcResult::Response(resp) = result else { + panic!("expected response"); + }; + let parsed: serde_json::Value = serde_json::from_str(&resp).unwrap(); + assert_eq!(parsed["id"], "4"); + assert_eq!(parsed["error"], "unknown method: nope.method"); +} + +#[tokio::test] +async fn chat_send_valid_params_streams_response() { + let mut config = ExoclawConfig::default(); + config.agent.provider = "mock".to_string(); + let state = build_state(config); + + let result = handle_rpc( + r#"{"id":"5","method":"chat.send","params":{"channel":"websocket","account":"me","content":"hello world"}}"#, + &state, + ) + .await; + + let RpcResult::Stream { + id, + session_key, + mut rx, + .. + } = result + else { + panic!("expected stream"); + }; + + assert_eq!(id, "5"); + assert_eq!(session_key, "default:websocket:me:main"); + + let mut text = String::new(); + let mut saw_usage = false; + let mut saw_done = false; + + loop { + let next = timeout(Duration::from_secs(5), rx.recv()).await.unwrap(); + match next { + Some(AgentEvent::Text(chunk)) => text.push_str(&chunk), + Some(AgentEvent::Usage { .. }) => saw_usage = true, + Some(AgentEvent::Done) => { + saw_done = true; + break; + } + Some(AgentEvent::Error(err)) => panic!("unexpected stream error: {err}"), + Some(_) => {} + None => break, + } + } + + assert!(text.contains("mock response")); + assert!(saw_usage); + assert!(saw_done); +}