diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml new file mode 100644 index 0000000..b55e77c --- /dev/null +++ b/.github/workflows/integration.yml @@ -0,0 +1,31 @@ +name: Integration Tests + +on: + push: + branches: ["main"] + pull_request: + branches: ["main"] + +env: + CARGO_TERM_COLOR: always + +jobs: + integration: + name: Integration tests + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache dependencies + uses: Swatinem/rust-cache@v2 + + - name: Run unit and local integration tests + run: cargo test --all-features + + - name: Run integration example (local RMCP) + run: cargo run --example rmcp_integration_test --features rmcp -- local diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml new file mode 100644 index 0000000..43b2061 --- /dev/null +++ b/.github/workflows/rust.yml @@ -0,0 +1,42 @@ +name: Rust CI + +on: + push: + branches: ["main"] + pull_request: + branches: ["main"] + +env: + CARGO_TERM_COLOR: always + +jobs: + ci: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache dependencies + uses: Swatinem/rust-cache@v2 + + - name: Check + run: cargo check --all --all-features + + - name: Test + run: cargo test --all --all-features + + - name: Format check + run: cargo fmt --all -- --check + + - name: Clippy + run: cargo clippy --all --all-features -- -D warnings + + - name: Doc check + run: cargo doc --no-deps --all-features + + - name: Build + run: cargo build --verbose \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..694916f --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,32 @@ +# Changelog + +## [0.1.0] - 2026-05-07 + +### Added + +- Core transport layer: `NostrClientTransport` and `NostrServerTransport` over NIP-59 gift wraps +- Gateway and Proxy high-level APIs for bridging MCP over Nostr +- Discovery API: `discover_servers`, `discover_tools`, `discover_resources`, `discover_prompts`, `discover_resource_templates` +- CEP-6: server announcement publishing and querying (kinds 11316–11320) +- CEP-19: ephemeral gift wraps (kind 21059) with `GiftWrapMode` negotiation on both client and server +- CEP-35: stateless session discovery, tag composition, and capability learning +- LRU-bounded session store with configurable capacity (default 1000 sessions) and TTL expiry +- Multi-client support in `NostrServerWorker` (removed single-peer barrier) +- Direct rmcp transport adapters via `into_rmcp_transport()` for native `ContextVM` services +- `CancellationToken`-based graceful shutdown on `close()` +- TTL sweep for client and server correlation stores to prevent pending-request leaks +- `MockRelayPool` for deterministic offline testing +- Builder pattern for all transport and worker configuration structs +- Four examples: gateway, proxy, discovery, and rmcp integration test + +### Fixed + +- Single-peer barrier in RMCP worker rejected concurrent clients (#60) +- Pending-request leak: correlation store entries never expired by TTL (#61) +- Event loop tasks not cancelled on `close()`, causing resource leaks (#63) +- `RecvError::Lagged` killing event loop under high relay throughput (#68) +- Client race condition: responses lost when publish completed before correlation registration (#55) +- Uncorrelated responses (missing `e` tag) forwarded to consumer instead of dropped (#55) +- Non-atomic `send_response` behavior in server transport (#48) +- Unbounded LRU cache initialization with zero capacity (#50) +- Announced servers not sending JSON-RPC `-32000 Unauthorized` error for disallowed clients (#53) diff --git a/Cargo.toml b/Cargo.toml index 165dae7..7abf096 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,9 +2,15 @@ name = "contextvm-sdk" version = "0.1.0" edition = "2021" +rust-version = "1.70" description = "Rust SDK for the ContextVM protocol — MCP over Nostr" license = "MIT" -repository = "https://github.com/k0sti/rust-contextvm-sdk" +readme = "README.md" +repository = "https://github.com/ContextVM/rs-sdk" +homepage = "https://contextvm.org" +documentation = "https://docs.rs/contextvm-sdk" +keywords = ["nostr", "mcp", "model-context-protocol", "decentralized", "ai"] +categories = ["network-programming", "api-bindings", "asynchronous"] [dependencies] # Async runtime @@ -24,6 +30,34 @@ nostr-sdk = { version = "0.43", features = ["nip59"] } # Logging tracing = "0.1" +# Optional MCP integration (Rust equivalent to TS @modelcontextprotocol/sdk) +rmcp = { version = "0.16.0", features = ["server", "client", "macros", "transport-worker"], optional = true } + +# LRU cache for gift-wrap (outer event id) deduplication +lru = "0.12" + +# CancellationToken for graceful event-loop shutdown +tokio-util = { version = "0.7", features = ["rt"] } + +[features] +# Enable rmcp by default while keeping legacy APIs available. +default = ["rmcp"] +rmcp = ["dep:rmcp"] + +[[example]] +name = "rmcp_integration_test" +required-features = ["rmcp"] + +[[example]] +name = "native_echo_server" +required-features = ["rmcp"] + +[[example]] +name = "native_echo_client" +required-features = ["rmcp"] + [dev-dependencies] tokio-test = "0.4" -tracing-subscriber = "0.3" +anyhow = "1" +schemars = "0.8" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/DESIGN.md b/DESIGN.md index fc9467f..7fe0ab7 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -221,7 +221,7 @@ tracing-subscriber = "0.3" - Unit test: rejects events from wrong server pubkey - [x] **2.4** Implement `NostrServerTransport` - - Config: relay_urls, encryption_mode, server_info, is_public_server, allowed_public_keys, excluded_capabilities, cleanup_interval_ms, session_timeout_ms + - Config: relay_urls, encryption_mode, server_info, is_announced_server, allowed_public_keys, excluded_capabilities, cleanup_interval_ms, session_timeout_ms - Implements `Transport` trait - Features: - Subscribe to events targeting server pubkey diff --git a/README.md b/README.md index 79ed657..ffdb4e5 100644 --- a/README.md +++ b/README.md @@ -77,19 +77,16 @@ use contextvm_sdk::signer; async fn main() -> contextvm_sdk::Result<()> { let keys = signer::generate(); - let config = GatewayConfig { - nostr_config: NostrServerTransportConfig { - relay_urls: vec!["wss://relay.damus.io".into()], - encryption_mode: EncryptionMode::Optional, - server_info: Some(ServerInfo { - name: Some("My MCP Server".into()), - about: Some("Tools via Nostr".into()), - ..Default::default() - }), - is_public_server: true, - ..Default::default() - }, - }; + let config = GatewayConfig::new( + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Optional) + .with_server_info( + ServerInfo::default() + .with_name("My MCP Server") + .with_about("Tools via Nostr"), + ) + .with_announced_server(true), + ); let mut gateway = NostrMCPGateway::new(keys, config).await?; let mut requests = gateway.start().await?; @@ -116,14 +113,11 @@ use contextvm_sdk::signer; async fn main() -> contextvm_sdk::Result<()> { let keys = signer::generate(); - let config = ProxyConfig { - nostr_config: NostrClientTransportConfig { - relay_urls: vec!["wss://relay.damus.io".into()], - server_pubkey: "abc123...server_hex_pubkey".into(), - encryption_mode: EncryptionMode::Optional, - ..Default::default() - }, - }; + let config = ProxyConfig::new( + NostrClientTransportConfig::default() + .with_server_pubkey("abc123...server_hex_pubkey") + .with_encryption_mode(EncryptionMode::Optional), + ); let mut proxy = NostrMCPProxy::new(keys, config).await?; let mut responses = proxy.start().await?; @@ -201,7 +195,7 @@ metadata-private delivery. Server announcements (kinds 11316–11320) are always | `relay_urls` | `["wss://relay.damus.io"]` | Nostr relays to connect to | | `encryption_mode` | `Optional` | Encryption policy | | `server_info` | `None` | Server metadata for announcements | -| `is_public_server` | `false` | Whether to publish announcements | +| `is_announced_server` | `false` | Whether to publish announcements (CEP-6) | | `allowed_public_keys` | `[]` (allow all) | Client pubkey allowlist (hex) | | `excluded_capabilities` | `[]` | Methods exempt from allowlist | | `session_timeout` | `300s` | Inactive session expiry | diff --git a/examples/discovery.rs b/examples/discovery.rs index 166cb3f..edff853 100644 --- a/examples/discovery.rs +++ b/examples/discovery.rs @@ -53,8 +53,7 @@ async fn main() -> contextvm_sdk::Result<()> { println!(" Resources: {} found", resources.len()); } - let prompts = - discovery::discover_prompts(client, &server.pubkey_parsed, &relays).await?; + let prompts = discovery::discover_prompts(client, &server.pubkey_parsed, &relays).await?; if !prompts.is_empty() { println!(" Prompts: {} found", prompts.len()); } diff --git a/examples/gateway.rs b/examples/gateway.rs index ee6b1c5..efe8500 100644 --- a/examples/gateway.rs +++ b/examples/gateway.rs @@ -2,6 +2,8 @@ //! //! This demonstrates how to create a ContextVM gateway that receives //! MCP requests over Nostr and responds to them. +//! +//! Usage: cargo run --example gateway use contextvm_sdk::core::types::*; use contextvm_sdk::gateway::{GatewayConfig, NostrMCPGateway}; @@ -17,18 +19,14 @@ async fn main() -> contextvm_sdk::Result<()> { println!("Server pubkey: {}", keys.public_key().to_hex()); // Configure the gateway - let config = GatewayConfig { - nostr_config: NostrServerTransportConfig { - relay_urls: vec!["wss://relay.damus.io".to_string()], - server_info: Some(ServerInfo { - name: Some("Echo Server".to_string()), - about: Some("A simple echo tool exposed via ContextVM".to_string()), - ..Default::default() - }), - is_public_server: true, - ..Default::default() - }, - }; + let nostr_config = NostrServerTransportConfig::default() + .with_server_info( + ServerInfo::default() + .with_name("Echo Server") + .with_about("A simple echo tool exposed via ContextVM"), + ) + .with_announced_server(true); + let config = GatewayConfig::new(nostr_config); let mut gateway = NostrMCPGateway::new(keys, config).await?; let mut rx = gateway.start().await?; diff --git a/examples/native_echo_client.rs b/examples/native_echo_client.rs new file mode 100644 index 0000000..be5e0dd --- /dev/null +++ b/examples/native_echo_client.rs @@ -0,0 +1,94 @@ +//! Example: Native rmcp client over ContextVM/Nostr. +//! +//! Usage: +//! cargo run --example native_echo_client -- + +use anyhow::{Context, Result}; +use contextvm_sdk::transport::client::{NostrClientTransport, NostrClientTransportConfig}; +use contextvm_sdk::{signer, EncryptionMode, GiftWrapMode}; +use rmcp::{ + model::{CallToolRequestParams, CallToolResult}, + ClientHandler, ServiceExt, +}; + +const RELAY_URL: &str = "wss://relay.contextvm.org"; + +#[derive(Clone, Default)] +struct EchoClient; + +impl ClientHandler for EchoClient {} + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::from_default_env() + .add_directive("contextvm_sdk=info".parse()?) + .add_directive("rmcp=warn".parse()?), + ) + .init(); + + let server_pubkey = std::env::args() + .nth(1) + .context("Usage: native_echo_client ")?; + + let signer = signer::generate(); + println!("Native ContextVM echo client starting"); + println!("Relay: {RELAY_URL}"); + println!("Client pubkey: {}", signer.public_key().to_hex()); + println!("Target server pubkey: {server_pubkey}"); + + let transport = NostrClientTransport::new( + signer, + NostrClientTransportConfig::default() + .with_relay_urls(vec![RELAY_URL.to_string()]) + .with_server_pubkey(server_pubkey) + .with_encryption_mode(EncryptionMode::Optional) + .with_gift_wrap_mode(GiftWrapMode::Optional), + ) + .await?; + + let client = EchoClient.serve(transport).await?; + + let peer_info = client + .peer_info() + .context("server did not provide peer info after initialize")?; + println!("Connected to: {:?}", peer_info.server_info.name); + + let tools = client.list_all_tools().await?; + println!("Discovered {} tool(s):", tools.len()); + for tool in &tools { + println!("- {}", tool.name); + } + + let result = client + .call_tool(CallToolRequestParams { + name: "echo".into(), + arguments: serde_json::from_value(serde_json::json!({ + "message": "hello from native contextvm client" + })) + .ok(), + meta: None, + task: None, + }) + .await?; + + println!("Echo result: {}", first_text(&result)); + + client.cancel().await?; + Ok(()) +} + +fn first_text(result: &CallToolResult) -> String { + result + .content + .iter() + .find_map(|content| { + if let rmcp::model::RawContent::Text(text) = &content.raw { + Some(text.text.clone()) + } else { + None + } + }) + .unwrap_or_default() +} diff --git a/examples/native_echo_server.rs b/examples/native_echo_server.rs new file mode 100644 index 0000000..e463e62 --- /dev/null +++ b/examples/native_echo_server.rs @@ -0,0 +1,103 @@ +//! Example: Native rmcp echo server over ContextVM/Nostr. +//! +//! Usage: +//! cargo run --example native_echo_server + +use anyhow::Result; +use contextvm_sdk::transport::server::{NostrServerTransport, NostrServerTransportConfig}; +use contextvm_sdk::{signer, EncryptionMode, GiftWrapMode, ServerInfo}; +use rmcp::{ + handler::server::{router::tool::ToolRouter, wrapper::Parameters}, + model::*, + schemars, tool, tool_handler, tool_router, ServerHandler, ServiceExt, +}; + +const RELAY_URL: &str = "wss://relay.contextvm.org"; + +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +struct EchoParams { + message: String, +} + +#[derive(Clone)] +struct EchoServer { + tool_router: ToolRouter, +} + +impl EchoServer { + fn new() -> Self { + Self { + tool_router: Self::tool_router(), + } + } +} + +#[tool_router] +impl EchoServer { + #[tool(description = "Echo a message back unchanged")] + async fn echo( + &self, + Parameters(EchoParams { message }): Parameters, + ) -> Result { + Ok(CallToolResult::success(vec![Content::text(format!( + "Echo: {message}" + ))])) + } +} + +#[tool_handler] +impl ServerHandler for EchoServer { + fn get_info(&self) -> rmcp::model::ServerInfo { + rmcp::model::ServerInfo { + protocol_version: ProtocolVersion::LATEST, + capabilities: ServerCapabilities::builder().enable_tools().build(), + server_info: Implementation { + name: "contextvm-native-echo".to_string(), + title: Some("ContextVM Native Echo Server".to_string()), + version: "0.1.0".to_string(), + description: Some("Native rmcp echo server over ContextVM/Nostr".to_string()), + icons: None, + website_url: None, + }, + instructions: Some("Call the echo tool with a message string".to_string()), + } + } +} + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::from_default_env() + .add_directive("contextvm_sdk=info".parse()?) + .add_directive("rmcp=warn".parse()?), + ) + .init(); + + let signer = signer::generate(); + let pubkey = signer.public_key().to_hex(); + + println!("Native ContextVM echo server starting"); + println!("Relay: {RELAY_URL}"); + println!("Server pubkey: {pubkey}"); + + let transport = NostrServerTransport::new( + signer, + NostrServerTransportConfig::default() + .with_relay_urls(vec![RELAY_URL.to_string()]) + .with_encryption_mode(EncryptionMode::Optional) + .with_gift_wrap_mode(GiftWrapMode::Optional) + .with_announced_server(false) + .with_server_info( + ServerInfo::default() + .with_name("contextvm-native-echo".to_string()) + .with_about("Native rmcp echo server example".to_string()), + ), + ) + .await?; + + let service = EchoServer::new().serve(transport).await?; + println!("Server ready. Press Ctrl+C to stop."); + service.waiting().await?; + Ok(()) +} diff --git a/examples/proxy.rs b/examples/proxy.rs index a10663a..4aea3e3 100644 --- a/examples/proxy.rs +++ b/examples/proxy.rs @@ -6,6 +6,7 @@ use contextvm_sdk::core::types::*; use contextvm_sdk::proxy::{NostrMCPProxy, ProxyConfig}; use contextvm_sdk::signer; use contextvm_sdk::transport::client::NostrClientTransportConfig; + #[tokio::main] async fn main() -> contextvm_sdk::Result<()> { tracing_subscriber::fmt::init(); @@ -17,14 +18,10 @@ async fn main() -> contextvm_sdk::Result<()> { let keys = signer::generate(); println!("Client pubkey: {}", keys.public_key().to_hex()); - let config = ProxyConfig { - nostr_config: NostrClientTransportConfig { - relay_urls: vec!["wss://relay.damus.io".to_string()], - server_pubkey: server_pubkey_hex, - encryption_mode: EncryptionMode::Optional, - ..Default::default() - }, - }; + let nostr_config = NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey_hex) + .with_encryption_mode(EncryptionMode::Optional); + let config = ProxyConfig::new(nostr_config); let mut proxy = NostrMCPProxy::new(keys, config).await?; let mut rx = proxy.start().await?; @@ -42,7 +39,10 @@ async fn main() -> contextvm_sdk::Result<()> { // Wait for response if let Some(response) = rx.recv().await { - println!("Response: {}", serde_json::to_string_pretty(&response).unwrap()); + println!( + "Response: {}", + serde_json::to_string_pretty(&response).unwrap() + ); } proxy.stop().await?; diff --git a/examples/rmcp_integration_test.rs b/examples/rmcp_integration_test.rs new file mode 100644 index 0000000..9d82519 --- /dev/null +++ b/examples/rmcp_integration_test.rs @@ -0,0 +1,719 @@ +//! Comprehensive rmcp integration matrix for ContextVM SDK. +//! +//! This example validates three scenarios: +//! 1) local rmcp transport (in-process duplex) +//! 2) hybrid relay mode (rmcp server + legacy JSON-RPC client) +//! 3) full rmcp over relays (rmcp server + rmcp client) +//! +//! Run: +//! cargo run --example rmcp_integration_test --features rmcp +//! cargo run --example rmcp_integration_test --features rmcp -- local +//! cargo run --example rmcp_integration_test --features rmcp -- hybrid +//! cargo run --example rmcp_integration_test --features rmcp -- relay-rmcp +//! cargo run --example rmcp_integration_test --features rmcp -- all +//! +//! Optional relay override: +//! CTXVM_RELAY_URL=wss://relay.primal.net cargo run --example rmcp_integration_test --features rmcp -- all +//! cargo run --example rmcp_integration_test --features rmcp -- all wss://relay.primal.net + +use anyhow::{anyhow, bail, Context, Result}; +use contextvm_sdk::core::constants::mcp_protocol_version; +use contextvm_sdk::core::types::{ + EncryptionMode, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, + ServerInfo as CtxServerInfo, +}; +use contextvm_sdk::gateway::{GatewayConfig, NostrMCPGateway}; +use contextvm_sdk::proxy::{NostrMCPProxy, ProxyConfig}; +use contextvm_sdk::signer; +use contextvm_sdk::transport::client::NostrClientTransportConfig; +use contextvm_sdk::transport::server::NostrServerTransportConfig; +use rmcp::{ + handler::server::wrapper::Parameters, model::*, schemars, service::RequestContext, tool, + tool_handler, tool_router, ClientHandler, RoleServer, ServerHandler, ServiceExt, +}; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::Mutex; +use tokio::time::{sleep, timeout}; + +const DEFAULT_RELAY_URL: &str = "wss://relay.primal.net"; +const IO_TIMEOUT: Duration = Duration::from_secs(30); +const RELAY_WARMUP: Duration = Duration::from_secs(2); +const STARTUP_TIMEOUT: Duration = Duration::from_secs(20); + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Mode { + Local, + Hybrid, + RelayRmcp, + All, +} + +impl Mode { + fn parse(value: Option<&str>) -> Result { + match value.unwrap_or("all") { + "local" => Ok(Self::Local), + "hybrid" => Ok(Self::Hybrid), + "relay-rmcp" => Ok(Self::RelayRmcp), + "all" => Ok(Self::All), + other => bail!("Unknown mode '{other}'. Use one of: local | hybrid | relay-rmcp | all"), + } + } + + fn run_local(self) -> bool { + matches!(self, Self::Local | Self::All) + } + + fn run_hybrid(self) -> bool { + matches!(self, Self::Hybrid | Self::All) + } + + fn run_relay_rmcp(self) -> bool { + matches!(self, Self::RelayRmcp | Self::All) + } +} + +// Parameter structs with JSON schema for tools/list. +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +struct EchoParams { + message: String, +} + +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +struct AddParams { + a: i64, + b: i64, +} + +use rmcp::handler::server::router::tool::ToolRouter; + +#[derive(Clone)] +struct DemoServer { + echo_count: Arc>, + tool_router: ToolRouter, +} + +impl DemoServer { + fn new() -> Self { + Self { + echo_count: Arc::new(Mutex::new(0)), + tool_router: Self::tool_router(), + } + } +} + +#[tool_router] +impl DemoServer { + #[tool(description = "Echo a message back unchanged")] + async fn echo( + &self, + Parameters(EchoParams { message }): Parameters, + ) -> Result { + let mut n = self.echo_count.lock().await; + *n += 1; + Ok(CallToolResult::success(vec![Content::text(format!( + "Echo #{n}: {message}" + ))])) + } + + #[tool(description = "Add two integers and return their sum")] + fn add( + &self, + Parameters(AddParams { a, b }): Parameters, + ) -> Result { + Ok(CallToolResult::success(vec![Content::text(format!( + "{a} + {b} = {}", + a + b + ))])) + } + + #[tool(description = "Return the total number of echo calls made so far")] + async fn get_echo_count(&self) -> Result { + let n = self.echo_count.lock().await; + Ok(CallToolResult::success(vec![Content::text(format!( + "Total echo calls: {n}" + ))])) + } +} + +#[tool_handler] +impl ServerHandler for DemoServer { + fn get_info(&self) -> ServerInfo { + ServerInfo { + protocol_version: ProtocolVersion::LATEST, + capabilities: ServerCapabilities::builder() + .enable_tools() + .enable_resources() + .build(), + server_info: Implementation { + name: "contextvm-demo".to_string(), + title: Some("ContextVM Demo Server".to_string()), + version: "0.1.0".to_string(), + description: Some("Demonstrates rmcp integration over ContextVM".to_string()), + icons: None, + website_url: None, + }, + instructions: Some("Try: echo, add, get_echo_count".to_string()), + } + } + + async fn list_resources( + &self, + _req: Option, + _ctx: RequestContext, + ) -> Result { + Ok(ListResourcesResult { + resources: vec![ + RawResource::new("demo://readme", "Demo README".to_string()).no_annotation() + ], + next_cursor: None, + meta: None, + }) + } + + async fn read_resource( + &self, + req: ReadResourceRequestParams, + _ctx: RequestContext, + ) -> Result { + match req.uri.as_str() { + "demo://readme" => Ok(ReadResourceResult { + contents: vec![ResourceContents::text( + "This server demonstrates the ContextVM rmcp integration.", + req.uri, + )], + }), + other => Err(ErrorData::resource_not_found( + "not_found", + Some(serde_json::json!({ "uri": other })), + )), + } + } +} + +#[derive(Clone, Default)] +struct DemoClient; +impl ClientHandler for DemoClient {} + +#[derive(Clone, Default)] +struct RelayRmcpClient; +impl ClientHandler for RelayRmcpClient {} + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::from_default_env() + .add_directive("rmcp=warn".parse()?) + .add_directive("contextvm_sdk=info".parse()?), + ) + .init(); + + let args: Vec = std::env::args().skip(1).collect(); + let mode = Mode::parse(args.first().map(String::as_str))?; + let relay_url = args + .get(1) + .cloned() + .or_else(|| std::env::var("CTXVM_RELAY_URL").ok()) + .unwrap_or_else(|| DEFAULT_RELAY_URL.to_string()); + + println!("========================================"); + println!("ContextVM SDK rmcp integration matrix"); + println!("mode: {:?}", mode); + println!("relay: {relay_url}"); + println!("========================================\n"); + + if mode.run_local() { + run_local_rmcp_case().await?; + } + + if mode.run_hybrid() { + run_hybrid_relay_case(&relay_url).await?; + } + + if mode.run_relay_rmcp() { + run_relay_rmcp_case(&relay_url).await?; + } + + println!("\nAll selected integration scenarios passed."); + Ok(()) +} + +async fn run_local_rmcp_case() -> Result<()> { + println!("[local-rmcp] start"); + + let (server_io, client_io) = tokio::io::duplex(65536); + + let server_handle = tokio::spawn(async move { + DemoServer::new() + .serve(server_io) + .await + .expect("server serve failed") + .waiting() + .await + .expect("server error"); + }); + + let client = DemoClient.serve(client_io).await?; + + let tools = client.list_all_tools().await?; + assert_eq!(tools.len(), 3, "expected 3 tools in local rmcp case"); + + let add_result = client + .call_tool(call_params( + "add", + Some(serde_json::json!({ "a": 7, "b": 5 })), + )) + .await?; + let add_text = first_text(&add_result); + assert!(add_text.contains("12"), "expected add result to include 12"); + + let resources = client.list_all_resources().await?; + assert_eq!( + resources.len(), + 1, + "expected one resource in local rmcp case" + ); + + match client.call_tool(call_params("no_such_tool", None)).await { + Err(_) => {} + Ok(r) if r.is_error.unwrap_or(false) => {} + Ok(_) => bail!("expected unknown tool to fail in local rmcp case"), + } + + client.cancel().await?; + server_handle.abort(); + + println!("[local-rmcp] pass"); + Ok(()) +} + +async fn run_hybrid_relay_case(relay_url: &str) -> Result<()> { + println!("[relay-hybrid] start (rmcp server + legacy client)"); + + let server_keys = signer::generate(); + let server_pubkey_hex = server_keys.public_key().to_hex(); + + println!("[relay-hybrid] stage: spawning rmcp server task"); + let relay_url_owned = relay_url.to_string(); + let server_task = tokio::spawn(async move { + let server = NostrMCPGateway::serve_handler( + server_keys, + server_config(&relay_url_owned), + DemoServer::new(), + ) + .await + .with_context(|| format!("failed to start rmcp server on relay {relay_url_owned}"))?; + + let _ = server + .waiting() + .await + .map_err(|e| anyhow!("rmcp server exited with error: {e}"))?; + + Err(anyhow!("rmcp server stopped unexpectedly")) + }); + + sleep(RELAY_WARMUP).await; + + if server_task.is_finished() { + let res = server_task + .await + .map_err(|e| anyhow!("rmcp server task join error: {e}"))?; + return res.context("rmcp server task ended before client startup"); + } + + let outcome: Result<()> = async { + println!("[relay-hybrid] stage: creating legacy proxy client"); + + let mut proxy = timeout( + STARTUP_TIMEOUT, + NostrMCPProxy::new( + signer::generate(), + client_config(relay_url, server_pubkey_hex.clone()), + ), + ) + .await + .with_context(|| { + format!( + "timed out creating legacy proxy client after {:?}", + STARTUP_TIMEOUT + ) + })? + .context("failed to create legacy proxy client")?; + + println!("[relay-hybrid] stage: starting legacy proxy transport"); + let mut rx = timeout(STARTUP_TIMEOUT, proxy.start()) + .await + .with_context(|| { + format!( + "timed out starting legacy proxy transport after {:?}", + STARTUP_TIMEOUT + ) + })? + .context("failed to start legacy proxy")?; + println!("[relay-hybrid] stage: legacy proxy started"); + + let init_id = serde_json::json!(1); + let init_request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: init_id.clone(), + method: "initialize".to_string(), + params: Some(serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "capabilities": { + "tools": {}, + "resources": {} + }, + "clientInfo": { + "name": "legacy-hybrid-client", + "version": "0.1.0" + } + })), + }); + + let init_response = + send_legacy_request_and_wait(&proxy, &mut rx, init_request, &init_id).await?; + assert_initialize_shape(&init_response)?; + + proxy + .send(&JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/initialized".to_string(), + params: None, + })) + .await + .context("failed to send initialized notification")?; + + let tools_id = serde_json::json!(2); + let tools_request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: tools_id.clone(), + method: "tools/list".to_string(), + params: Some(serde_json::json!({})), + }); + + let tools_response = + send_legacy_request_and_wait(&proxy, &mut rx, tools_request, &tools_id).await?; + let tools = extract_tools_list(&tools_response)?; + assert!( + tools + .iter() + .any(|t| t.get("name") == Some(&serde_json::json!("echo"))), + "expected echo tool in hybrid case" + ); + + let call_id = serde_json::json!(3); + let call_request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: call_id.clone(), + method: "tools/call".to_string(), + params: Some(serde_json::json!({ + "name": "echo", + "arguments": { "message": "legacy-client-hello" } + })), + }); + + let call_response = send_legacy_request_and_wait(&proxy, &mut rx, call_request, &call_id) + .await + .context("tools/call failed in hybrid case")?; + let echo_text = extract_first_content_text(&call_response)?; + assert!( + echo_text.contains("legacy-client-hello"), + "unexpected echo output in hybrid case: {echo_text}" + ); + + let unknown_id = serde_json::json!(4); + let unknown_request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: unknown_id.clone(), + method: "tools/call".to_string(), + params: Some(serde_json::json!({ + "name": "no_such_tool", + "arguments": {} + })), + }); + + let unknown_response = + send_legacy_request_and_wait(&proxy, &mut rx, unknown_request, &unknown_id).await?; + assert_error_response(&unknown_response)?; + + proxy.stop().await.context("failed to stop legacy proxy")?; + + Ok(()) + } + .await; + + server_task.abort(); + + if server_task.is_finished() { + let _ = server_task.await; + } + + outcome?; + + println!("[relay-hybrid] pass"); + Ok(()) +} + +async fn run_relay_rmcp_case(relay_url: &str) -> Result<()> { + println!("[relay-rmcp] start (rmcp server + rmcp client)"); + + let server_keys = signer::generate(); + let server_pubkey_hex = server_keys.public_key().to_hex(); + + println!("[relay-rmcp] stage: spawning rmcp server task"); + let relay_url_owned = relay_url.to_string(); + let server_task = tokio::spawn(async move { + let server = NostrMCPGateway::serve_handler( + server_keys, + server_config(&relay_url_owned), + DemoServer::new(), + ) + .await + .with_context(|| format!("failed to start rmcp server on relay {relay_url_owned}"))?; + + let _ = server + .waiting() + .await + .map_err(|e| anyhow!("rmcp server exited with error: {e}"))?; + + Err(anyhow!("rmcp server stopped unexpectedly")) + }); + + sleep(RELAY_WARMUP).await; + + if server_task.is_finished() { + let res = server_task + .await + .map_err(|e| anyhow!("rmcp server task join error: {e}"))?; + return res.context("rmcp server task ended before rmcp client startup"); + } + + let outcome: Result<()> = async { + println!("[relay-rmcp] stage: starting rmcp relay client worker"); + + let client = timeout( + STARTUP_TIMEOUT, + NostrMCPProxy::serve_client_handler( + signer::generate(), + client_config(relay_url, server_pubkey_hex), + RelayRmcpClient, + ), + ) + .await + .with_context(|| { + format!( + "timed out starting rmcp relay client worker after {:?}", + STARTUP_TIMEOUT + ) + })? + .context("failed to start rmcp relay client")?; + println!("[relay-rmcp] stage: rmcp relay client started"); + + let peer = client + .peer_info() + .ok_or_else(|| anyhow!("rmcp relay client did not receive peer info"))?; + let negotiated = peer.protocol_version.to_string(); + assert!( + is_supported_protocol(&negotiated), + "unexpected negotiated protocol version: {negotiated}" + ); + + let tools = client.list_all_tools().await?; + assert!( + tools.iter().any(|t| t.name == "echo"), + "expected echo tool in rmcp relay case" + ); + + let echo = client + .call_tool(call_params( + "echo", + Some(serde_json::json!({ "message": "rmcp-relay-hello" })), + )) + .await?; + let echo_text = first_text(&echo); + assert!( + echo_text.contains("rmcp-relay-hello"), + "unexpected rmcp relay echo output: {echo_text}" + ); + + let resources = client.list_all_resources().await?; + assert!( + resources.iter().any(|r| r.uri.as_str() == "demo://readme"), + "expected demo://readme resource in rmcp relay case" + ); + + match client.call_tool(call_params("no_such_tool", None)).await { + Err(_) => {} + Ok(r) if r.is_error.unwrap_or(false) => {} + Ok(_) => bail!("expected unknown tool to fail in rmcp relay case"), + } + + client + .cancel() + .await + .context("failed to cancel rmcp relay client")?; + + Ok(()) + } + .await; + + server_task.abort(); + + if server_task.is_finished() { + let _ = server_task.await; + } + + outcome?; + + println!("[relay-rmcp] pass"); + Ok(()) +} + +fn server_config(relay_url: &str) -> GatewayConfig { + let nostr_config = NostrServerTransportConfig::default() + .with_relay_urls(vec![relay_url.to_string()]) + .with_encryption_mode(EncryptionMode::Optional) + .with_server_info( + CtxServerInfo::default() + .with_name("rmcp-matrix-server") + .with_about("rmcp matrix coverage server"), + ) + .with_announced_server(false); + GatewayConfig::new(nostr_config) +} + +fn client_config(relay_url: &str, server_pubkey: String) -> ProxyConfig { + let nostr_config = NostrClientTransportConfig::default() + .with_relay_urls(vec![relay_url.to_string()]) + .with_server_pubkey(server_pubkey) + .with_encryption_mode(EncryptionMode::Optional); + ProxyConfig::new(nostr_config) +} + +async fn send_legacy_request_and_wait( + proxy: &NostrMCPProxy, + rx: &mut tokio::sync::mpsc::UnboundedReceiver, + request: JsonRpcMessage, + expected_id: &serde_json::Value, +) -> Result { + proxy.send(&request).await?; + + loop { + let maybe_msg = timeout(IO_TIMEOUT, rx.recv()) + .await + .context("timed out waiting for legacy response")?; + + let msg = maybe_msg.ok_or_else(|| anyhow!("legacy response channel closed"))?; + + if msg.id() == Some(expected_id) { + return Ok(msg); + } + + if msg.is_notification() { + continue; + } + } +} + +fn extract_tools_list(response: &JsonRpcMessage) -> Result<&Vec> { + let JsonRpcMessage::Response(resp) = response else { + bail!("expected tools/list response, got {response:?}"); + }; + + resp.result + .get("tools") + .and_then(|v| v.as_array()) + .ok_or_else(|| anyhow!("tools/list response missing tools array")) +} + +fn extract_first_content_text(response: &JsonRpcMessage) -> Result { + let JsonRpcMessage::Response(resp) = response else { + bail!("expected tools/call response, got {response:?}"); + }; + + let text = resp + .result + .get("content") + .and_then(|v| v.as_array()) + .and_then(|items| items.first()) + .and_then(|item| item.get("text")) + .and_then(|text| text.as_str()) + .ok_or_else(|| anyhow!("tools/call response missing content[0].text"))?; + + Ok(text.to_string()) +} + +fn assert_initialize_shape(response: &JsonRpcMessage) -> Result<()> { + let JsonRpcMessage::Response(resp) = response else { + bail!("expected initialize response, got {response:?}"); + }; + let expected_protocol = mcp_protocol_version(); + let protocol = resp + .result + .get("protocolVersion") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow!("initialize response missing protocolVersion"))?; + + if !is_supported_protocol(protocol) { + bail!( + "unexpected protocolVersion in initialize response: expected one of [{expected_protocol}, {}], got {protocol}", + ProtocolVersion::LATEST + ); + } + + if resp.result.get("serverInfo").is_none() { + bail!("initialize response missing serverInfo"); + } + + Ok(()) +} + +fn is_supported_protocol(protocol: &str) -> bool { + protocol == mcp_protocol_version() || protocol == ProtocolVersion::LATEST.to_string() +} + +fn assert_error_response(response: &JsonRpcMessage) -> Result<()> { + match response { + JsonRpcMessage::ErrorResponse(err) => { + if err.error.code >= 0 { + bail!( + "expected negative JSON-RPC error code, got {}", + err.error.code + ); + } + Ok(()) + } + JsonRpcMessage::Response(resp) => { + if resp.result.get("isError") == Some(&serde_json::json!(true)) { + Ok(()) + } else { + bail!("expected error response but received success result") + } + } + _ => bail!("expected error response, got {response:?}"), + } +} + +fn call_params(name: &'static str, args: Option) -> CallToolRequestParams { + CallToolRequestParams { + name: name.into(), + arguments: args.and_then(|v| serde_json::from_value(v).ok()), + meta: None, + task: None, + } +} + +fn first_text(result: &CallToolResult) -> String { + result + .content + .iter() + .find_map(|content| { + if let RawContent::Text(t) = &content.raw { + Some(t.text.clone()) + } else { + None + } + }) + .unwrap_or_default() +} diff --git a/sdk b/sdk new file mode 160000 index 0000000..7a0c5c3 --- /dev/null +++ b/sdk @@ -0,0 +1 @@ +Subproject commit 7a0c5c398c8ffcc6b07ddb181c67a29b77dc3cb4 diff --git a/src/core/constants.rs b/src/core/constants.rs index 3748a3b..dd8cc08 100644 --- a/src/core/constants.rs +++ b/src/core/constants.rs @@ -9,6 +9,15 @@ pub const CTXVM_MESSAGES_KIND: u16 = 25910; /// Encrypted messages using NIP-59 Gift Wrap (kind 1059) pub const GIFT_WRAP_KIND: u16 = 1059; +/// Ephemeral variant of NIP-59 Gift Wrap (kind 21059, CEP-19) +/// +/// Same structure and semantics as kind 1059, but in NIP-01's ephemeral range. +/// Relays are not expected to store ephemeral events beyond transient forwarding. +pub const EPHEMERAL_GIFT_WRAP_KIND: u16 = 21059; + +/// Replaceable relay list metadata event following NIP-65 (CEP-17) +pub const RELAY_LIST_METADATA_KIND: u16 = 10002; + /// Server announcement (addressable, kind 11316) pub const SERVER_ANNOUNCEMENT_KIND: u16 = 11316; @@ -29,6 +38,9 @@ pub mod tags { /// Public key tag pub const PUBKEY: &str = "p"; + /// Relay URL tag (CEP-17) + pub const RELAY: &str = "r"; + /// Event ID tag for correlation pub const EVENT_ID: &str = "e"; @@ -49,11 +61,42 @@ pub mod tags { /// Support encryption tag pub const SUPPORT_ENCRYPTION: &str = "support_encryption"; + + /// Support ephemeral gift wrap kind (21059) for encrypted messages (CEP-19) + pub const SUPPORT_ENCRYPTION_EPHEMERAL: &str = "support_encryption_ephemeral"; + + /// Support CEP-22 oversized payload transfer via notifications/progress framing + pub const SUPPORT_OVERSIZED_TRANSFER: &str = "support_oversized_transfer"; } /// Maximum message size (1MB) pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024; +/// Default LRU cache size for deduplication +pub const DEFAULT_LRU_SIZE: usize = 5000; + +/// Default timeout for network/relay operations (30 seconds) +pub const DEFAULT_TIMEOUT_MS: u64 = 30_000; + +/// Default relay targets for discoverability publication (CEP-17). +/// +/// These are used as additional publication targets for server metadata, +/// even when they are not part of the server's operational relay list. +pub const DEFAULT_BOOTSTRAP_RELAY_URLS: &[&str] = &[ + "wss://relay.damus.io", + "wss://relay.primal.net", + "wss://nos.lol", + "wss://relay.snort.social/", + "wss://nostr.mom/", + "wss://nostr.oxtr.dev/", +]; + +/// MCP protocol method for the initialization request +pub const INITIALIZE_METHOD: &str = "initialize"; + +/// MCP protocol method for the initialized notification +pub const NOTIFICATIONS_INITIALIZED_METHOD: &str = "notifications/initialized"; + /// Kinds that should never be encrypted (public announcements) pub const UNENCRYPTED_KINDS: &[u16] = &[ SERVER_ANNOUNCEMENT_KIND, @@ -62,3 +105,122 @@ pub const UNENCRYPTED_KINDS: &[u16] = &[ RESOURCETEMPLATES_LIST_KIND, PROMPTS_LIST_KIND, ]; + +#[cfg(feature = "rmcp")] +pub fn mcp_protocol_version() -> &'static str { + use std::sync::OnceLock; + static VERSION: OnceLock = OnceLock::new(); + VERSION + .get_or_init(|| rmcp::model::ProtocolVersion::LATEST.to_string()) + .as_str() +} + +#[cfg(not(feature = "rmcp"))] +pub const fn mcp_protocol_version() -> &'static str { + "2025-11-25" +} + +// Compile-time range checks (NIP-01 kind ranges). +// Placed at module level so violations are caught in every build, not just `cargo test`. +const _: () = { + // Ephemeral events: 20000 <= kind < 30000 + assert!(EPHEMERAL_GIFT_WRAP_KIND >= 20000); + assert!(EPHEMERAL_GIFT_WRAP_KIND < 30000); + assert!(CTXVM_MESSAGES_KIND >= 20000); + assert!(CTXVM_MESSAGES_KIND < 30000); + // Replaceable events: 10000 <= kind < 20000 + assert!(RELAY_LIST_METADATA_KIND >= 10000); + assert!(RELAY_LIST_METADATA_KIND < 20000); +}; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_event_kind_values_match_spec() { + assert_eq!(CTXVM_MESSAGES_KIND, 25910); + assert_eq!(GIFT_WRAP_KIND, 1059); + assert_eq!(EPHEMERAL_GIFT_WRAP_KIND, 21059); + assert_eq!(RELAY_LIST_METADATA_KIND, 10002); + assert_eq!(SERVER_ANNOUNCEMENT_KIND, 11316); + assert_eq!(TOOLS_LIST_KIND, 11317); + assert_eq!(RESOURCES_LIST_KIND, 11318); + assert_eq!(RESOURCETEMPLATES_LIST_KIND, 11319); + assert_eq!(PROMPTS_LIST_KIND, 11320); + } + + #[test] + fn test_tag_values_match_ts_sdk() { + assert_eq!(tags::PUBKEY, "p"); + assert_eq!(tags::RELAY, "r"); + assert_eq!(tags::EVENT_ID, "e"); + assert_eq!(tags::CAPABILITY, "cap"); + assert_eq!(tags::NAME, "name"); + assert_eq!(tags::WEBSITE, "website"); + assert_eq!(tags::PICTURE, "picture"); + assert_eq!(tags::ABOUT, "about"); + assert_eq!(tags::SUPPORT_ENCRYPTION, "support_encryption"); + assert_eq!( + tags::SUPPORT_ENCRYPTION_EPHEMERAL, + "support_encryption_ephemeral" + ); + assert_eq!( + tags::SUPPORT_OVERSIZED_TRANSFER, + "support_oversized_transfer" + ); + } + + #[test] + fn test_announcement_kinds_in_addressable_range() { + // NIP-01: addressable events are 30000 <= kind < 40000 + // However, the spec uses 11316-11320 which are in the replaceable range. + // These are parameterized replaceable events per the ContextVM spec. + for &kind in UNENCRYPTED_KINDS { + assert!(kind >= 11316); + assert!(kind <= 11320); + } + } + + #[test] + fn test_bootstrap_relays_are_wss() { + for url in DEFAULT_BOOTSTRAP_RELAY_URLS { + assert!( + url.starts_with("wss://"), + "Bootstrap relay must use wss: {url}" + ); + } + } + + #[test] + fn test_bootstrap_relays_nonempty() { + assert!( + !DEFAULT_BOOTSTRAP_RELAY_URLS.is_empty(), + "Must have at least one bootstrap relay" + ); + } + + #[test] + fn test_mcp_method_constants() { + assert_eq!(INITIALIZE_METHOD, "initialize"); + assert_eq!( + NOTIFICATIONS_INITIALIZED_METHOD, + "notifications/initialized" + ); + } + + #[test] + fn test_unencrypted_kinds_contains_all_announcements() { + assert!(UNENCRYPTED_KINDS.contains(&SERVER_ANNOUNCEMENT_KIND)); + assert!(UNENCRYPTED_KINDS.contains(&TOOLS_LIST_KIND)); + assert!(UNENCRYPTED_KINDS.contains(&RESOURCES_LIST_KIND)); + assert!(UNENCRYPTED_KINDS.contains(&RESOURCETEMPLATES_LIST_KIND)); + assert!(UNENCRYPTED_KINDS.contains(&PROMPTS_LIST_KIND)); + } + + #[test] + fn test_gift_wrap_not_in_unencrypted() { + assert!(!UNENCRYPTED_KINDS.contains(&GIFT_WRAP_KIND)); + assert!(!UNENCRYPTED_KINDS.contains(&EPHEMERAL_GIFT_WRAP_KIND)); + } +} diff --git a/src/core/serializers.rs b/src/core/serializers.rs index 3d641fd..69c519c 100644 --- a/src/core/serializers.rs +++ b/src/core/serializers.rs @@ -50,7 +50,7 @@ pub fn get_tag_value_from_slice(tags: &[Tag], name: &str) -> Option { #[cfg(test)] mod tests { use super::*; - use crate::core::types::{JsonRpcRequest, JsonRpcMessage}; + use crate::core::types::{JsonRpcMessage, JsonRpcRequest}; #[test] fn test_roundtrip() { diff --git a/src/core/types.rs b/src/core/types.rs index cb10773..d3d0064 100644 --- a/src/core/types.rs +++ b/src/core/types.rs @@ -4,16 +4,19 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::time::Instant; +use crate::core::constants::{EPHEMERAL_GIFT_WRAP_KIND, GIFT_WRAP_KIND}; + // ── Encryption mode ───────────────────────────────────────────────── /// Encryption mode for transport communication. /// /// Controls whether MCP messages are sent as plaintext kind 25910 events /// or wrapped in NIP-59 gift wraps (kind 1059) for end-to-end encryption. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum EncryptionMode { /// Encrypt responses only when the incoming request was encrypted (mirror mode). + #[default] Optional, /// Enforce encryption for all messages; reject plaintext. Required, @@ -21,9 +24,36 @@ pub enum EncryptionMode { Disabled, } -impl Default for EncryptionMode { - fn default() -> Self { - Self::Optional +// Gift-wrap mode (CEP-19) + +// Gift-wrap policy for encrypted transport communication (CEP-19). +// Controls whether encrypted messages use persistent gift wraps (kind `1059`), +// ephemeral gift wraps (kind `21059`), or adapt based on peer support. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum GiftWrapMode { + /// Prefer persistent gift wraps until ephemeral support is explicitly chosen or learned. + #[default] + Optional, + /// Force the ephemeral gift-wrap kind (`21059`) for encrypted messages. + Ephemeral, + /// Force the persistent gift-wrap kind (`1059`) for encrypted messages. + Persistent, +} + +impl GiftWrapMode { + /// Returns whether this mode accepts the given encrypted outer event kind. + pub fn allows_kind(self, kind: u16) -> bool { + match self { + Self::Optional => kind == GIFT_WRAP_KIND || kind == EPHEMERAL_GIFT_WRAP_KIND, + Self::Ephemeral => kind == EPHEMERAL_GIFT_WRAP_KIND, + Self::Persistent => kind == GIFT_WRAP_KIND, + } + } + + /// Returns whether this mode supports sending and advertising ephemeral gift wraps. + pub fn supports_ephemeral(self) -> bool { + !matches!(self, Self::Persistent) } } @@ -34,6 +64,7 @@ impl Default for EncryptionMode { /// Published as the content of a replaceable Nostr event so that clients /// can discover the server's identity and metadata. #[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[non_exhaustive] pub struct ServerInfo { /// Human-readable server name. #[serde(skip_serializing_if = "Option::is_none")] @@ -52,6 +83,34 @@ pub struct ServerInfo { pub about: Option, } +impl ServerInfo { + /// Set the server name. + pub fn with_name(mut self, name: impl Into) -> Self { + self.name = Some(name.into()); + self + } + /// Set the server version. + pub fn with_version(mut self, version: impl Into) -> Self { + self.version = Some(version.into()); + self + } + /// Set the server picture URL. + pub fn with_picture(mut self, picture: impl Into) -> Self { + self.picture = Some(picture.into()); + self + } + /// Set the server website URL. + pub fn with_website(mut self, website: impl Into) -> Self { + self.website = Some(website.into()); + self + } + /// Set the server description. + pub fn with_about(mut self, about: impl Into) -> Self { + self.about = Some(about.into()); + self + } +} + // ── Client session ────────────────────────────────────────────────── /// Client session state tracked by the server transport. @@ -61,6 +120,16 @@ pub struct ClientSession { pub is_initialized: bool, /// Whether the client's messages were encrypted. pub is_encrypted: bool, + /// Whether server discovery tags have been sent to this client (one-shot flag). + pub has_sent_common_tags: bool, + /// Whether the client has demonstrated support for ephemeral gift wraps (CEP-19). + pub supports_ephemeral_gift_wrap: bool, + /// Learned from client discovery tags: peer supports NIP-44 encryption. + pub supports_encryption: bool, + /// Learned from client discovery tags: peer supports ephemeral gift wraps (CEP-19). + pub supports_ephemeral_encryption: bool, + /// Learned from client discovery tags: peer supports CEP-22 oversized transfer. + pub supports_oversized_transfer: bool, /// Last activity timestamp. pub last_activity: Instant, /// Pending requests: event_id → original request ID. @@ -75,6 +144,11 @@ impl ClientSession { Self { is_initialized: false, is_encrypted, + has_sent_common_tags: false, + supports_ephemeral_gift_wrap: false, + supports_encryption: false, + supports_ephemeral_encryption: false, + supports_oversized_transfer: false, last_activity: Instant::now(), pending_requests: HashMap::new(), event_to_progress_token: HashMap::new(), @@ -227,3 +301,323 @@ pub struct CapabilityExclusion { /// Optional capability name for method-specific exclusions (e.g., "get_weather"). pub name: Option, } + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::constants::{EPHEMERAL_GIFT_WRAP_KIND, GIFT_WRAP_KIND}; + use serde_json::json; + use std::thread; + use std::time::Duration; + + #[test] + fn test_encryption_mode_serde_roundtrip_optional() { + let mode = EncryptionMode::Optional; + let s = serde_json::to_string(&mode).unwrap(); + assert_eq!(s, "\"optional\""); + let parsed: EncryptionMode = serde_json::from_str(&s).unwrap(); + assert_eq!(parsed, mode); + } + + #[test] + fn test_encryption_mode_serde_roundtrip_required() { + let mode = EncryptionMode::Required; + let s = serde_json::to_string(&mode).unwrap(); + assert_eq!(s, "\"required\""); + let parsed: EncryptionMode = serde_json::from_str(&s).unwrap(); + assert_eq!(parsed, mode); + } + + #[test] + fn test_encryption_mode_serde_roundtrip_disabled() { + let mode = EncryptionMode::Disabled; + let s = serde_json::to_string(&mode).unwrap(); + assert_eq!(s, "\"disabled\""); + let parsed: EncryptionMode = serde_json::from_str(&s).unwrap(); + assert_eq!(parsed, mode); + } + + #[test] + fn test_gift_wrap_mode_serde_roundtrip_optional() { + let mode = GiftWrapMode::Optional; + let s = serde_json::to_string(&mode).unwrap(); + assert_eq!(s, "\"optional\""); + let parsed: GiftWrapMode = serde_json::from_str(&s).unwrap(); + assert_eq!(parsed, mode); + } + + #[test] + fn test_gift_wrap_mode_serde_roundtrip_ephemeral() { + let mode = GiftWrapMode::Ephemeral; + let s = serde_json::to_string(&mode).unwrap(); + assert_eq!(s, "\"ephemeral\""); + let parsed: GiftWrapMode = serde_json::from_str(&s).unwrap(); + assert_eq!(parsed, mode); + } + + #[test] + fn test_gift_wrap_mode_serde_roundtrip_persistent() { + let mode = GiftWrapMode::Persistent; + let s = serde_json::to_string(&mode).unwrap(); + assert_eq!(s, "\"persistent\""); + let parsed: GiftWrapMode = serde_json::from_str(&s).unwrap(); + assert_eq!(parsed, mode); + } + + #[test] + fn test_gift_wrap_mode_policy_helpers() { + // Optional accepts both kinds + assert!(GiftWrapMode::Optional.allows_kind(GIFT_WRAP_KIND)); + assert!(GiftWrapMode::Optional.allows_kind(EPHEMERAL_GIFT_WRAP_KIND)); + // Ephemeral only accepts 21059 + assert!(GiftWrapMode::Ephemeral.allows_kind(EPHEMERAL_GIFT_WRAP_KIND)); + assert!(!GiftWrapMode::Ephemeral.allows_kind(GIFT_WRAP_KIND)); + // Persistent only accepts 1059 + assert!(GiftWrapMode::Persistent.allows_kind(GIFT_WRAP_KIND)); + assert!(!GiftWrapMode::Persistent.allows_kind(EPHEMERAL_GIFT_WRAP_KIND)); + // supports_ephemeral check + assert!(GiftWrapMode::Optional.supports_ephemeral()); + assert!(GiftWrapMode::Ephemeral.supports_ephemeral()); + assert!(!GiftWrapMode::Persistent.supports_ephemeral()); + } + + fn assert_json_rpc_roundtrip(msg: &JsonRpcMessage) { + let wire = serde_json::to_string(msg).unwrap(); + let parsed: JsonRpcMessage = serde_json::from_str(&wire).unwrap(); + let before = serde_json::to_value(msg).unwrap(); + let after = serde_json::to_value(&parsed).unwrap(); + assert_eq!(before, after); + } + + #[test] + fn test_json_rpc_message_serde_roundtrip_request() { + let msg = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: json!(42), + method: "tools/list".to_string(), + params: Some(json!({ "cursor": null })), + }); + assert_json_rpc_roundtrip(&msg); + } + + #[test] + fn test_json_rpc_message_serde_roundtrip_request_without_params() { + let msg = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: json!("req-id"), + method: "ping".to_string(), + params: None, + }); + assert_json_rpc_roundtrip(&msg); + } + + #[test] + fn test_json_rpc_message_serde_roundtrip_response() { + let msg = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: json!(1), + result: json!({ "tools": [] }), + }); + assert_json_rpc_roundtrip(&msg); + } + + #[test] + fn test_json_rpc_message_serde_roundtrip_error_response() { + let msg = JsonRpcMessage::ErrorResponse(JsonRpcErrorResponse { + jsonrpc: "2.0".to_string(), + id: json!(99), + error: JsonRpcError { + code: -32600, + message: "Invalid Request".to_string(), + data: Some(json!({ "hint": "fix it" })), + }, + }); + assert_json_rpc_roundtrip(&msg); + } + + #[test] + fn test_json_rpc_message_serde_roundtrip_notification() { + let msg = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/initialized".to_string(), + params: None, + }); + assert_json_rpc_roundtrip(&msg); + } + + #[test] + fn test_json_rpc_message_type_predicates() { + let req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: json!(1), + method: "m".to_string(), + params: None, + }); + let res = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: json!(1), + result: json!(null), + }); + let err = JsonRpcMessage::ErrorResponse(JsonRpcErrorResponse { + jsonrpc: "2.0".to_string(), + id: json!(1), + error: JsonRpcError { + code: -1, + message: "e".to_string(), + data: None, + }, + }); + let notif = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "n".to_string(), + params: None, + }); + + assert!(req.is_request()); + assert!(res.is_response()); + assert!(err.is_error()); + assert!(notif.is_notification()); + } + + #[test] + fn test_json_rpc_error_data_none_omitted() { + let err = JsonRpcError { + code: -32600, + message: "bad".to_string(), + data: None, + }; + let json_str = serde_json::to_string(&err).unwrap(); + let value: serde_json::Value = serde_json::from_str(&json_str).unwrap(); + let obj = value.as_object().expect("error object"); + assert!( + !obj.contains_key("data"), + "expected data omitted when None, got: {json_str}" + ); + } + + #[test] + fn test_json_rpc_message_method() { + let req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: json!(0), + method: "tools/call".to_string(), + params: None, + }); + let res = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: json!(0), + result: json!(null), + }); + let err = JsonRpcMessage::ErrorResponse(JsonRpcErrorResponse { + jsonrpc: "2.0".to_string(), + id: json!(0), + error: JsonRpcError { + code: 0, + message: "e".to_string(), + data: None, + }, + }); + let notif = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/progress".to_string(), + params: None, + }); + + assert_eq!(req.method(), Some("tools/call")); + assert_eq!(res.method(), None); + assert_eq!(err.method(), None); + assert_eq!(notif.method(), Some("notifications/progress")); + } + + #[test] + fn test_json_rpc_message_id() { + let req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: json!("abc"), + method: "m".to_string(), + params: None, + }); + let res = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: json!(7), + result: json!(null), + }); + let err = JsonRpcMessage::ErrorResponse(JsonRpcErrorResponse { + jsonrpc: "2.0".to_string(), + id: json!([1, 2]), + error: JsonRpcError { + code: 0, + message: "e".to_string(), + data: None, + }, + }); + let notif = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "n".to_string(), + params: None, + }); + + assert_eq!(req.id(), Some(&json!("abc"))); + assert_eq!(res.id(), Some(&json!(7))); + assert_eq!(err.id(), Some(&json!([1, 2]))); + assert_eq!(notif.id(), None); + } + + #[test] + fn test_server_info_serde_all_fields_present() { + let info = ServerInfo { + name: Some("Test Server".to_string()), + version: Some("1.0.0".to_string()), + picture: Some("https://example.com/p.png".to_string()), + website: Some("https://example.com".to_string()), + about: Some("About text".to_string()), + }; + let json_str = serde_json::to_string(&info).unwrap(); + let parsed: ServerInfo = serde_json::from_str(&json_str).unwrap(); + assert_eq!(parsed.name, info.name); + assert_eq!(parsed.version, info.version); + assert_eq!(parsed.picture, info.picture); + assert_eq!(parsed.website, info.website); + assert_eq!(parsed.about, info.about); + } + + #[test] + fn test_server_info_serde_optional_fields_omitted() { + let info = ServerInfo { + name: None, + version: None, + picture: None, + website: None, + about: None, + }; + let json_str = serde_json::to_string(&info).unwrap(); + assert_eq!(json_str, "{}"); + } + + #[test] + fn test_client_session_new_initial_state_encrypted() { + let session = ClientSession::new(true); + assert!(!session.is_initialized); + assert!(session.is_encrypted); + assert!(session.pending_requests.is_empty()); + assert!(session.event_to_progress_token.is_empty()); + } + + #[test] + fn test_client_session_new_initial_state_plaintext() { + let session = ClientSession::new(false); + assert!(!session.is_initialized); + assert!(!session.is_encrypted); + assert!(session.pending_requests.is_empty()); + assert!(session.event_to_progress_token.is_empty()); + } + + #[test] + fn test_client_session_update_activity() { + let mut session = ClientSession::new(false); + let first = session.last_activity; + thread::sleep(Duration::from_millis(10)); + session.update_activity(); + assert!(session.last_activity > first); + } +} diff --git a/src/core/validation.rs b/src/core/validation.rs index d409b7b..9e6dc5d 100644 --- a/src/core/validation.rs +++ b/src/core/validation.rs @@ -8,6 +8,17 @@ pub fn validate_message_size(content: &str) -> bool { content.len() <= MAX_MESSAGE_SIZE } +/// Validate size and structure, then parse into a [`JsonRpcMessage`]. +pub fn validate_and_parse(content: &str) -> Option { + if !validate_message_size(content) { + tracing::warn!("Message size validation failed: {} bytes", content.len()); + return None; + } + + let value: serde_json::Value = serde_json::from_str(content).ok()?; + validate_message(&value) +} + /// Validate that a JSON value is a well-formed JSON-RPC 2.0 message. /// /// Checks: @@ -61,4 +72,30 @@ mod tests { let big = "x".repeat(MAX_MESSAGE_SIZE + 1); assert!(!validate_message_size(&big)); } + + #[test] + fn test_validate_and_parse_valid_request() { + let content = r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#; + let msg = validate_and_parse(content).unwrap(); + assert!(msg.is_request()); + assert_eq!(msg.method(), Some("tools/list")); + } + + #[test] + fn test_validate_and_parse_rejects_oversized() { + let padding = "x".repeat(MAX_MESSAGE_SIZE); + let content = format!(r#"{{"jsonrpc":"2.0","id":1,"method":"{}"}}"#, padding); + assert!(validate_and_parse(&content).is_none()); + } + + #[test] + fn test_validate_and_parse_rejects_invalid_version() { + let content = r#"{"jsonrpc":"1.0","id":1,"method":"test"}"#; + assert!(validate_and_parse(content).is_none()); + } + + #[test] + fn test_validate_and_parse_rejects_invalid_json() { + assert!(validate_and_parse("not json").is_none()); + } } diff --git a/src/discovery/mod.rs b/src/discovery/mod.rs index eb0e876..d5ffa3a 100644 --- a/src/discovery/mod.rs +++ b/src/discovery/mod.rs @@ -64,8 +64,7 @@ pub async fn discover_servers( let mut announcements = Vec::new(); for event in events { - let server_info: ServerInfo = - serde_json::from_str(&event.content).unwrap_or_default(); + let server_info: ServerInfo = serde_json::from_str(&event.content).unwrap_or_default(); announcements.push(ServerAnnouncement { pubkey: event.pubkey.to_hex(), pubkey_parsed: event.pubkey, @@ -120,6 +119,50 @@ pub async fn discover_resource_templates( .await } +/// Discover tools and parse them into rmcp typed descriptors. +#[cfg(feature = "rmcp")] +pub async fn discover_tools_typed( + client: &Arc, + server_pubkey: &PublicKey, + relay_urls: &[String], +) -> Result> { + let raw = discover_tools(client, server_pubkey, relay_urls).await?; + parse_typed_list(raw) +} + +/// Discover resources and parse them into rmcp typed descriptors. +#[cfg(feature = "rmcp")] +pub async fn discover_resources_typed( + client: &Arc, + server_pubkey: &PublicKey, + relay_urls: &[String], +) -> Result> { + let raw = discover_resources(client, server_pubkey, relay_urls).await?; + parse_typed_list(raw) +} + +/// Discover prompts and parse them into rmcp typed descriptors. +#[cfg(feature = "rmcp")] +pub async fn discover_prompts_typed( + client: &Arc, + server_pubkey: &PublicKey, + relay_urls: &[String], +) -> Result> { + let raw = discover_prompts(client, server_pubkey, relay_urls).await?; + parse_typed_list(raw) +} + +/// Discover resource templates and parse them into rmcp typed descriptors. +#[cfg(feature = "rmcp")] +pub async fn discover_resource_templates_typed( + client: &Arc, + server_pubkey: &PublicKey, + relay_urls: &[String], +) -> Result> { + let raw = discover_resource_templates(client, server_pubkey, relay_urls).await?; + parse_typed_list(raw) +} + // ── Internal ──────────────────────────────────────────────────────── async fn fetch_list( @@ -153,6 +196,20 @@ async fn fetch_list( .unwrap_or_default()) } +#[cfg(feature = "rmcp")] +fn parse_typed_list(raw: Vec) -> Result> +where + T: serde::de::DeserializeOwned, +{ + let mut parsed = Vec::new(); + for item in raw { + let value = serde_json::from_value(item) + .map_err(|e| Error::Other(format!("Failed to parse typed discovery item: {e}")))?; + parsed.push(value); + } + Ok(parsed) +} + #[cfg(test)] mod tests { use super::*; @@ -175,7 +232,10 @@ mod tests { assert_eq!(parsed.version, Some("1.0.0".to_string())); assert_eq!(parsed.about, Some("A test MCP server".to_string())); assert_eq!(parsed.website, Some("https://example.com".to_string())); - assert_eq!(parsed.picture, Some("https://example.com/pic.png".to_string())); + assert_eq!( + parsed.picture, + Some("https://example.com/pic.png".to_string()) + ); } #[test] @@ -222,7 +282,8 @@ mod tests { }, event_id: EventId::from_hex( "0000000000000000000000000000000000000000000000000000000000000001", - ).unwrap(), + ) + .unwrap(), created_at: Timestamp::now(), }; diff --git a/src/encryption/mod.rs b/src/encryption/mod.rs index dadf331..6a6d4bc 100644 --- a/src/encryption/mod.rs +++ b/src/encryption/mod.rs @@ -3,7 +3,7 @@ //! Provides NIP-44 encryption/decryption and NIP-59 gift wrapping. //! The actual gift wrapping is done via nostr-sdk's Client for full NIP-59 compliance. -use crate::core::constants::GIFT_WRAP_KIND; +use crate::core::constants::{EPHEMERAL_GIFT_WRAP_KIND, GIFT_WRAP_KIND}; use crate::core::error::{Error, Result}; use nostr_sdk::prelude::*; @@ -37,16 +37,9 @@ where .map_err(|e| Error::Decryption(e.to_string())) } -/// Decrypt a single-layer NIP-44 gift wrap (kind 1059). -/// -/// This matches the ContextVM JS/TS SDK's encryption scheme: -/// - The gift wrap event has NIP-44 encrypted content (single layer) -/// - Decrypt using recipient's key + event's pubkey (ephemeral sender) -/// - Returns the decrypted plaintext content string -pub async fn decrypt_gift_wrap_single_layer( - signer: &T, - event: &Event, -) -> Result +// Decrypt a single-layer NIP-44 gift wrap (kind 1059). + +pub async fn decrypt_gift_wrap_single_layer(signer: &T, event: &Event) -> Result where T: NostrSigner, { @@ -54,13 +47,8 @@ where decrypt_nip44(signer, &sender_pubkey, &event.content).await } -/// Create a single-layer NIP-44 gift wrap (kind 1059). -/// -/// Matches the ContextVM JS/TS SDK's `encryptMessage`: -/// 1. Generate ephemeral keypair -/// 2. NIP-44 encrypt plaintext using ephemeral_secret + recipient_pubkey -/// 3. Build kind 1059 event with `p` tag pointing to recipient -/// 4. Sign with ephemeral key +// Create a single-layer NIP-44 gift wrap (kind 1059). + pub async fn gift_wrap_single_layer( _signer: &T, recipient: &PublicKey, @@ -73,8 +61,39 @@ where let encrypted = encrypt_nip44(&ephemeral, recipient, plaintext).await?; - let builder = EventBuilder::new(Kind::Custom(GIFT_WRAP_KIND), encrypted) - .tag(Tag::public_key(*recipient)); + let builder = + EventBuilder::new(Kind::Custom(GIFT_WRAP_KIND), encrypted).tag(Tag::public_key(*recipient)); + + builder + .sign_with_keys(&ephemeral) + .map_err(|e| Error::Encryption(e.to_string())) +} + +/// Create a single-layer NIP-44 gift wrap using the provided outer event kind. +/// +/// Only ContextVM's supported persistent (`1059`) and ephemeral (`21059`) gift-wrap +/// kinds are accepted here. +pub async fn gift_wrap_single_layer_with_kind( + _signer: &T, + recipient: &PublicKey, + plaintext: &str, + gift_wrap_kind: u16, +) -> Result +where + T: NostrSigner, +{ + if gift_wrap_kind != GIFT_WRAP_KIND && gift_wrap_kind != EPHEMERAL_GIFT_WRAP_KIND { + return Err(Error::Encryption(format!( + "Unsupported gift-wrap kind for single-layer encryption: {gift_wrap_kind}" + ))); + } + + let ephemeral = Keys::generate(); + + let encrypted = encrypt_nip44(&ephemeral, recipient, plaintext).await?; + + let builder = + EventBuilder::new(Kind::Custom(gift_wrap_kind), encrypted).tag(Tag::public_key(*recipient)); builder .sign_with_keys(&ephemeral) @@ -114,6 +133,8 @@ pub async fn gift_wrap( #[cfg(test)] mod tests { + use crate::core::constants::{EPHEMERAL_GIFT_WRAP_KIND, GIFT_WRAP_KIND}; + use super::*; #[tokio::test] @@ -142,10 +163,7 @@ mod tests { /// 2. NIP-44 encrypt the plaintext using ephemeral_secret + recipient_pubkey /// 3. Build kind 1059 event with encrypted content, `p` tag = recipient /// 4. Sign with ephemeral key - async fn create_js_style_gift_wrap( - plaintext: &str, - recipient: &PublicKey, - ) -> (Event, Keys) { + async fn create_simple_gift_wrap(plaintext: &str, recipient: &PublicKey) -> (Event, Keys) { let ephemeral = Keys::generate(); // Single-layer NIP-44 encrypt @@ -154,7 +172,7 @@ mod tests { .unwrap(); // Build kind 1059 event - let builder = EventBuilder::new(Kind::Custom(1059), encrypted) + let builder = EventBuilder::new(Kind::from(GIFT_WRAP_KIND), encrypted) .tag(Tag::public_key(*recipient)); let event = builder.sign_with_keys(&ephemeral).unwrap(); @@ -183,7 +201,7 @@ mod tests { // Step 3: Encrypt as a gift wrap let (gift_wrap, _ephemeral) = - create_js_style_gift_wrap(&inner_json, &server_keys.public_key()).await; + create_simple_gift_wrap(&inner_json, &server_keys.public_key()).await; assert_eq!(gift_wrap.kind, Kind::Custom(1059)); @@ -265,4 +283,93 @@ mod tests { // (it uses an ephemeral key, like the JS SDK) assert_ne!(gift_wrap_event.pubkey, sender_keys.public_key()); } + + /// Regression: gift-wrapped inner events with a tampered pubkey must be + /// caught by `Event::verify()`. + #[tokio::test] + async fn test_forged_inner_event_detected_by_verify() { + let real_sender = Keys::generate(); + let impersonated = Keys::generate(); + let recipient = Keys::generate(); + + let mcp_content = r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#; + + // Step 1: build a legitimately signed inner event + let inner_event = EventBuilder::new(Kind::Custom(25910), mcp_content) + .tag(Tag::public_key(recipient.public_key())) + .sign_with_keys(&real_sender) + .unwrap(); + + // Step 2: tamper the pubkey (keep original, now-invalid, signature) + let mut forged_json: serde_json::Value = serde_json::to_value(&inner_event).unwrap(); + forged_json["pubkey"] = serde_json::Value::String(impersonated.public_key().to_hex()); + let forged_str = serde_json::to_string(&forged_json).unwrap(); + + // Step 3: gift-wrap the forged payload + let (gift_wrap, _) = create_simple_gift_wrap(&forged_str, &recipient.public_key()).await; + + // Decrypt + parse both succeed — the forgery is syntactically valid + let decrypted = decrypt_gift_wrap_single_layer(&recipient, &gift_wrap) + .await + .unwrap(); + let parsed: Event = serde_json::from_str(&decrypted).unwrap(); + assert_eq!(parsed.pubkey, impersonated.public_key()); + + // Signature verification catches the tampered pubkey + assert!( + parsed.verify().is_err(), + "forged inner event must fail signature verification" + ); + } + + #[tokio::test] + async fn test_ephemeral_gift_wrap_roundtrip_single_layer() { + let sender_keys = Keys::generate(); + let recipient_keys = Keys::generate(); + + let mcp_content = r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#; + let inner_event = EventBuilder::new(Kind::Custom(25910), mcp_content) + .tag(Tag::public_key(recipient_keys.public_key())) + .sign_with_keys(&sender_keys) + .unwrap(); + let inner_json = serde_json::to_string(&inner_event).unwrap(); + + let gift_wrap_event = gift_wrap_single_layer_with_kind( + &sender_keys, + &recipient_keys.public_key(), + &inner_json, + EPHEMERAL_GIFT_WRAP_KIND, + ) + .await + .unwrap(); + + assert_eq!(gift_wrap_event.kind, Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND)); + + let decrypted = decrypt_gift_wrap_single_layer(&recipient_keys, &gift_wrap_event) + .await + .unwrap(); + let parsed: Event = serde_json::from_str(&decrypted).unwrap(); + assert_eq!(parsed.pubkey, sender_keys.public_key()); + assert_eq!(parsed.content, mcp_content); + } + + #[tokio::test] + async fn test_invalid_gift_wrap_kind_rejected() { + let sender_keys = Keys::generate(); + let recipient_keys = Keys::generate(); + + let error = gift_wrap_single_layer_with_kind( + &sender_keys, + &recipient_keys.public_key(), + "test", + 4242, + ) + .await + .unwrap_err(); + + assert!( + error.to_string().contains("Unsupported gift-wrap kind"), + "unexpected error: {error}" + ); + } } diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index d907611..e4bba91 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -8,11 +8,19 @@ use crate::core::types::JsonRpcMessage; use crate::transport::server::{IncomingRequest, NostrServerTransport, NostrServerTransportConfig}; /// Configuration for the gateway. +#[non_exhaustive] pub struct GatewayConfig { /// Nostr server transport configuration. pub nostr_config: NostrServerTransportConfig, } +impl GatewayConfig { + /// Create a new gateway configuration. + pub fn new(nostr_config: NostrServerTransportConfig) -> Self { + Self { nostr_config } + } +} + /// Gateway that bridges a local MCP server to Nostr. /// /// The gateway listens for incoming MCP requests via Nostr, forwards them @@ -40,9 +48,7 @@ impl NostrMCPGateway { /// /// The caller is responsible for processing requests and calling /// `send_response` for each one. - pub async fn start( - &mut self, - ) -> Result> { + pub async fn start(&mut self) -> Result> { if self.is_running { return Err(Error::Other("Gateway already running".to_string())); } @@ -81,6 +87,32 @@ impl NostrMCPGateway { } } +#[cfg(feature = "rmcp")] +impl NostrMCPGateway { + /// Start a gateway directly from an rmcp server handler. + /// + /// This additive API keeps the existing `new/start/send_response` flow intact, + /// while also allowing direct `handler.serve(transport)` style usage. + pub async fn serve_handler( + signer: T, + config: GatewayConfig, + handler: H, + ) -> Result> + where + T: nostr_sdk::prelude::IntoNostrSigner, + H: rmcp::ServerHandler, + { + use crate::NostrServerTransport; + use rmcp::ServiceExt; + + let transport = NostrServerTransport::new(signer, config.nostr_config).await?; + handler + .serve(transport) + .await + .map_err(|e| Error::Other(format!("rmcp server initialization failed: {e}"))) + } +} + #[cfg(test)] mod tests { use super::*; @@ -93,25 +125,44 @@ mod tests { let nostr_config = NostrServerTransportConfig { relay_urls: vec!["wss://relay.example.com".to_string()], encryption_mode: EncryptionMode::Required, + gift_wrap_mode: GiftWrapMode::Optional, server_info: Some(ServerInfo { name: Some("Test Gateway".to_string()), version: Some("1.0.0".to_string()), ..Default::default() }), - is_public_server: true, + is_announced_server: true, allowed_public_keys: vec!["abc123".to_string()], excluded_capabilities: vec![], + max_sessions: 1000, cleanup_interval: Duration::from_secs(120), session_timeout: Duration::from_secs(600), + request_timeout: Duration::from_secs(60), }; let config = GatewayConfig { nostr_config }; - assert_eq!(config.nostr_config.relay_urls, vec!["wss://relay.example.com"]); - assert_eq!(config.nostr_config.encryption_mode, EncryptionMode::Required); - assert!(config.nostr_config.is_public_server); + assert_eq!( + config.nostr_config.relay_urls, + vec!["wss://relay.example.com"] + ); + assert_eq!( + config.nostr_config.encryption_mode, + EncryptionMode::Required + ); + assert!(config.nostr_config.is_announced_server); assert_eq!(config.nostr_config.allowed_public_keys.len(), 1); - assert!(config.nostr_config.server_info.as_ref().unwrap().name.as_ref().unwrap() == "Test Gateway"); + assert!( + config + .nostr_config + .server_info + .as_ref() + .unwrap() + .name + .as_ref() + .unwrap() + == "Test Gateway" + ); } #[test] @@ -119,7 +170,10 @@ mod tests { let config = GatewayConfig { nostr_config: NostrServerTransportConfig::default(), }; - assert_eq!(config.nostr_config.encryption_mode, EncryptionMode::Optional); - assert!(!config.nostr_config.is_public_server); + assert_eq!( + config.nostr_config.encryption_mode, + EncryptionMode::Optional + ); + assert!(!config.nostr_config.is_announced_server); } } diff --git a/src/lib.rs b/src/lib.rs index becd92e..7615c7e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -45,13 +45,26 @@ pub mod relay; pub mod signer; pub mod transport; +#[cfg(feature = "rmcp")] +pub mod rmcp_transport; // Re-export commonly used types pub use core::error::{Error, Result}; pub use core::types::{ - CapabilityExclusion, ClientSession, EncryptionMode, JsonRpcError, JsonRpcErrorResponse, - JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ServerInfo, + CapabilityExclusion, ClientSession, EncryptionMode, GiftWrapMode, JsonRpcError, + JsonRpcErrorResponse, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, + ServerInfo, }; pub use discovery::ServerAnnouncement; -pub use relay::RelayPool; -pub use transport::client::{NostrClientTransport, NostrClientTransportConfig}; -pub use transport::server::{IncomingRequest, NostrServerTransport, NostrServerTransportConfig}; +pub use relay::mock::MockRelayPool; +pub use relay::{RelayPool, RelayPoolTrait}; +pub use transport::client::{ + ClientCorrelationStore, NostrClientTransport, NostrClientTransportConfig, +}; +pub use transport::discovery_tags::{DiscoveredPeerCapabilities, PeerCapabilities}; +pub use transport::server::{ + IncomingRequest, NostrServerTransport, NostrServerTransportConfig, RouteEntry, + ServerEventRouteStore, SessionSnapshot, SessionStore, +}; + +#[cfg(feature = "rmcp")] +pub use rmcp; diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index fab56d7..4833127 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -8,11 +8,19 @@ use crate::core::types::JsonRpcMessage; use crate::transport::client::{NostrClientTransport, NostrClientTransportConfig}; /// Configuration for the proxy. +#[non_exhaustive] pub struct ProxyConfig { /// Nostr client transport configuration. pub nostr_config: NostrClientTransportConfig, } +impl ProxyConfig { + /// Create a new proxy configuration. + pub fn new(nostr_config: NostrClientTransportConfig) -> Self { + Self { nostr_config } + } +} + /// Proxy that connects to a remote MCP server via Nostr. pub struct NostrMCPProxy { transport: NostrClientTransport, @@ -34,9 +42,7 @@ impl NostrMCPProxy { } /// Start the proxy. Returns a receiver for incoming responses/notifications. - pub async fn start( - &mut self, - ) -> Result> { + pub async fn start(&mut self) -> Result> { if self.is_running { return Err(Error::Other("Proxy already running".to_string())); } @@ -70,6 +76,32 @@ impl NostrMCPProxy { } } +#[cfg(feature = "rmcp")] +impl NostrMCPProxy { + /// Start a proxy directly from an rmcp client handler. + /// + /// This additive API keeps the existing `new/start/send` flow intact, + /// while also allowing direct `handler.serve(transport)` style usage. + pub async fn serve_client_handler( + signer: T, + config: ProxyConfig, + handler: H, + ) -> Result> + where + T: nostr_sdk::prelude::IntoNostrSigner, + H: rmcp::ClientHandler, + { + use crate::NostrClientTransport; + use rmcp::ServiceExt; + + let transport = NostrClientTransport::new(signer, config.nostr_config).await?; + handler + .serve(transport) + .await + .map_err(|e| Error::Other(format!("rmcp client initialization failed: {e}"))) + } +} + #[cfg(test)] mod tests { use super::*; @@ -86,15 +118,22 @@ mod tests { relay_urls: vec!["wss://relay.example.com".to_string()], server_pubkey: server_pubkey.clone(), encryption_mode: EncryptionMode::Required, + gift_wrap_mode: GiftWrapMode::Optional, is_stateless: true, timeout: Duration::from_secs(60), }; let config = ProxyConfig { nostr_config }; - assert_eq!(config.nostr_config.relay_urls, vec!["wss://relay.example.com"]); + assert_eq!( + config.nostr_config.relay_urls, + vec!["wss://relay.example.com"] + ); assert_eq!(config.nostr_config.server_pubkey, server_pubkey); - assert_eq!(config.nostr_config.encryption_mode, EncryptionMode::Required); + assert_eq!( + config.nostr_config.encryption_mode, + EncryptionMode::Required + ); assert!(config.nostr_config.is_stateless); assert_eq!(config.nostr_config.timeout, Duration::from_secs(60)); } @@ -105,6 +144,9 @@ mod tests { nostr_config: NostrClientTransportConfig::default(), }; assert!(!config.nostr_config.is_stateless); - assert_eq!(config.nostr_config.encryption_mode, EncryptionMode::Optional); + assert_eq!( + config.nostr_config.encryption_mode, + EncryptionMode::Optional + ); } } diff --git a/src/relay/mock.rs b/src/relay/mock.rs new file mode 100644 index 0000000..2b181d6 --- /dev/null +++ b/src/relay/mock.rs @@ -0,0 +1,359 @@ +//! In-memory mock relay pool for network-free testing. +//! +//! Mirrors the design of the TypeScript SDK's `MockRelayHub`: +//! - `publish_event` stores the event and broadcasts it to all `notifications()` receivers. +//! - `subscribe` registers filters and immediately replays matching stored events through the +//! broadcast, so listeners that called `notifications()` before `subscribe()` see the replay. +//! - `connect` / `disconnect` are no-ops — no sockets are opened. +//! - Signing uses a freshly generated ephemeral `Keys`; `signer()` returns it wrapped in `Arc` +//! so encryption code can call it without any real relay connection. + +use std::collections::HashMap; +use std::sync::Arc; + +use async_trait::async_trait; +use tokio::sync::Mutex; + +use nostr_sdk::prelude::*; + +use crate::core::error::{Error, Result}; +use crate::relay::RelayPoolTrait; + +// ── Internal state ──────────────────────────────────────────────────────────── + +struct MockRelayInner { + events: Vec, + /// Active subscriptions: id → filters registered by that subscription. + subscriptions: HashMap>, + next_sub_id: u32, +} + +impl MockRelayInner { + fn new() -> Self { + Self { + events: Vec::new(), + subscriptions: HashMap::new(), + next_sub_id: 0, + } + } +} + +// ── Public struct ───────────────────────────────────────────────────────────── + +/// In-memory relay pool for deterministic, network-free testing. +/// +/// Create one with [`MockRelayPool::new`] and pass it (wrapped in `Arc`) wherever +/// an `Arc` is expected. +pub struct MockRelayPool { + inner: Arc>, + /// Broadcast sender — every published event is sent here so that all + /// `notifications()` receivers see it. + notification_tx: tokio::sync::broadcast::Sender, + /// Ephemeral key used for signing in `publish` / `sign` / `signer`. + keys: Keys, +} + +impl MockRelayPool { + /// Create a new mock relay pool with a freshly generated ephemeral signing key. + pub fn new() -> Self { + let keys = Keys::generate(); + let (tx, _rx) = tokio::sync::broadcast::channel(1024); + Self { + inner: Arc::new(Mutex::new(MockRelayInner::new())), + notification_tx: tx, + keys, + } + } + + /// The ephemeral public key used by this mock for signing. + pub fn mock_public_key(&self) -> PublicKey { + self.keys.public_key() + } + + /// The ephemeral signing keys (for manual event injection in tests). + pub fn mock_keys(&self) -> Keys { + self.keys.clone() + } + + /// Like [`new`](Self::new) but with caller-provided signing keys. + pub fn with_keys(keys: Keys) -> Self { + let (tx, _rx) = tokio::sync::broadcast::channel(1024); + Self { + inner: Arc::new(Mutex::new(MockRelayInner::new())), + notification_tx: tx, + keys, + } + } + + /// Create a pair of linked mock relay pools with different signing keys. + /// + /// Both pools share the same event store and notification channel; events + /// published by one are visible to the other's `notifications()` receivers. + pub fn create_pair() -> (Self, Self) { + let (tx, _rx) = tokio::sync::broadcast::channel(1024); + let inner = Arc::new(Mutex::new(MockRelayInner::new())); + let a = Self { + inner: Arc::clone(&inner), + notification_tx: tx.clone(), + keys: Keys::generate(), + }; + let b = Self { + inner, + notification_tx: tx, + keys: Keys::generate(), + }; + (a, b) + } + + /// Create `n` linked mock relay pools with different signing keys. + /// + /// All pools share the same event store and notification channel so events + /// published by any one pool are visible to all others' `notifications()` + /// receivers. Useful for multi-client integration tests. + pub fn create_linked_group(n: usize) -> Vec { + assert!(n > 0, "group must have at least one pool"); + let (tx, _rx) = tokio::sync::broadcast::channel(1024); + let inner = Arc::new(Mutex::new(MockRelayInner::new())); + (0..n) + .map(|_| Self { + inner: Arc::clone(&inner), + notification_tx: tx.clone(), + keys: Keys::generate(), + }) + .collect() + } + + /// Clone of all events published so far (useful for assertions in tests). + pub async fn stored_events(&self) -> Vec { + self.inner.lock().await.events.clone() + } +} + +impl Default for MockRelayPool { + fn default() -> Self { + Self::new() + } +} + +// ── RelayPoolTrait impl ─────────────────────────────────────────────────────── + +#[async_trait] +impl RelayPoolTrait for MockRelayPool { + /// No-op: the mock has no sockets to open. + async fn connect(&self, _relay_urls: &[String]) -> Result<()> { + Ok(()) + } + + /// No-op: the mock has no sockets to close. + async fn disconnect(&self) -> Result<()> { + Ok(()) + } + + /// Store the event and broadcast it to all current `notifications()` receivers. + async fn publish_event(&self, event: &Event) -> Result { + let event_id = event.id; + + { + let mut inner = self.inner.lock().await; + inner.events.push(event.clone()); + } + + // Always broadcast — consumers filter by kind/pubkey/tag themselves, + // which mirrors how nostr-sdk's real notification stream works. + let notification = make_notification(event.clone()); + // Ignore send errors: they just mean there are no active receivers yet. + let _ = self.notification_tx.send(notification); + + Ok(event_id) + } + + /// Sign `builder` with the ephemeral key, then call `publish_event`. + async fn publish(&self, builder: EventBuilder) -> Result { + let event = sign_with_keys(builder, &self.keys)?; + let id = event.id; + self.publish_event(&event).await?; + Ok(id) + } + + /// Sign `builder` with the ephemeral key and return the event without publishing. + async fn sign(&self, builder: EventBuilder) -> Result { + sign_with_keys(builder, &self.keys) + } + + /// Return the ephemeral key as a signer. + async fn signer(&self) -> Result> { + Ok(Arc::new(self.keys.clone()) as Arc) + } + + /// Return a new broadcast receiver. Each call gets an independent receiver + /// that sees all events published *after* this call, plus any replayed by + /// a subsequent `subscribe()`. + fn notifications(&self) -> tokio::sync::broadcast::Receiver { + self.notification_tx.subscribe() + } + + /// Return the ephemeral public key. + async fn public_key(&self) -> Result { + Ok(self.keys.public_key()) + } + + /// Register the filters and immediately replay any already-stored events that + /// match them through the broadcast channel, mirroring the behaviour of a + /// real relay that sends historical events before EOSE. + async fn subscribe(&self, filters: Vec) -> Result<()> { + let replay = { + let mut inner = self.inner.lock().await; + let sub_id = inner.next_sub_id; + inner.next_sub_id += 1; + + // Store filters first so the replay read comes from the stored value, + // ensuring the field is both written and read (no dead-code warning). + inner.subscriptions.insert(sub_id, filters); + + // Clone events so we can release the events borrow before borrowing subscriptions. + let events_snapshot = inner.events.clone(); + let stored = inner.subscriptions.get(&sub_id).expect("just inserted"); + events_snapshot + .into_iter() + .filter(|e| { + stored + .iter() + .any(|f| f.match_event(e, MatchEventOptions::default())) + }) + .collect::>() + }; + + for event in replay { + let _ = self.notification_tx.send(make_notification(event)); + } + + Ok(()) + } +} + +// ── Helpers ─────────────────────────────────────────────────────────────────── + +fn sign_with_keys(builder: EventBuilder, keys: &Keys) -> Result { + builder + .sign_with_keys(keys) + .map_err(|e| Error::Transport(e.to_string())) +} + +fn make_notification(event: Event) -> RelayPoolNotification { + RelayPoolNotification::Event { + relay_url: RelayUrl::parse("wss://mock.relay").expect("hardcoded URL"), + subscription_id: SubscriptionId::generate(), + event: Box::new(event), + } +} + +// ── Unit tests ──────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn connect_and_disconnect_are_noops() { + let pool = MockRelayPool::new(); + assert!(pool.connect(&["wss://unused".to_string()]).await.is_ok()); + assert!(pool.disconnect().await.is_ok()); + } + + #[tokio::test] + async fn publish_event_stores_and_broadcasts() { + let pool = MockRelayPool::new(); + let mut rx = pool.notifications(); + + let keys = Keys::generate(); + let event = EventBuilder::new(Kind::TextNote, "hello") + .sign_with_keys(&keys) + .unwrap(); + + pool.publish_event(&event).await.unwrap(); + + assert_eq!(pool.stored_events().await.len(), 1); + let notif = rx.try_recv().unwrap(); + if let RelayPoolNotification::Event { event: e, .. } = notif { + assert_eq!(e.id, event.id); + } else { + panic!("expected Event notification"); + } + } + + #[tokio::test] + async fn publish_signs_and_stores() { + let pool = MockRelayPool::new(); + let builder = EventBuilder::new(Kind::TextNote, "signed"); + pool.publish(builder).await.unwrap(); + let stored = pool.stored_events().await; + assert_eq!(stored.len(), 1); + assert_eq!(stored[0].pubkey, pool.mock_public_key()); + } + + #[tokio::test] + async fn sign_does_not_publish() { + let pool = MockRelayPool::new(); + let builder = EventBuilder::new(Kind::TextNote, "unsigned"); + let event = pool.sign(builder).await.unwrap(); + assert_eq!(event.pubkey, pool.mock_public_key()); + assert!(pool.stored_events().await.is_empty()); + } + + #[tokio::test] + async fn signer_uses_same_key_as_publish() { + let pool = MockRelayPool::new(); + let signer = pool.signer().await.unwrap(); + let expected_pubkey = pool.mock_public_key(); + assert_eq!(signer.get_public_key().await.unwrap(), expected_pubkey); + } + + #[tokio::test] + async fn subscribe_replays_matching_stored_events() { + let pool = MockRelayPool::new(); + let mut rx = pool.notifications(); + + // Pre-publish two events + let keys = Keys::generate(); + let e1 = EventBuilder::new(Kind::TextNote, "one") + .sign_with_keys(&keys) + .unwrap(); + let e2 = EventBuilder::new(Kind::Custom(9999), "two") + .sign_with_keys(&keys) + .unwrap(); + pool.publish_event(&e1).await.unwrap(); + pool.publish_event(&e2).await.unwrap(); + + // Drain the two publish notifications + rx.try_recv().unwrap(); + rx.try_recv().unwrap(); + + // Subscribe for TextNote only — e1 should be replayed, e2 not + let filter = Filter::new().kind(Kind::TextNote); + pool.subscribe(vec![filter]).await.unwrap(); + + let replayed = rx.try_recv().unwrap(); + if let RelayPoolNotification::Event { event, .. } = replayed { + assert_eq!(event.id, e1.id); + } else { + panic!("expected replayed Event notification"); + } + // e2 should not be replayed + assert!(rx.try_recv().is_err()); + } + + #[tokio::test] + async fn notifications_receives_future_publishes() { + let pool = MockRelayPool::new(); + let mut rx = pool.notifications(); + + let keys = Keys::generate(); + let event = EventBuilder::new(Kind::TextNote, "future") + .sign_with_keys(&keys) + .unwrap(); + pool.publish_event(&event).await.unwrap(); + + let notif = rx.try_recv().unwrap(); + assert!(matches!(notif, RelayPoolNotification::Event { .. })); + } +} diff --git a/src/relay/mod.rs b/src/relay/mod.rs index dedc94f..198a5a7 100644 --- a/src/relay/mod.rs +++ b/src/relay/mod.rs @@ -2,10 +2,38 @@ //! //! Wraps nostr-sdk's Client for relay connection, event publishing, and subscription. +pub mod mock; +pub use mock::MockRelayPool; + +use async_trait::async_trait; + use crate::core::error::{Error, Result}; use nostr_sdk::prelude::*; use std::sync::Arc; +/// Trait abstracting relay pool operations, enabling dependency injection and testing. +#[async_trait] +pub trait RelayPoolTrait: Send + Sync { + /// Connect to the given relay URLs. + async fn connect(&self, relay_urls: &[String]) -> Result<()>; + /// Disconnect from all relays. + async fn disconnect(&self) -> Result<()>; + /// Publish a pre-built event to relays. + async fn publish_event(&self, event: &Event) -> Result; + /// Build, sign, and publish an event from a builder. + async fn publish(&self, builder: EventBuilder) -> Result; + /// Sign an event builder without publishing. + async fn sign(&self, builder: EventBuilder) -> Result; + /// Get the signer associated with this relay pool. + async fn signer(&self) -> Result>; + /// Get notifications receiver for event streaming. + fn notifications(&self) -> tokio::sync::broadcast::Receiver; + /// Get the public key of the signer. + async fn public_key(&self) -> Result; + /// Subscribe to events matching filters. + async fn subscribe(&self, filters: Vec) -> Result<()>; +} + /// Relay pool wrapper for managing Nostr relay connections. pub struct RelayPool { client: Arc, @@ -106,3 +134,45 @@ impl RelayPool { Ok(()) } } + +#[async_trait] +impl RelayPoolTrait for RelayPool { + async fn connect(&self, relay_urls: &[String]) -> Result<()> { + RelayPool::connect(self, relay_urls).await + } + + async fn disconnect(&self) -> Result<()> { + RelayPool::disconnect(self).await + } + + async fn publish_event(&self, event: &Event) -> Result { + RelayPool::publish_event(self, event).await + } + + async fn publish(&self, builder: EventBuilder) -> Result { + RelayPool::publish(self, builder).await + } + + async fn sign(&self, builder: EventBuilder) -> Result { + RelayPool::sign(self, builder).await + } + + async fn signer(&self) -> Result> { + self.client + .signer() + .await + .map_err(|e| Error::Other(e.to_string())) + } + + fn notifications(&self) -> tokio::sync::broadcast::Receiver { + RelayPool::notifications(self) + } + + async fn public_key(&self) -> Result { + RelayPool::public_key(self).await + } + + async fn subscribe(&self, filters: Vec) -> Result<()> { + RelayPool::subscribe(self, filters).await + } +} diff --git a/src/rmcp_transport/convert.rs b/src/rmcp_transport/convert.rs new file mode 100644 index 0000000..7df0783 --- /dev/null +++ b/src/rmcp_transport/convert.rs @@ -0,0 +1,247 @@ +//! Conversion boundary between internal JSON-RPC messages and rmcp message types. +//! +//! These helpers intentionally convert via serde JSON to preserve wire-level +//! compatibility and avoid fragile hand-mapping between evolving type systems. + +use crate::core::types::JsonRpcMessage; + +const LOG_TARGET: &str = "contextvm_sdk::rmcp_transport::convert"; + +/// Convert internal JSON-RPC message into rmcp server RX message. +/// +/// Role mapping: +/// - RoleServer RX receives client-originated messages. +pub fn internal_to_rmcp_server_rx( + msg: &JsonRpcMessage, +) -> Option> { + let direction = "internal_to_rmcp_server_rx"; + let value = match serde_json::to_value(msg) { + Ok(value) => value, + Err(error) => { + tracing::error!( + target: LOG_TARGET, + direction = direction, + error = %error, + "Failed to serialize message into intermediate JSON" + ); + return None; + } + }; + + match serde_json::from_value(value.clone()) { + Ok(parsed) => Some(parsed), + Err(error) => { + tracing::error!( + target: LOG_TARGET, + direction = direction, + error = %error, + payload = ?value, + "Failed to parse converted JSON payload" + ); + None + } + } +} + +/// Convert internal JSON-RPC message into rmcp client RX message. +/// +/// Role mapping: +/// - RoleClient RX receives server-originated messages. +pub fn internal_to_rmcp_client_rx( + msg: &JsonRpcMessage, +) -> Option> { + let direction = "internal_to_rmcp_client_rx"; + let value = match serde_json::to_value(msg) { + Ok(value) => value, + Err(error) => { + tracing::error!( + target: LOG_TARGET, + direction = direction, + error = %error, + "Failed to serialize message into intermediate JSON" + ); + return None; + } + }; + + match serde_json::from_value(value.clone()) { + Ok(parsed) => Some(parsed), + Err(error) => { + tracing::error!( + target: LOG_TARGET, + direction = direction, + error = %error, + payload = ?value, + "Failed to parse converted JSON payload" + ); + None + } + } +} + +/// Convert rmcp server TX message back into internal JSON-RPC. +pub fn rmcp_server_tx_to_internal( + msg: rmcp::service::TxJsonRpcMessage, +) -> Option { + let direction = "rmcp_server_tx_to_internal"; + let value = match serde_json::to_value(msg) { + Ok(value) => value, + Err(error) => { + tracing::error!( + target: LOG_TARGET, + direction = direction, + error = %error, + "Failed to serialize message into intermediate JSON" + ); + return None; + } + }; + + match serde_json::from_value(value.clone()) { + Ok(parsed) => Some(parsed), + Err(error) => { + tracing::error!( + target: LOG_TARGET, + direction = direction, + error = %error, + payload = ?value, + "Failed to parse converted JSON payload" + ); + None + } + } +} + +/// Convert rmcp client TX message back into internal JSON-RPC. +pub fn rmcp_client_tx_to_internal( + msg: rmcp::service::TxJsonRpcMessage, +) -> Option { + let direction = "rmcp_client_tx_to_internal"; + let value = match serde_json::to_value(msg) { + Ok(value) => value, + Err(error) => { + tracing::error!( + target: LOG_TARGET, + direction = direction, + error = %error, + "Failed to serialize message into intermediate JSON" + ); + return None; + } + }; + + match serde_json::from_value(value.clone()) { + Ok(parsed) => Some(parsed), + Err(error) => { + tracing::error!( + target: LOG_TARGET, + direction = direction, + error = %error, + payload = ?value, + "Failed to parse converted JSON payload" + ); + None + } + } +} + +#[cfg(all(test, feature = "rmcp"))] +mod tests { + use super::*; + use crate::core::types::{JsonRpcRequest, JsonRpcResponse}; + + #[test] + fn test_internal_request_to_rmcp_server_rx_ping() { + let internal = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "ping".to_string(), + params: None, + }); + + let rmcp_msg = internal_to_rmcp_server_rx(&internal) + .expect("expected conversion to rmcp server rx message"); + let value = serde_json::to_value(rmcp_msg).expect("serialize rmcp message to JSON"); + + assert_eq!(value.get("method"), Some(&serde_json::json!("ping"))); + assert_eq!(value.get("id"), Some(&serde_json::json!(1))); + } + + #[test] + fn test_internal_response_to_rmcp_client_rx_empty_result() { + let internal = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(42), + result: serde_json::json!({}), + }); + + let rmcp_msg = internal_to_rmcp_client_rx(&internal) + .expect("expected conversion to rmcp client rx message"); + let value = serde_json::to_value(rmcp_msg).expect("serialize rmcp message to JSON"); + + assert_eq!(value.get("id"), Some(&serde_json::json!(42))); + assert_eq!(value.get("result"), Some(&serde_json::json!({}))); + } + + #[test] + fn test_rmcp_server_tx_to_internal_response() { + let rmcp_msg = rmcp::model::ServerJsonRpcMessage::response( + rmcp::model::ServerResult::empty(()), + rmcp::model::RequestId::Number(7), + ); + + let internal = rmcp_server_tx_to_internal(rmcp_msg) + .expect("expected conversion from rmcp server tx to internal JSON-RPC"); + + match internal { + JsonRpcMessage::Response(resp) => { + assert_eq!(resp.id, serde_json::json!(7)); + assert_eq!(resp.result, serde_json::json!({})); + } + other => panic!("expected internal response, got {other:?}"), + } + } + + #[test] + fn test_rmcp_client_tx_to_internal_response() { + let rmcp_msg = rmcp::model::ClientJsonRpcMessage::response( + rmcp::model::ClientResult::empty(()), + rmcp::model::RequestId::Number(9), + ); + + let internal = rmcp_client_tx_to_internal(rmcp_msg) + .expect("expected conversion from rmcp client tx to internal JSON-RPC"); + + match internal { + JsonRpcMessage::Response(resp) => { + assert_eq!(resp.id, serde_json::json!(9)); + assert_eq!(resp.result, serde_json::json!({})); + } + other => panic!("expected internal response, got {other:?}"), + } + } + + #[test] + fn test_server_rx_roundtrip_preserves_wire_shape() { + let internal = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("abc"), + method: "ping".to_string(), + params: None, + }); + + let rmcp_msg = internal_to_rmcp_server_rx(&internal) + .expect("expected conversion to rmcp server rx message"); + let value = serde_json::to_value(rmcp_msg).expect("serialize rmcp message to JSON"); + let roundtrip_internal: JsonRpcMessage = + serde_json::from_value(value).expect("deserialize back to internal JSON-RPC"); + + match roundtrip_internal { + JsonRpcMessage::Request(req) => { + assert_eq!(req.id, serde_json::json!("abc")); + assert_eq!(req.method, "ping"); + } + other => panic!("expected internal request, got {other:?}"), + } + } +} diff --git a/src/rmcp_transport/mod.rs b/src/rmcp_transport/mod.rs new file mode 100644 index 0000000..436f1dd --- /dev/null +++ b/src/rmcp_transport/mod.rs @@ -0,0 +1,17 @@ +//! rmcp integration for ContextVM Nostr transports. +//! +//! This module contains the conversion helpers and worker bridge that let raw +//! ContextVM transports plug directly into rmcp service APIs. + +pub mod convert; +pub mod transport; +pub mod worker; + +#[cfg(test)] +mod pipeline_tests; + +pub use convert::{ + internal_to_rmcp_client_rx, internal_to_rmcp_server_rx, rmcp_client_tx_to_internal, + rmcp_server_tx_to_internal, +}; +pub use worker::{NostrClientWorker, NostrServerWorker}; diff --git a/src/rmcp_transport/pipeline_tests.rs b/src/rmcp_transport/pipeline_tests.rs new file mode 100644 index 0000000..ff4844e --- /dev/null +++ b/src/rmcp_transport/pipeline_tests.rs @@ -0,0 +1,497 @@ +//! End-to-end pipeline tests for the rmcp ↔ Nostr transport integration. +//! +//! These tests verify every step of the message journey without requiring a live +//! relay connection: +//! +//! ```text +//! Nostr event content (JSON string) +//! → serializers::nostr_event_to_mcp_message [Layer 1: deserialise] +//! → internal_to_rmcp_server_rx [Layer 2: type bridge] +//! → (rmcp handler processes it) [Layer 3: rmcp dispatch – simulated] +//! → rmcp_server_tx_to_internal [Layer 4: type bridge back] +//! → send_response (event_id correlation) [Layer 5: route back to Nostr – mocked] +//! ``` + +#[cfg(all(test, feature = "rmcp"))] +mod tests { + use std::sync::Arc; + + use rmcp::model::{ + CallToolRequestParams, CallToolResult, ClientJsonRpcMessage, ClientResult, ErrorData, + Implementation, ProtocolVersion, RequestId, ServerCapabilities, ServerInfo, + ServerJsonRpcMessage, ServerResult, + }; + use rmcp::{ + handler::server::{router::tool::ToolRouter, wrapper::Parameters}, + schemars, tool, tool_handler, tool_router, ClientHandler, ServerHandler, ServiceExt, + }; + + use crate::core::serializers; + use crate::core::types::{EncryptionMode, GiftWrapMode}; + use crate::core::types::{ + JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, + }; + use crate::relay::mock::MockRelayPool; + use crate::relay::RelayPoolTrait; + use crate::rmcp_transport::convert::{ + internal_to_rmcp_client_rx, internal_to_rmcp_server_rx, rmcp_client_tx_to_internal, + rmcp_server_tx_to_internal, + }; + use crate::transport::{ + client::{NostrClientTransport, NostrClientTransportConfig}, + server::{NostrServerTransport, NostrServerTransportConfig}, + }; + + #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] + struct EchoParams { + message: String, + } + + #[derive(Clone)] + struct StatelessTestServer { + tool_router: ToolRouter, + } + + impl StatelessTestServer { + fn new() -> Self { + Self { + tool_router: Self::tool_router(), + } + } + } + + #[tool_router] + impl StatelessTestServer { + #[tool(description = "Echo a message back unchanged")] + async fn echo( + &self, + Parameters(EchoParams { message }): Parameters, + ) -> Result { + Ok(CallToolResult::success(vec![rmcp::model::Content::text( + format!("Echo: {message}"), + )])) + } + } + + #[tool_handler] + impl ServerHandler for StatelessTestServer { + fn get_info(&self) -> ServerInfo { + ServerInfo { + protocol_version: ProtocolVersion::LATEST, + capabilities: ServerCapabilities::builder().enable_tools().build(), + server_info: Implementation { + name: "stateless-test-server".to_string(), + title: Some("Stateless Test Server".to_string()), + version: "0.1.0".to_string(), + description: Some("Stateless rmcp regression test server".to_string()), + icons: None, + website_url: None, + }, + instructions: Some("Use the echo tool".to_string()), + } + } + } + + #[derive(Clone, Default)] + struct StatelessTestClient; + impl ClientHandler for StatelessTestClient {} + + // ── Layer 1: Nostr event content → JsonRpcMessage ────────────────────── + + #[test] + fn layer1_nostr_content_to_internal_request() { + let content = r#"{"jsonrpc":"2.0","id":1,"method":"ping","params":{}}"#; + let msg = serializers::nostr_event_to_mcp_message(content) + .expect("valid MCP request should parse"); + + assert!(msg.is_request()); + assert_eq!(msg.method(), Some("ping")); + assert_eq!(msg.id(), Some(&serde_json::json!(1))); + } + + #[test] + fn layer1_nostr_content_to_internal_tools_list() { + let content = r#"{"jsonrpc":"2.0","id":"abc","method":"tools/list","params":{}}"#; + let msg = serializers::nostr_event_to_mcp_message(content).unwrap(); + assert_eq!(msg.method(), Some("tools/list")); + assert_eq!(msg.id(), Some(&serde_json::json!("abc"))); + } + + #[test] + fn layer1_nostr_content_to_internal_notification() { + let content = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#; + let msg = serializers::nostr_event_to_mcp_message(content).unwrap(); + assert!(!msg.is_request()); + assert_eq!(msg.method(), Some("notifications/initialized")); + } + + #[test] + fn layer1_nostr_content_invalid_json_returns_none() { + assert!(serializers::nostr_event_to_mcp_message("not json").is_none()); + } + + #[test] + fn layer1_nostr_event_to_mcp_message_no_version_check() { + // DESIGN NOTE: nostr_event_to_mcp_message uses raw serde deserialization — + // it does NOT reject invalid jsonrpc versions. Version enforcement happens + // one layer up in base.rs via validate_message(), which IS tested separately + // in core::validation::tests::test_invalid_version and + // transport::base::tests::test_convert_event_to_mcp_invalid_jsonrpc_version. + // + // A message with jsonrpc "1.0" will parse successfully at the serializer + // layer because JsonRpcRequest accepts any String for the jsonrpc field. + let content = r#"{"jsonrpc":"1.0","id":1,"method":"ping"}"#; + // It parses — the struct captures jsonrpc as a plain String. + let msg = serializers::nostr_event_to_mcp_message(content); + // We don't assert None here; rejection happens in base.rs, not here. + // What we DO assert: if it parsed, the method and id are intact. + if let Some(msg) = msg { + assert_eq!(msg.method(), Some("ping")); + } + // The real rejection path is covered by: + // transport::base::tests::test_convert_event_to_mcp_invalid_jsonrpc_version + } + + // ── Layer 2: JsonRpcMessage → rmcp RxJsonRpcMessage (server) ─────────── + + #[test] + fn layer2_internal_request_converts_to_rmcp_server_rx() { + let msg = make_request("ping", serde_json::json!(1), None); + let rmcp = internal_to_rmcp_server_rx(&msg).expect("ping should convert"); + + let v = serde_json::to_value(&rmcp).unwrap(); + assert_eq!(v["method"], "ping"); + assert_eq!(v["id"], serde_json::json!(1)); + assert_eq!(v["jsonrpc"], "2.0"); + } + + #[test] + fn layer2_string_id_preserved_through_bridge() { + let msg = make_request("tools/list", serde_json::json!("req-xyz"), None); + let rmcp = internal_to_rmcp_server_rx(&msg).unwrap(); + + let v = serde_json::to_value(&rmcp).unwrap(); + assert_eq!(v["id"], serde_json::json!("req-xyz")); + } + + #[test] + fn layer2_notification_converts_to_rmcp_server_rx() { + let msg = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/initialized".to_string(), + params: None, + }); + let rmcp = + internal_to_rmcp_server_rx(&msg).expect("initialized notification should convert"); + let v = serde_json::to_value(&rmcp).unwrap(); + assert_eq!(v["method"], "notifications/initialized"); + } + + #[test] + fn layer2_tools_list_with_params_converts() { + let msg = make_request( + "tools/list", + serde_json::json!(7), + Some(serde_json::json!({"cursor": "next-page"})), + ); + let rmcp = internal_to_rmcp_server_rx(&msg).unwrap(); + let v = serde_json::to_value(&rmcp).unwrap(); + assert_eq!(v["method"], "tools/list"); + assert_eq!(v["params"]["cursor"], "next-page"); + } + + // ── Layer 3+4: Simulated handler → rmcp response → internal ──────────── + + #[test] + fn layer4_rmcp_ping_response_roundtrip_number_id() { + // Simulate rmcp handler producing a ping response + let rmcp_response = + ServerJsonRpcMessage::response(ServerResult::empty(()), RequestId::Number(42)); + let internal = + rmcp_server_tx_to_internal(rmcp_response).expect("ping response should convert back"); + + match internal { + JsonRpcMessage::Response(r) => { + assert_eq!(r.id, serde_json::json!(42)); + assert_eq!(r.jsonrpc, "2.0"); + } + other => panic!("expected Response, got {other:?}"), + } + } + + #[test] + fn layer4_rmcp_ping_response_roundtrip_string_id() { + let rmcp_response = ServerJsonRpcMessage::response( + ServerResult::empty(()), + RequestId::String(std::sync::Arc::from("req-xyz")), + ); + let internal = rmcp_server_tx_to_internal(rmcp_response).unwrap(); + + match internal { + JsonRpcMessage::Response(r) => { + assert_eq!(r.id, serde_json::json!("req-xyz")); + } + other => panic!("expected Response, got {other:?}"), + } + } + + // ── Full roundtrip: internal → rmcp → internal ────────────────────────── + + #[test] + fn full_server_roundtrip_request_id_preserved() { + // Layer 2: convert incoming request to rmcp + let original = make_request("ping", serde_json::json!(99), None); + let rmcp_rx = internal_to_rmcp_server_rx(&original).unwrap(); + + // Extract the ID that rmcp sees + let rmcp_value = serde_json::to_value(&rmcp_rx).unwrap(); + let id_seen_by_rmcp = rmcp_value["id"].clone(); + assert_eq!(id_seen_by_rmcp, serde_json::json!(99)); + + // Layer 4: rmcp produces a response with the same ID echoed back + let rmcp_tx = + ServerJsonRpcMessage::response(ServerResult::empty(()), RequestId::Number(99)); + let response = rmcp_server_tx_to_internal(rmcp_tx).unwrap(); + + // The response ID must equal the original request ID + assert_eq!(response.id(), Some(&serde_json::json!(99))); + } + + #[test] + fn full_client_roundtrip_response_id_preserved() { + // Client side: rmcp produces an outbound request + let rmcp_tx = ClientJsonRpcMessage::response(ClientResult::empty(()), RequestId::Number(7)); + let internal = rmcp_client_tx_to_internal(rmcp_tx).unwrap(); + assert_eq!(internal.id(), Some(&serde_json::json!(7))); + + // And an incoming server response converts to rmcp correctly + let incoming_response = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(7), + result: serde_json::json!({"tools": []}), + }); + let rmcp_rx = internal_to_rmcp_client_rx(&incoming_response).unwrap(); + let v = serde_json::to_value(&rmcp_rx).unwrap(); + assert_eq!(v["id"], serde_json::json!(7)); + assert_eq!(v["result"]["tools"], serde_json::json!([])); + } + + // ── Layer 5: event_id-based request correlation (mirrors NostrServerWorker) ── + + #[test] + fn layer5_worker_uses_event_id_as_request_id() { + // Simulate the worker rewriting req.id to the Nostr event_id. + let event_id = "abc123def456"; + let mut req = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(42), + method: "tools/list".to_string(), + params: None, + }; + + // Worker inbound path: rewrite id to event_id + req.id = serde_json::json!(event_id); + assert_eq!(req.id, serde_json::json!("abc123def456")); + + // Convert through rmcp bridge — ID must survive the roundtrip + let msg = JsonRpcMessage::Request(req); + let rmcp_rx = internal_to_rmcp_server_rx(&msg).unwrap(); + let v = serde_json::to_value(&rmcp_rx).unwrap(); + assert_eq!(v["id"], serde_json::json!("abc123def456")); + + // Simulate rmcp handler echoing the event_id back in the response + let rmcp_tx = ServerJsonRpcMessage::response( + ServerResult::empty(()), + RequestId::String(std::sync::Arc::from(event_id)), + ); + let response = rmcp_server_tx_to_internal(rmcp_tx).unwrap(); + + // The response ID is the event_id — worker passes it directly to send_response + match response { + JsonRpcMessage::Response(r) => { + assert_eq!(r.id.as_str(), Some(event_id)); + } + other => panic!("expected Response, got {other:?}"), + } + } + + #[test] + fn layer5_worker_two_clients_no_collision() { + // Two clients both send requests with id: 1. The worker rewrites each + // to its unique Nostr event_id, so no collision occurs. + let event_id_a = "aaaa1111aaaa1111aaaa1111aaaa1111aaaa1111aaaa1111aaaa1111aaaa1111"; + let event_id_b = "bbbb2222bbbb2222bbbb2222bbbb2222bbbb2222bbbb2222bbbb2222bbbb2222"; + + let mut req_a = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "tools/list".to_string(), + params: None, + }; + let mut req_b = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "tools/list".to_string(), + params: None, + }; + + // Worker rewrites both to their respective event IDs + req_a.id = serde_json::json!(event_id_a); + req_b.id = serde_json::json!(event_id_b); + + // After rewrite, the IDs are distinct even though both clients sent id: 1 + assert_ne!(req_a.id, req_b.id); + assert_eq!(req_a.id.as_str(), Some(event_id_a)); + assert_eq!(req_b.id.as_str(), Some(event_id_b)); + + // Responses echo back the event_id — each routes to the correct client + let rmcp_resp_a = ServerJsonRpcMessage::response( + ServerResult::empty(()), + RequestId::String(std::sync::Arc::from(event_id_a)), + ); + let rmcp_resp_b = ServerJsonRpcMessage::response( + ServerResult::empty(()), + RequestId::String(std::sync::Arc::from(event_id_b)), + ); + + let resp_a = rmcp_server_tx_to_internal(rmcp_resp_a).unwrap(); + let resp_b = rmcp_server_tx_to_internal(rmcp_resp_b).unwrap(); + + // Each response carries its own event_id — no cross-wiring + assert_eq!(resp_a.id().unwrap().as_str(), Some(event_id_a)); + assert_eq!(resp_b.id().unwrap().as_str(), Some(event_id_b)); + } + + #[test] + fn layer5_error_response_carries_event_id() { + // Error responses also carry the event_id for routing. + let event_id = "deadbeef"; + let mut req = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(5), + method: "tools/call".to_string(), + params: None, + }; + req.id = serde_json::json!(event_id); + + // rmcp handler returns an error with the rewritten event_id + let rmcp_err = ServerJsonRpcMessage::error( + rmcp::model::ErrorData { + code: rmcp::model::ErrorCode::METHOD_NOT_FOUND, + message: "Method not found".into(), + data: None, + }, + RequestId::String(std::sync::Arc::from(event_id)), + ); + let internal = rmcp_server_tx_to_internal(rmcp_err).unwrap(); + + match internal { + JsonRpcMessage::ErrorResponse(r) => { + assert_eq!(r.id.as_str(), Some(event_id)); + } + other => panic!("expected ErrorResponse, got {other:?}"), + } + } + + #[tokio::test] + async fn stateless_rmcp_roundtrip_over_mock_relay_preserves_correlation() { + let (server_pool, client_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool + .public_key() + .await + .expect("server mock relay pubkey") + .to_hex(); + + let server_transport = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default() + .with_relay_urls(vec!["mock://relay".to_string()]) + .with_encryption_mode(EncryptionMode::Disabled) + .with_gift_wrap_mode(GiftWrapMode::Optional), + Arc::new(server_pool), + ) + .await + .expect("server transport"); + + let server_task = tokio::spawn(async move { + StatelessTestServer::new() + .serve(server_transport) + .await + .expect("server should start") + .waiting() + .await + .expect("server should keep running until aborted"); + }); + + let client_transport = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_relay_urls(vec!["mock://relay".to_string()]) + .with_server_pubkey(server_pubkey) + .with_encryption_mode(EncryptionMode::Disabled) + .with_gift_wrap_mode(GiftWrapMode::Optional) + .with_stateless(true), + Arc::new(client_pool), + ) + .await + .expect("client transport"); + + let client = StatelessTestClient + .serve(client_transport) + .await + .expect("stateless client should start"); + + let peer_info = client + .peer_info() + .expect("peer info from emulated initialize"); + assert_eq!(peer_info.server_info.name, "Emulated-Stateless-Server"); + + let tools = client + .list_all_tools() + .await + .expect("tools/list should succeed"); + assert!( + tools.iter().any(|tool| tool.name == "echo"), + "expected echo tool from server" + ); + + let result = client + .call_tool(CallToolRequestParams { + name: "echo".into(), + arguments: serde_json::from_value(serde_json::json!({ + "message": "hello from stateless test" + })) + .ok(), + meta: None, + task: None, + }) + .await + .expect("tools/call should succeed"); + + let echoed = result + .content + .iter() + .find_map(|content| match &content.raw { + rmcp::model::RawContent::Text(text) => Some(text.text.clone()), + _ => None, + }) + .expect("echo response text"); + assert_eq!(echoed, "Echo: hello from stateless test"); + + client.cancel().await.expect("client cancel"); + server_task.abort(); + } + + // ── Helper ────────────────────────────────────────────────────────────── + + fn make_request( + method: &str, + id: serde_json::Value, + params: Option, + ) -> JsonRpcMessage { + JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id, + method: method.to_string(), + params, + }) + } +} diff --git a/src/rmcp_transport/transport.rs b/src/rmcp_transport/transport.rs new file mode 100644 index 0000000..29e8fb4 --- /dev/null +++ b/src/rmcp_transport/transport.rs @@ -0,0 +1,31 @@ +//! rmcp transport integration for raw ContextVM Nostr transports. + +use crate::{ + core::error::Error, + rmcp_transport::worker::{NostrClientWorker, NostrServerWorker}, + transport::{client::NostrClientTransport, server::NostrServerTransport}, +}; + +impl rmcp::transport::IntoTransport + for NostrServerTransport +{ + /// Convert the raw server transport into rmcp's transport model via the + /// worker bridge. + fn into_transport( + self, + ) -> impl rmcp::transport::Transport + 'static { + NostrServerWorker::from_transport(self).into_transport() + } +} + +impl rmcp::transport::IntoTransport + for NostrClientTransport +{ + /// Convert the raw client transport into rmcp's transport model via the + /// worker bridge. + fn into_transport( + self, + ) -> impl rmcp::transport::Transport + 'static { + NostrClientWorker::from_transport(self).into_transport() + } +} diff --git a/src/rmcp_transport/worker.rs b/src/rmcp_transport/worker.rs new file mode 100644 index 0000000..023354f --- /dev/null +++ b/src/rmcp_transport/worker.rs @@ -0,0 +1,520 @@ +//! rmcp worker adapters. +//! +//! This file defines wrapper types that bind existing ContextVM Nostr +//! transports to rmcp's worker abstraction. + +use crate::core::error::Result; +use crate::core::types::{JsonRpcMessage, JsonRpcNotification, JsonRpcRequest}; +use crate::transport::client::{NostrClientTransport, NostrClientTransportConfig}; +use crate::transport::server::{NostrServerTransport, NostrServerTransportConfig}; +use rmcp::transport::worker::{Worker, WorkerContext, WorkerQuitReason}; +use std::collections::HashSet; + +use super::convert::{ + internal_to_rmcp_client_rx, internal_to_rmcp_server_rx, rmcp_client_tx_to_internal, + rmcp_server_tx_to_internal, +}; + +const LOG_TARGET: &str = "contextvm_sdk::rmcp_transport::worker"; +const STATELESS_SYNTHETIC_EVENT_ID: &str = "contextvm-stateless-init"; + +fn synthetic_initialize_message() -> JsonRpcMessage { + JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(STATELESS_SYNTHETIC_EVENT_ID), + method: "initialize".to_string(), + params: Some(serde_json::json!({ + "protocolVersion": crate::core::constants::mcp_protocol_version(), + "capabilities": {}, + "clientInfo": { + "name": "contextvm-stateless-client", + "version": "0.1.0" + } + })), + }) +} + +fn synthetic_initialized_notification() -> JsonRpcMessage { + JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/initialized".to_string(), + params: None, + }) +} + +fn should_inject_stateless_bootstrap( + initialized_clients: &HashSet, + client_pubkey: &str, + message: &JsonRpcMessage, +) -> bool { + if initialized_clients.contains(client_pubkey) { + return false; + } + + matches!(message, JsonRpcMessage::Request(req) if req.method != "initialize") +} + +fn is_synthetic_initialize_message(message: &JsonRpcMessage) -> bool { + matches!( + message, + JsonRpcMessage::Request(req) + if req.method == "initialize" + && req.id == serde_json::json!(STATELESS_SYNTHETIC_EVENT_ID) + ) +} + +/// rmcp server worker wrapper for ContextVM Nostr server transport. +/// +/// Multiplexes all connected clients through a single rmcp service instance. +/// Inbound requests have their JSON-RPC `id` rewritten to the Nostr `event_id` +/// before being forwarded to the rmcp handler. Since event IDs are globally +/// unique (SHA-256 hashes), this eliminates collisions when different clients +/// use the same JSON-RPC request IDs. The transport's event-route store +/// handles response routing back to the originating client; server-initiated +/// notifications are broadcast to all initialized clients. +pub struct NostrServerWorker { + transport: NostrServerTransport, +} + +impl NostrServerWorker { + /// Create a new server worker from existing server transport config. + pub async fn new(signer: T, config: NostrServerTransportConfig) -> Result + where + T: nostr_sdk::prelude::IntoNostrSigner, + { + let transport = NostrServerTransport::new(signer, config).await?; + Ok(Self { transport }) + } + + /// Create a worker from an already-constructed raw transport. + pub fn from_transport(transport: NostrServerTransport) -> Self { + Self { transport } + } + + /// Access the wrapped transport. + pub fn transport(&self) -> &NostrServerTransport { + &self.transport + } +} + +impl Worker for NostrServerWorker { + type Error = crate::core::error::Error; + type Role = rmcp::RoleServer; + + fn err_closed() -> Self::Error { + Self::Error::Transport("rmcp worker channel closed".to_string()) + } + + fn err_join(e: tokio::task::JoinError) -> Self::Error { + Self::Error::Other(format!("rmcp worker join error: {e}")) + } + + async fn run( + mut self, + mut context: WorkerContext, + ) -> std::result::Result<(), WorkerQuitReason> { + self.transport + .start() + .await + .map_err(WorkerQuitReason::fatal_context("starting server transport"))?; + + let mut rx = self.transport.take_message_receiver().ok_or_else(|| { + WorkerQuitReason::fatal( + Self::Error::Other("server message receiver already taken".to_string()), + "taking server message receiver", + ) + })?; + + let cancellation_token = context.cancellation_token.clone(); + let mut initialized_clients = HashSet::new(); + + let quit_reason = loop { + tokio::select! { + _ = cancellation_token.cancelled() => { + break WorkerQuitReason::Cancelled; + } + incoming = rx.recv() => { + let Some(incoming) = incoming else { + break WorkerQuitReason::TransportClosed; + }; + + let crate::transport::server::IncomingRequest { + mut message, + event_id, + client_pubkey, + .. + } = incoming; + + let should_inject_bootstrap = should_inject_stateless_bootstrap( + &initialized_clients, + &client_pubkey, + &message, + ); + + if should_inject_bootstrap { + let synthetic_init = synthetic_initialize_message(); + let Some(rmcp_init) = internal_to_rmcp_server_rx(&synthetic_init) else { + break WorkerQuitReason::fatal( + Self::Error::Validation( + "failed converting synthetic initialize request to rmcp format".to_string(), + ), + "converting synthetic initialize request", + ); + }; + + if let Err(reason) = context.send_to_handler(rmcp_init).await { + break reason; + } + + let initialized = synthetic_initialized_notification(); + let Some(rmcp_initialized) = internal_to_rmcp_server_rx(&initialized) else { + break WorkerQuitReason::fatal( + Self::Error::Validation( + "failed converting synthetic initialized notification to rmcp format".to_string(), + ), + "converting synthetic initialized notification", + ); + }; + + if let Err(reason) = context.send_to_handler(rmcp_initialized).await { + break reason; + } + + initialized_clients.insert(client_pubkey.clone()); + } + + if matches!(&message, JsonRpcMessage::Request(req) if req.method == "initialize") + || matches!(&message, JsonRpcMessage::Notification(n) if n.method == "notifications/initialized") + { + initialized_clients.insert(client_pubkey.clone()); + } + + // Rewrite real wire requests to the Nostr event_id. + // Synthetic stateless bootstrap messages must retain their + // sentinel ID so their responses can be dropped before they + // ever touch transport correlation. + if !is_synthetic_initialize_message(&message) { + if let JsonRpcMessage::Request(ref mut req) = message { + req.id = serde_json::json!(event_id); + } + } + + if let Some(rmcp_msg) = internal_to_rmcp_server_rx(&message) { + if let Err(reason) = context.send_to_handler(rmcp_msg).await { + break reason; + } + } else { + tracing::warn!( + target: LOG_TARGET, + "Failed to convert incoming server-side message to rmcp format" + ); + } + } + outbound = context.recv_from_handler() => { + let outbound = match outbound { + Ok(outbound) => outbound, + Err(reason) => break reason, + }; + + let result = if let Some(internal_msg) = rmcp_server_tx_to_internal(outbound.message) { + self.forward_server_internal(internal_msg).await + } else { + Err(Self::Error::Validation( + "failed converting rmcp server message to internal JSON-RPC".to_string(), + )) + }; + + let _ = outbound.responder.send(result); + } + } + }; + + if let Err(e) = self.transport.close().await { + tracing::warn!( + target: LOG_TARGET, + error = %e, + "Failed to close server transport cleanly" + ); + } + + Err(quit_reason) + } +} + +/// rmcp client worker wrapper for ContextVM Nostr client transport. +pub struct NostrClientWorker { + transport: NostrClientTransport, +} + +impl NostrClientWorker { + /// Create a new client worker from existing client transport config. + pub async fn new(signer: T, config: NostrClientTransportConfig) -> Result + where + T: nostr_sdk::prelude::IntoNostrSigner, + { + let transport = NostrClientTransport::new(signer, config).await?; + Ok(Self { transport }) + } + + /// Create a worker from an already-constructed raw transport. + pub fn from_transport(transport: NostrClientTransport) -> Self { + Self { transport } + } + + /// Access the wrapped transport. + pub fn transport(&self) -> &NostrClientTransport { + &self.transport + } +} + +impl Worker for NostrClientWorker { + type Error = crate::core::error::Error; + type Role = rmcp::RoleClient; + + fn err_closed() -> Self::Error { + Self::Error::Transport("rmcp worker channel closed".to_string()) + } + + fn err_join(e: tokio::task::JoinError) -> Self::Error { + Self::Error::Other(format!("rmcp worker join error: {e}")) + } + + async fn run( + mut self, + mut context: WorkerContext, + ) -> std::result::Result<(), WorkerQuitReason> { + self.transport + .start() + .await + .map_err(WorkerQuitReason::fatal_context("starting client transport"))?; + + let mut rx = self.transport.take_message_receiver().ok_or_else(|| { + WorkerQuitReason::fatal( + Self::Error::Other("client message receiver already taken".to_string()), + "taking client message receiver", + ) + })?; + + let cancellation_token = context.cancellation_token.clone(); + + let quit_reason = loop { + tokio::select! { + _ = cancellation_token.cancelled() => { + break WorkerQuitReason::Cancelled; + } + incoming = rx.recv() => { + let Some(incoming) = incoming else { + break WorkerQuitReason::TransportClosed; + }; + + if let Some(rmcp_msg) = internal_to_rmcp_client_rx(&incoming) { + if let Err(reason) = context.send_to_handler(rmcp_msg).await { + break reason; + } + } else { + tracing::warn!( + target: LOG_TARGET, + "Failed to convert incoming client-side message to rmcp format" + ); + } + } + outbound = context.recv_from_handler() => { + let outbound = match outbound { + Ok(outbound) => outbound, + Err(reason) => break reason, + }; + + let result = if let Some(internal_msg) = rmcp_client_tx_to_internal(outbound.message) { + self.transport.send(&internal_msg).await + } else { + Err(Self::Error::Validation( + "failed converting rmcp client message to internal JSON-RPC".to_string(), + )) + }; + + let _ = outbound.responder.send(result); + } + } + }; + + if let Err(e) = self.transport.close().await { + tracing::warn!( + target: LOG_TARGET, + error = %e, + "Failed to close client transport cleanly" + ); + } + + Err(quit_reason) + } +} + +impl NostrServerWorker { + /// Forward an outbound message from the rmcp handler to the Nostr transport. + /// + /// Response IDs carry the Nostr event_id set during ingest. The transport's + /// `send_response` uses this to look up the route (client_pubkey + + /// original_request_id) and deliver the response to the correct client. + /// Notifications and server-initiated requests are broadcast to all + /// initialized clients. + async fn forward_server_internal(&mut self, message: JsonRpcMessage) -> Result<()> { + match message { + JsonRpcMessage::Response(resp) => { + let event_id = resp.id.as_str().map(str::to_owned).ok_or_else(|| { + crate::core::error::Error::Validation( + "rmcp server response id is not a string event_id".to_string(), + ) + })?; + + if event_id == STATELESS_SYNTHETIC_EVENT_ID { + tracing::debug!( + target: LOG_TARGET, + event_id = %event_id, + "Dropping synthetic initialize response before wire transport" + ); + return Ok(()); + } + + self.transport + .send_response(&event_id, JsonRpcMessage::Response(resp)) + .await + } + JsonRpcMessage::ErrorResponse(resp) => { + let event_id = resp.id.as_str().map(str::to_owned).ok_or_else(|| { + crate::core::error::Error::Validation( + "rmcp server error response id is not a string event_id".to_string(), + ) + })?; + + if event_id == STATELESS_SYNTHETIC_EVENT_ID { + tracing::debug!( + target: LOG_TARGET, + event_id = %event_id, + "Dropping synthetic initialize error before wire transport" + ); + return Ok(()); + } + + self.transport + .send_response(&event_id, JsonRpcMessage::ErrorResponse(resp)) + .await + } + JsonRpcMessage::Notification(notification) => { + let message = JsonRpcMessage::Notification(notification); + self.transport.broadcast_notification(&message).await + } + JsonRpcMessage::Request(request) => { + let message = JsonRpcMessage::Request(request); + self.transport.broadcast_notification(&message).await + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::types::JsonRpcResponse; + + #[test] + fn test_should_inject_stateless_bootstrap_for_first_non_initialize_request() { + let initialized_clients = HashSet::new(); + let message = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "tools/list".to_string(), + params: Some(serde_json::json!({})), + }); + + assert!(should_inject_stateless_bootstrap( + &initialized_clients, + "client-a", + &message, + )); + } + + #[test] + fn test_should_not_inject_stateless_bootstrap_for_real_initialize() { + let initialized_clients = HashSet::new(); + let message = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: Some(serde_json::json!({})), + }); + + assert!(!should_inject_stateless_bootstrap( + &initialized_clients, + "client-a", + &message, + )); + } + + #[test] + fn test_synthetic_initialize_keeps_sentinel_id() { + let message = synthetic_initialize_message(); + + match message { + JsonRpcMessage::Request(req) => { + assert_eq!(req.id, serde_json::json!(STATELESS_SYNTHETIC_EVENT_ID)); + assert_eq!(req.method, "initialize"); + } + other => panic!("expected request, got {other:?}"), + } + } + + #[test] + fn test_real_request_is_rewritten_to_event_id() { + let mut message = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "tools/list".to_string(), + params: Some(serde_json::json!({})), + }); + + if let JsonRpcMessage::Request(ref mut req) = message { + req.id = serde_json::json!("real-event-id"); + } + + match message { + JsonRpcMessage::Request(req) => { + assert_eq!(req.id, serde_json::json!("real-event-id")); + } + other => panic!("expected request, got {other:?}"), + } + } + + #[test] + fn test_synthetic_initialize_response_uses_sentinel_for_drop() { + let message = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(STATELESS_SYNTHETIC_EVENT_ID), + result: serde_json::json!({}), + }); + + match message { + JsonRpcMessage::Response(resp) => { + assert_eq!(resp.id.as_str(), Some(STATELESS_SYNTHETIC_EVENT_ID)); + } + other => panic!("expected response, got {other:?}"), + } + } + + #[test] + fn test_synthetic_initialized_notification_shape() { + let message = synthetic_initialized_notification(); + match message { + JsonRpcMessage::Notification(notification) => { + assert_eq!(notification.method, "notifications/initialized"); + } + other => panic!("expected notification, got {other:?}"), + } + } + + #[test] + fn test_is_synthetic_initialize_message_detects_sentinel() { + assert!(is_synthetic_initialize_message( + &synthetic_initialize_message() + )); + } +} diff --git a/src/transport/base.rs b/src/transport/base.rs index 0c19a4a..bacb9c0 100644 --- a/src/transport/base.rs +++ b/src/transport/base.rs @@ -9,7 +9,9 @@ use crate::core::serializers; use crate::core::types::{EncryptionMode, JsonRpcMessage}; use crate::core::validation; use crate::encryption; -use crate::relay::RelayPool; +use crate::relay::RelayPoolTrait; + +const LOG_TARGET: &str = "contextvm_sdk::transport::base"; /// Shared transport logic for both client and server. /// @@ -18,7 +20,7 @@ use crate::relay::RelayPool; /// and [`NostrServerTransport`](super::server::NostrServerTransport). pub struct BaseTransport { /// The relay pool for publishing and subscribing to Nostr events. - pub relay_pool: Arc, + pub relay_pool: Arc, /// The encryption policy for outgoing messages. pub encryption_mode: EncryptionMode, /// Whether the transport is currently connected to relays. @@ -53,39 +55,39 @@ impl BaseTransport { /// Subscribe to events targeting a pubkey (both regular and encrypted). /// - /// Uses two filters: one for ephemeral ContextVM messages (kind 25910) - /// with `since: now()`, and one for NIP-59 gift wraps (kind 1059) without - /// a `since` constraint. Gift wraps use randomized timestamps per NIP-59, - /// so a `since: now()` filter would reject most incoming encrypted messages. + /// Uses three filters: one for ephemeral ContextVM messages (kind 25910) + /// and two for NIP-59 gift wraps (kinds 1059 and 21059). pub async fn subscribe_for_pubkey(&self, pubkey: &PublicKey) -> Result<()> { let p_tag = pubkey.to_hex(); + let now = Timestamp::now(); - // Ephemeral ContextVM messages — safe to use since:now() let ephemeral_filter = Filter::new() .kind(Kind::Custom(CTXVM_MESSAGES_KIND)) .custom_tag(SingleLetterTag::lowercase(Alphabet::P), p_tag.clone()) - .since(Timestamp::now()); + .since(now); - // NIP-59 gift wraps — timestamps are randomized (up to ±48h or more), - // so we must NOT use since:now(). Limit to recent window instead. - let two_days_ago = Timestamp::from(Timestamp::now().as_u64().saturating_sub(2 * 24 * 3600)); let gift_wrap_filter = Filter::new() .kind(Kind::Custom(GIFT_WRAP_KIND)) - .custom_tag(SingleLetterTag::lowercase(Alphabet::P), p_tag) - .since(two_days_ago); + .custom_tag(SingleLetterTag::lowercase(Alphabet::P), p_tag.clone()) + .since(now); - self.relay_pool.subscribe(vec![ephemeral_filter, gift_wrap_filter]).await + let ephemeral_gift_wrap_filter = Filter::new() + .kind(Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND)) + .custom_tag(SingleLetterTag::lowercase(Alphabet::P), p_tag) + .since(now); + + self.relay_pool + .subscribe(vec![ + ephemeral_filter, + gift_wrap_filter, + ephemeral_gift_wrap_filter, + ]) + .await } /// Convert a Nostr event to an MCP message with validation. pub fn convert_event_to_mcp(&self, content: &str) -> Option { - if !validation::validate_message_size(content) { - tracing::warn!("Message size validation failed: {} bytes", content.len()); - return None; - } - - let value: serde_json::Value = serde_json::from_str(content).ok()?; - validation::validate_message(&value) + validation::validate_and_parse(content) } /// Create a signed Nostr event for an MCP message. @@ -99,9 +101,62 @@ impl BaseTransport { self.relay_pool.sign(builder).await } + /// Prepare an MCP message for publishing without actually publishing it. + /// + /// Signs (and optionally gift-wraps) the event, returning the inner signed + /// event ID together with the final event that should be published to relays. + pub async fn prepare_mcp_message( + &self, + message: &JsonRpcMessage, + recipient: &PublicKey, + kind: u16, + tags: Vec, + is_encrypted: Option, + gift_wrap_kind: Option, + ) -> Result<(EventId, Event)> { + let should_encrypt = self.should_encrypt(kind, is_encrypted); + + let event = self.create_signed_event(message, kind, tags).await?; + let signed_event_id = event.id; + + if should_encrypt { + let event_json = + serde_json::to_string(&event).map_err(|e| Error::Encryption(e.to_string()))?; + let signer = self + .relay_pool + .signer() + .await + .map_err(|e| Error::Encryption(e.to_string()))?; + let selected_gift_wrap_kind = gift_wrap_kind.unwrap_or(GIFT_WRAP_KIND); + let gift_wrap_event = encryption::gift_wrap_single_layer_with_kind( + &signer, + recipient, + &event_json, + selected_gift_wrap_kind, + ) + .await?; + tracing::debug!( + target: LOG_TARGET, + signed_event_id = %signed_event_id, + envelope_id = %gift_wrap_event.id, + gift_wrap_kind = selected_gift_wrap_kind, + "Prepared encrypted MCP message" + ); + Ok((signed_event_id, gift_wrap_event)) + } else { + tracing::debug!( + target: LOG_TARGET, + signed_event_id = %signed_event_id, + "Prepared unencrypted MCP message" + ); + Ok((signed_event_id, event)) + } + } + /// Send an MCP message to a recipient, optionally encrypting. /// - /// Returns the event ID of the published event. + /// Returns the signed MCP event ID. + /// When encrypted, this is the inner signed event ID. pub async fn send_mcp_message( &self, message: &JsonRpcMessage, @@ -109,29 +164,49 @@ impl BaseTransport { kind: u16, tags: Vec, is_encrypted: Option, + gift_wrap_kind: Option, ) -> Result { let should_encrypt = self.should_encrypt(kind, is_encrypted); let event = self.create_signed_event(message, kind, tags).await?; + let signed_event_id = event.id; if should_encrypt { // Single-layer gift wrap: JSON.stringify(signedEvent) → NIP-44 encrypt // This matches the JS/TS SDK's encryptMessage(JSON.stringify(event), recipient) - let event_json = serde_json::to_string(&event) + let event_json = + serde_json::to_string(&event).map_err(|e| Error::Encryption(e.to_string()))?; + let signer = self + .relay_pool + .signer() + .await .map_err(|e| Error::Encryption(e.to_string()))?; - let signer = self.relay_pool.client().signer().await - .map_err(|e| Error::Encryption(e.to_string()))?; - let gift_wrap_event = encryption::gift_wrap_single_layer( - &signer, recipient, &event_json, - ).await?; - let event_id = self.relay_pool.publish_event(&gift_wrap_event).await?; - tracing::debug!(event_id = %event_id, "Sent encrypted MCP message"); - Ok(event_id) + let selected_gift_wrap_kind = gift_wrap_kind.unwrap_or(GIFT_WRAP_KIND); + let gift_wrap_event = encryption::gift_wrap_single_layer_with_kind( + &signer, + recipient, + &event_json, + selected_gift_wrap_kind, + ) + .await?; + self.relay_pool.publish_event(&gift_wrap_event).await?; + tracing::debug!( + target: LOG_TARGET, + signed_event_id = %signed_event_id, + envelope_id = %gift_wrap_event.id, + gift_wrap_kind = selected_gift_wrap_kind, + "Sent encrypted MCP message" + ); } else { - let event_id = self.relay_pool.publish_event(&event).await?; - tracing::debug!(event_id = %event_id, "Sent unencrypted MCP message"); - Ok(event_id) + self.relay_pool.publish_event(&event).await?; + tracing::debug!( + target: LOG_TARGET, + signed_event_id = %signed_event_id, + "Sent unencrypted MCP message" + ); } + + Ok(signed_event_id) } /// Determine whether a message should be encrypted. @@ -157,13 +232,27 @@ impl BaseTransport { pub fn create_response_tags(pubkey: &PublicKey, event_id: &EventId) -> Vec { vec![Tag::public_key(*pubkey), Tag::event(*event_id)] } + + /// Compose outbound event tags in canonical order: + /// routing (p, e) -> discovery (one-shot caps) -> negotiation (pmi, persistent). + pub fn compose_outbound_tags( + base_tags: &[Tag], + discovery_tags: &[Tag], + negotiation_tags: &[Tag], + ) -> Vec { + let mut tags = + Vec::with_capacity(base_tags.len() + discovery_tags.len() + negotiation_tags.len()); + tags.extend_from_slice(base_tags); + tags.extend_from_slice(discovery_tags); + tags.extend_from_slice(negotiation_tags); + tags + } } #[cfg(test)] mod tests { use super::*; use crate::core::types::*; - use nostr_sdk::prelude::*; // Test should_encrypt logic without constructing full BaseTransport fn should_encrypt(mode: EncryptionMode, kind: u16, is_encrypted: Option) -> bool { @@ -179,24 +268,60 @@ mod tests { #[test] fn test_should_encrypt_disabled_mode() { - assert!(!should_encrypt(EncryptionMode::Disabled, CTXVM_MESSAGES_KIND, None)); - assert!(!should_encrypt(EncryptionMode::Disabled, CTXVM_MESSAGES_KIND, Some(true))); - assert!(!should_encrypt(EncryptionMode::Disabled, CTXVM_MESSAGES_KIND, Some(false))); + assert!(!should_encrypt( + EncryptionMode::Disabled, + CTXVM_MESSAGES_KIND, + None + )); + assert!(!should_encrypt( + EncryptionMode::Disabled, + CTXVM_MESSAGES_KIND, + Some(true) + )); + assert!(!should_encrypt( + EncryptionMode::Disabled, + CTXVM_MESSAGES_KIND, + Some(false) + )); } #[test] fn test_should_encrypt_required_mode() { - assert!(should_encrypt(EncryptionMode::Required, CTXVM_MESSAGES_KIND, None)); - assert!(should_encrypt(EncryptionMode::Required, CTXVM_MESSAGES_KIND, Some(false))); - assert!(should_encrypt(EncryptionMode::Required, CTXVM_MESSAGES_KIND, Some(true))); + assert!(should_encrypt( + EncryptionMode::Required, + CTXVM_MESSAGES_KIND, + None + )); + assert!(should_encrypt( + EncryptionMode::Required, + CTXVM_MESSAGES_KIND, + Some(false) + )); + assert!(should_encrypt( + EncryptionMode::Required, + CTXVM_MESSAGES_KIND, + Some(true) + )); } #[test] fn test_should_encrypt_optional_mode() { // Default (None) → true - assert!(should_encrypt(EncryptionMode::Optional, CTXVM_MESSAGES_KIND, None)); - assert!(should_encrypt(EncryptionMode::Optional, CTXVM_MESSAGES_KIND, Some(true))); - assert!(!should_encrypt(EncryptionMode::Optional, CTXVM_MESSAGES_KIND, Some(false))); + assert!(should_encrypt( + EncryptionMode::Optional, + CTXVM_MESSAGES_KIND, + None + )); + assert!(should_encrypt( + EncryptionMode::Optional, + CTXVM_MESSAGES_KIND, + Some(true) + )); + assert!(!should_encrypt( + EncryptionMode::Optional, + CTXVM_MESSAGES_KIND, + Some(false) + )); } #[test] @@ -224,10 +349,9 @@ mod tests { let keys = Keys::generate(); let pubkey = keys.public_key(); // Create a dummy event ID - let event_id = EventId::from_hex( - "0000000000000000000000000000000000000000000000000000000000000001", - ) - .unwrap(); + let event_id = + EventId::from_hex("0000000000000000000000000000000000000000000000000000000000000001") + .unwrap(); let tags = BaseTransport::create_response_tags(&pubkey, &event_id); assert_eq!(tags.len(), 2); @@ -287,4 +411,57 @@ mod tests { let big = "x".repeat(MAX_MESSAGE_SIZE + 1); assert!(!crate::core::validation::validate_message_size(&big)); } + + // ── compose_outbound_tags ────────────────────────────────── + + fn make_custom_tag(name: &str) -> Tag { + Tag::custom(TagKind::Custom(name.into()), Vec::::new()) + } + + #[test] + fn compose_outbound_tags_ordering() { + let keys = Keys::generate(); + let base = vec![Tag::public_key(keys.public_key())]; + let discovery = vec![make_custom_tag("support_encryption")]; + let negotiation = vec![make_custom_tag("pmi")]; + + let result = BaseTransport::compose_outbound_tags(&base, &discovery, &negotiation); + assert_eq!(result.len(), 3); + assert_eq!(result[0].clone().to_vec()[0], "p"); + assert_eq!(result[1].clone().to_vec()[0], "support_encryption"); + assert_eq!(result[2].clone().to_vec()[0], "pmi"); + } + + #[test] + fn compose_outbound_tags_empty_discovery() { + let keys = Keys::generate(); + let base = vec![Tag::public_key(keys.public_key())]; + let negotiation = vec![make_custom_tag("pmi")]; + + let result = BaseTransport::compose_outbound_tags(&base, &[], &negotiation); + assert_eq!(result.len(), 2); + assert_eq!(result[0].clone().to_vec()[0], "p"); + assert_eq!(result[1].clone().to_vec()[0], "pmi"); + } + + #[test] + fn compose_outbound_tags_all_empty() { + let result = BaseTransport::compose_outbound_tags(&[], &[], &[]); + assert!(result.is_empty()); + } + + #[test] + fn compose_outbound_tags_preserves_all_elements() { + let discovery = vec![ + make_custom_tag("support_encryption"), + make_custom_tag("support_encryption_ephemeral"), + ]; + let result = BaseTransport::compose_outbound_tags(&[], &discovery, &[]); + assert_eq!(result.len(), 2); + assert_eq!(result[0].clone().to_vec()[0], "support_encryption"); + assert_eq!( + result[1].clone().to_vec()[0], + "support_encryption_ephemeral" + ); + } } diff --git a/src/transport/client.rs b/src/transport/client.rs deleted file mode 100644 index 7dffcf5..0000000 --- a/src/transport/client.rs +++ /dev/null @@ -1,318 +0,0 @@ -//! Client-side Nostr transport for ContextVM. -//! -//! Connects to a remote MCP server over Nostr. Sends JSON-RPC requests as -//! kind 25910 events, correlates responses via `e` tag. - -use std::collections::HashSet; -use std::sync::Arc; -use std::time::Duration; - -use nostr_sdk::prelude::*; -use tokio::sync::RwLock; - -use crate::core::constants::*; -use crate::core::error::{Error, Result}; -use crate::core::serializers; -use crate::core::types::*; -use crate::encryption; -use crate::relay::RelayPool; -use crate::transport::base::BaseTransport; - -/// Configuration for the client transport. -pub struct NostrClientTransportConfig { - /// Relay URLs to connect to. - pub relay_urls: Vec, - /// The server's public key (hex). - pub server_pubkey: String, - /// Encryption mode. - pub encryption_mode: EncryptionMode, - /// Stateless mode: emulate initialize response locally. - pub is_stateless: bool, - /// Response timeout (default: 30s). - pub timeout: Duration, -} - -impl Default for NostrClientTransportConfig { - fn default() -> Self { - Self { - relay_urls: vec!["wss://relay.damus.io".to_string()], - server_pubkey: String::new(), - encryption_mode: EncryptionMode::Optional, - is_stateless: false, - timeout: Duration::from_secs(30), - } - } -} - -/// Client-side Nostr transport for sending MCP requests and receiving responses. -pub struct NostrClientTransport { - base: BaseTransport, - config: NostrClientTransportConfig, - server_pubkey: PublicKey, - /// Pending request event IDs awaiting responses. - pending_requests: Arc>>, - /// Channel for receiving processed MCP messages from the event loop. - message_tx: tokio::sync::mpsc::UnboundedSender, - message_rx: Option>, -} - -impl NostrClientTransport { - /// Create a new client transport. - pub async fn new(signer: T, config: NostrClientTransportConfig) -> Result - where - T: IntoNostrSigner, - { - let server_pubkey = PublicKey::from_hex(&config.server_pubkey) - .map_err(|e| Error::Other(format!("Invalid server pubkey: {e}")))?; - - let relay_pool = Arc::new(RelayPool::new(signer).await?); - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - - Ok(Self { - base: BaseTransport { - relay_pool, - encryption_mode: config.encryption_mode, - is_connected: false, - }, - config, - server_pubkey, - pending_requests: Arc::new(RwLock::new(HashSet::new())), - message_tx: tx, - message_rx: Some(rx), - }) - } - - /// Connect and start listening for responses. - pub async fn start(&mut self) -> Result<()> { - self.base.connect(&self.config.relay_urls).await?; - - let pubkey = self.base.get_public_key().await?; - tracing::info!(pubkey = %pubkey.to_hex(), "Client transport started"); - - self.base.subscribe_for_pubkey(&pubkey).await?; - - // Spawn event loop - let client = self.base.relay_pool.client().clone(); - let pending = self.pending_requests.clone(); - let server_pubkey = self.server_pubkey; - let tx = self.message_tx.clone(); - let encryption_mode = self.config.encryption_mode; - - tokio::spawn(async move { - Self::event_loop(client, pending, server_pubkey, tx, encryption_mode).await; - }); - - Ok(()) - } - - /// Close the transport. - pub async fn close(&mut self) -> Result<()> { - self.base.disconnect().await - } - - /// Send a JSON-RPC message to the server. - pub async fn send(&self, message: &JsonRpcMessage) -> Result<()> { - // Stateless mode: emulate initialize response - if self.config.is_stateless { - if let JsonRpcMessage::Request(ref req) = message { - if req.method == "initialize" { - self.emulate_initialize_response(&req.id); - return Ok(()); - } - } - if let JsonRpcMessage::Notification(ref n) = message { - if n.method == "notifications/initialized" { - return Ok(()); - } - } - } - - let tags = BaseTransport::create_recipient_tags(&self.server_pubkey); - let event_id = self - .base - .send_mcp_message(message, &self.server_pubkey, CTXVM_MESSAGES_KIND, tags, None) - .await?; - - self.pending_requests - .write() - .await - .insert(event_id.to_hex()); - - Ok(()) - } - - /// Take the message receiver for consuming incoming messages. - pub fn take_message_receiver( - &mut self, - ) -> Option> { - self.message_rx.take() - } - - fn emulate_initialize_response(&self, request_id: &serde_json::Value) { - let response = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: "2.0".to_string(), - id: request_id.clone(), - result: serde_json::json!({ - "protocolVersion": "2025-03-26", - "serverInfo": { - "name": "Emulated-Stateless-Server", - "version": "1.0.0" - }, - "capabilities": { - "tools": { "listChanged": true }, - "prompts": { "listChanged": true }, - "resources": { "subscribe": true, "listChanged": true } - } - }), - }); - let _ = self.message_tx.send(response); - } - - async fn event_loop( - client: Arc, - pending: Arc>>, - server_pubkey: PublicKey, - tx: tokio::sync::mpsc::UnboundedSender, - _encryption_mode: EncryptionMode, - ) { - let mut notifications = client.notifications(); - - while let Ok(notification) = notifications.recv().await { - if let RelayPoolNotification::Event { event, .. } = notification { - // Handle gift-wrapped events - let (actual_event_content, actual_pubkey, e_tag) = - if event.kind == Kind::Custom(GIFT_WRAP_KIND) { - // Single-layer NIP-44 decrypt (matches JS/TS SDK) - let signer = match client.signer().await { - Ok(s) => s, - Err(e) => { - tracing::error!("Failed to get signer: {e}"); - continue; - } - }; - match encryption::decrypt_gift_wrap_single_layer(&signer, &event).await { - Ok(decrypted_json) => { - match serde_json::from_str::(&decrypted_json) { - Ok(inner) => { - let e_tag = serializers::get_tag_value(&inner.tags, "e"); - (inner.content, inner.pubkey, e_tag) - } - Err(e) => { - tracing::error!("Failed to parse inner event: {e}"); - continue; - } - } - } - Err(e) => { - tracing::error!("Failed to decrypt gift wrap: {e}"); - continue; - } - } - } else { - let e_tag = serializers::get_tag_value(&event.tags, "e"); - (event.content.clone(), event.pubkey, e_tag) - }; - - // Verify it's from our server - if actual_pubkey != server_pubkey { - tracing::debug!("Skipping event from unexpected pubkey"); - continue; - } - - // Correlate response - if let Some(ref correlated_id) = e_tag { - let is_pending = pending.read().await.contains(correlated_id.as_str()); - if !is_pending { - tracing::warn!(e_tag = %correlated_id, "Response for unknown request"); - continue; - } - } - - // Parse MCP message - if let Some(mcp_msg) = - serializers::nostr_event_to_mcp_message(&actual_event_content) - { - // Clean up pending request - if let Some(ref correlated_id) = e_tag { - pending.write().await.remove(correlated_id.as_str()); - } - let _ = tx.send(mcp_msg); - } - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::core::types::*; - - #[test] - fn test_config_defaults() { - let config = NostrClientTransportConfig::default(); - assert_eq!(config.relay_urls, vec!["wss://relay.damus.io".to_string()]); - assert!(config.server_pubkey.is_empty()); - assert_eq!(config.encryption_mode, EncryptionMode::Optional); - assert!(!config.is_stateless); - assert_eq!(config.timeout, Duration::from_secs(30)); - } - - #[test] - fn test_stateless_config() { - let config = NostrClientTransportConfig { - is_stateless: true, - ..Default::default() - }; - assert!(config.is_stateless); - } - - #[test] - fn test_stateless_emulated_initialize_response_shape() { - // Verify the emulated response has the expected structure - let request_id = serde_json::json!(1); - let response = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: "2.0".to_string(), - id: request_id.clone(), - result: serde_json::json!({ - "protocolVersion": "2025-03-26", - "serverInfo": { - "name": "Emulated-Stateless-Server", - "version": "1.0.0" - }, - "capabilities": { - "tools": { "listChanged": true }, - "prompts": { "listChanged": true }, - "resources": { "subscribe": true, "listChanged": true } - } - }), - }); - assert!(response.is_response()); - assert_eq!(response.id(), Some(&serde_json::json!(1))); - - if let JsonRpcMessage::Response(r) = &response { - assert!(r.result.get("capabilities").is_some()); - assert!(r.result.get("serverInfo").is_some()); - let server_info = r.result.get("serverInfo").unwrap(); - assert_eq!(server_info.get("name").unwrap().as_str().unwrap(), "Emulated-Stateless-Server"); - } - } - - #[test] - fn test_stateless_mode_initialize_request_detection() { - let init_req = JsonRpcMessage::Request(JsonRpcRequest { - jsonrpc: "2.0".to_string(), - id: serde_json::json!(1), - method: "initialize".to_string(), - params: None, - }); - assert_eq!(init_req.method(), Some("initialize")); - - let init_notif = JsonRpcMessage::Notification(JsonRpcNotification { - jsonrpc: "2.0".to_string(), - method: "notifications/initialized".to_string(), - params: None, - }); - assert_eq!(init_notif.method(), Some("notifications/initialized")); - } -} diff --git a/src/transport/client/correlation_store.rs b/src/transport/client/correlation_store.rs new file mode 100644 index 0000000..0fbcd9f --- /dev/null +++ b/src/transport/client/correlation_store.rs @@ -0,0 +1,212 @@ +//! Client-side correlation store for tracking pending request event IDs. + +use std::num::NonZeroUsize; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use lru::LruCache; +use tokio::sync::RwLock; + +use crate::core::constants::DEFAULT_LRU_SIZE; + +/// A pending request tracked by the correlation store. +#[derive(Debug, Clone)] +pub struct PendingRequest { + /// The original JSON-RPC request ID before event-ID replacement. + pub original_id: serde_json::Value, + /// Whether this request is an `initialize` handshake. + pub is_initialize: bool, + /// When the request was registered. + pub registered_at: Instant, +} + +/// Tracks pending request event IDs and their original request IDs on the client side. +/// +/// An optional capacity limit enables LRU eviction of the oldest entry when the +/// store is full. +#[derive(Clone)] +pub struct ClientCorrelationStore { + pending_requests: Arc>>, +} + +impl Default for ClientCorrelationStore { + fn default() -> Self { + Self::new() + } +} + +impl ClientCorrelationStore { + pub fn new() -> Self { + Self::with_max_pending(DEFAULT_LRU_SIZE) + } + + /// Create a store with an upper bound on pending requests. + /// When the limit is reached the oldest entry is evicted. + pub fn with_max_pending(max_pending: usize) -> Self { + Self { + pending_requests: Arc::new(RwLock::new(LruCache::new( + NonZeroUsize::new(max_pending).unwrap_or(NonZeroUsize::new(1).unwrap()), + ))), + } + } + + /// Register a pending request with its original JSON-RPC request ID. + pub async fn register( + &self, + event_id: String, + original_id: serde_json::Value, + is_initialize: bool, + ) { + self.pending_requests.write().await.push( + event_id, + PendingRequest { + original_id, + is_initialize, + registered_at: Instant::now(), + }, + ); + } + + /// Check whether a given event ID corresponds to an `initialize` request. + pub async fn is_initialize_request(&self, event_id: &str) -> bool { + self.pending_requests + .read() + .await + .peek(event_id) + .is_some_and(|r| r.is_initialize) + } + + pub async fn contains(&self, event_id: &str) -> bool { + self.pending_requests.read().await.contains(event_id) + } + + /// Remove a pending request. Returns `true` if the key existed. + pub async fn remove(&self, event_id: &str) -> bool { + self.pending_requests.write().await.pop(event_id).is_some() + } + + /// Retrieve the original request ID for a given event ID without removing it. + pub async fn get_original_id(&self, event_id: &str) -> Option { + self.pending_requests + .read() + .await + .peek(event_id) + .map(|r| r.original_id.clone()) + } + + /// Number of pending requests currently tracked. + pub async fn count(&self) -> usize { + self.pending_requests.read().await.len() + } + + /// Remove all entries older than `timeout`. Returns the number of entries removed. + pub async fn sweep_expired(&self, timeout: Duration) -> usize { + let now = Instant::now(); + let mut cache = self.pending_requests.write().await; + let mut expired_keys = Vec::new(); + + for (key, entry) in cache.iter() { + if now.duration_since(entry.registered_at) >= timeout { + expired_keys.push(key.clone()); + } + } + + let count = expired_keys.len(); + for key in expired_keys { + cache.pop(&key); + } + count + } + + pub async fn clear(&self) { + self.pending_requests.write().await.clear(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn remove_nonexistent_is_noop() { + let store = ClientCorrelationStore::new(); + assert!(!store.remove("nonexistent").await); + assert!(!store.contains("nonexistent").await); + } + + #[tokio::test] + async fn contains_after_clear() { + let store = ClientCorrelationStore::new(); + store + .register("e1".into(), serde_json::Value::Null, false) + .await; + store + .register("e2".into(), serde_json::Value::Null, false) + .await; + assert!(store.contains("e1").await); + store.clear().await; + assert!(!store.contains("e1").await); + assert!(!store.contains("e2").await); + } + + #[tokio::test] + async fn register_and_remove_roundtrip() { + let store = ClientCorrelationStore::new(); + store + .register("e1".into(), serde_json::Value::Null, false) + .await; + assert!(store.contains("e1").await); + assert!(store.remove("e1").await); + assert!(!store.contains("e1").await); + } + + #[tokio::test] + async fn default_store_is_bounded() { + let store = ClientCorrelationStore::new(); + for i in 0..=DEFAULT_LRU_SIZE { + store + .register(format!("e{i}"), serde_json::Value::Null, false) + .await; + } + + assert_eq!(store.count().await, DEFAULT_LRU_SIZE); + assert!(!store.contains("e0").await); + assert!(store.contains(&format!("e{DEFAULT_LRU_SIZE}")).await); + } + + #[tokio::test] + async fn sweep_expired_removes_only_stale_entries() { + let store = ClientCorrelationStore::new(); + + // Insert an entry that will be "old" by the time we sweep. + store + .register("old".into(), serde_json::json!(1), false) + .await; + + // Sleep so "old" entry ages past the threshold. + tokio::time::sleep(Duration::from_millis(20)).await; + + // Insert a fresh entry. + store + .register("fresh".into(), serde_json::json!(2), false) + .await; + + // Sweep with a 10ms timeout — "old" should be removed, "fresh" should remain. + let swept = store.sweep_expired(Duration::from_millis(10)).await; + assert_eq!(swept, 1); + assert!(!store.contains("old").await); + assert!(store.contains("fresh").await); + } + + #[tokio::test] + async fn sweep_expired_returns_zero_when_nothing_expired() { + let store = ClientCorrelationStore::new(); + store + .register("e1".into(), serde_json::Value::Null, false) + .await; + + let swept = store.sweep_expired(Duration::from_secs(60)).await; + assert_eq!(swept, 0); + assert!(store.contains("e1").await); + } +} diff --git a/src/transport/client/mod.rs b/src/transport/client/mod.rs new file mode 100644 index 0000000..31f80c7 --- /dev/null +++ b/src/transport/client/mod.rs @@ -0,0 +1,1218 @@ +//! Client-side Nostr transport for ContextVM. +//! +//! Connects to a remote MCP server over Nostr. Sends JSON-RPC requests as +//! kind 25910 events, correlates responses via `e` tag. + +pub mod correlation_store; + +pub use correlation_store::ClientCorrelationStore; + +use std::num::NonZeroUsize; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Mutex}; +use std::time::Duration; + +use lru::LruCache; +use nostr_sdk::prelude::*; +use tokio_util::sync::CancellationToken; + +use crate::core::constants::*; +use crate::core::error::{Error, Result}; +use crate::core::serializers; +use crate::core::types::*; +use crate::core::validation; +use crate::encryption; +use crate::relay::{RelayPool, RelayPoolTrait}; +use crate::transport::base::BaseTransport; +use crate::transport::discovery_tags::{parse_discovered_peer_capabilities, PeerCapabilities}; + +const LOG_TARGET: &str = "contextvm_sdk::transport::client"; + +/// Configuration for the client transport. +#[non_exhaustive] +pub struct NostrClientTransportConfig { + /// Relay URLs to connect to. + pub relay_urls: Vec, + /// The server's public key (hex). + pub server_pubkey: String, + /// Encryption mode. + pub encryption_mode: EncryptionMode, + /// Gift-wrap policy for encrypted messages. + pub gift_wrap_mode: GiftWrapMode, + /// Stateless mode: emulate initialize response locally. + pub is_stateless: bool, + /// Correlation-retention TTL for pending client requests (default: 30s). + /// + /// Stale pending entries older than this are swept from the correlation store. + /// This prevents leaks -- rmcp owns actual request timeout and cancellation. + /// Keep this value above your rmcp request timeout to avoid premature cleanup. + pub timeout: Duration, +} + +impl Default for NostrClientTransportConfig { + fn default() -> Self { + Self { + relay_urls: vec!["wss://relay.damus.io".to_string()], + server_pubkey: String::new(), + encryption_mode: EncryptionMode::Optional, + gift_wrap_mode: GiftWrapMode::Optional, + is_stateless: false, + timeout: Duration::from_secs(30), + } + } +} + +impl NostrClientTransportConfig { + /// Set the server's public key (hex). + pub fn with_server_pubkey(mut self, pubkey: impl Into) -> Self { + self.server_pubkey = pubkey.into(); + self + } + /// Set the encryption mode. + pub fn with_encryption_mode(mut self, mode: EncryptionMode) -> Self { + self.encryption_mode = mode; + self + } + /// Set the gift-wrap mode (CEP-19). + pub fn with_gift_wrap_mode(mut self, mode: GiftWrapMode) -> Self { + self.gift_wrap_mode = mode; + self + } + /// Enable or disable stateless mode. + pub fn with_stateless(mut self, stateless: bool) -> Self { + self.is_stateless = stateless; + self + } + /// Set the relay URLs to connect to. + pub fn with_relay_urls(mut self, urls: Vec) -> Self { + self.relay_urls = urls; + self + } + /// Set the correlation-retention TTL. + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.timeout = timeout; + self + } +} + +/// Client-side Nostr transport for sending MCP requests and receiving responses. +pub struct NostrClientTransport { + base: BaseTransport, + config: NostrClientTransportConfig, + server_pubkey: PublicKey, + /// Pending request event IDs awaiting responses. + pending_requests: ClientCorrelationStore, + /// CEP-35: one-shot flag for client discovery tag emission. + has_sent_discovery_tags: AtomicBool, + /// CEP-35: learned server capabilities from inbound discovery tags. + discovered_server_capabilities: Arc>, + /// CEP-35: first inbound event carrying discovery tags (session baseline). + server_initialize_event: Arc>>, + /// Learned support for server-side ephemeral gift wraps. + server_supports_ephemeral: Arc, + /// Outer gift-wrap event IDs successfully decrypted and verified (inner `verify()`). + /// Duplicate outer ids are skipped before decrypt; ids are inserted only after success + /// so failed decrypt/verify can be retried on redelivery. + seen_gift_wrap_ids: Arc>>, + /// Channel for receiving processed MCP messages from the event loop. + message_tx: Option>, + message_rx: Option>, + /// Token used to cancel the spawned event loop on close(). + cancellation_token: CancellationToken, + /// Handle for the spawned event loop task. + event_loop_handle: Option>, +} + +impl NostrClientTransport { + /// Create a new client transport. + pub async fn new(signer: T, config: NostrClientTransportConfig) -> Result + where + T: IntoNostrSigner, + { + let server_pubkey = PublicKey::from_hex(&config.server_pubkey).map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + server_pubkey = %config.server_pubkey, + "Invalid server pubkey" + ); + Error::Other(format!("Invalid server pubkey: {error}")) + })?; + + let relay_pool: Arc = + Arc::new(RelayPool::new(signer).await.map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to initialize relay pool for client transport" + ); + error + })?); + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let seen_gift_wrap_ids = Arc::new(Mutex::new(LruCache::new( + NonZeroUsize::new(DEFAULT_LRU_SIZE).expect("DEFAULT_LRU_SIZE must be non-zero"), + ))); + + tracing::info!( + target: LOG_TARGET, + relay_count = config.relay_urls.len(), + stateless = config.is_stateless, + encryption_mode = ?config.encryption_mode, + "Created client transport" + ); + Ok(Self { + base: BaseTransport { + relay_pool, + encryption_mode: config.encryption_mode, + is_connected: false, + }, + config, + server_pubkey, + pending_requests: ClientCorrelationStore::new(), + has_sent_discovery_tags: AtomicBool::new(false), + discovered_server_capabilities: Arc::new(Mutex::new(PeerCapabilities::default())), + server_initialize_event: Arc::new(Mutex::new(None)), + server_supports_ephemeral: Arc::new(AtomicBool::new(false)), + seen_gift_wrap_ids, + message_tx: Some(tx), + message_rx: Some(rx), + cancellation_token: CancellationToken::new(), + event_loop_handle: None, + }) + } + + /// Like [`new`](Self::new) but accepts an existing relay pool. + pub async fn with_relay_pool( + config: NostrClientTransportConfig, + relay_pool: Arc, + ) -> Result { + let server_pubkey = PublicKey::from_hex(&config.server_pubkey).map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + server_pubkey = %config.server_pubkey, + "Invalid server pubkey" + ); + Error::Other(format!("Invalid server pubkey: {error}")) + })?; + + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let seen_gift_wrap_ids = Arc::new(Mutex::new(LruCache::new( + NonZeroUsize::new(DEFAULT_LRU_SIZE).expect("DEFAULT_LRU_SIZE must be non-zero"), + ))); + + tracing::info!( + target: LOG_TARGET, + relay_count = config.relay_urls.len(), + stateless = config.is_stateless, + encryption_mode = ?config.encryption_mode, + "Created client transport (with_relay_pool)" + ); + Ok(Self { + base: BaseTransport { + relay_pool, + encryption_mode: config.encryption_mode, + is_connected: false, + }, + config, + server_pubkey, + pending_requests: ClientCorrelationStore::new(), + has_sent_discovery_tags: AtomicBool::new(false), + discovered_server_capabilities: Arc::new(Mutex::new(PeerCapabilities::default())), + server_initialize_event: Arc::new(Mutex::new(None)), + server_supports_ephemeral: Arc::new(AtomicBool::new(false)), + seen_gift_wrap_ids, + message_tx: Some(tx), + message_rx: Some(rx), + cancellation_token: CancellationToken::new(), + event_loop_handle: None, + }) + } + + /// Connect and start listening for responses. + pub async fn start(&mut self) -> Result<()> { + self.base + .connect(&self.config.relay_urls) + .await + .map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to connect client transport to relays" + ); + error + })?; + + let pubkey = self.base.get_public_key().await.map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to fetch client transport public key" + ); + error + })?; + tracing::info!( + target: LOG_TARGET, + pubkey = %pubkey.to_hex(), + "Client transport started" + ); + + self.base + .subscribe_for_pubkey(&pubkey) + .await + .map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + pubkey = %pubkey.to_hex(), + "Failed to subscribe client transport for pubkey" + ); + error + })?; + + // Spawn event loop with cancellation support + let relay_pool = Arc::clone(&self.base.relay_pool); + let pending = self.pending_requests.clone(); + let server_pubkey = self.server_pubkey; + let tx = self + .message_tx + .as_ref() + .expect("message_tx must exist before start()") + .clone(); + let encryption_mode = self.config.encryption_mode; + let gift_wrap_mode = self.config.gift_wrap_mode; + let discovered_caps = self.discovered_server_capabilities.clone(); + let init_event = self.server_initialize_event.clone(); + let server_supports_ephemeral = self.server_supports_ephemeral.clone(); + let seen_gift_wrap_ids = self.seen_gift_wrap_ids.clone(); + let timeout = self.config.timeout; + let token = self.cancellation_token.child_token(); + + self.event_loop_handle = Some(tokio::spawn(async move { + Self::event_loop( + relay_pool, + pending, + server_pubkey, + tx, + encryption_mode, + gift_wrap_mode, + discovered_caps, + init_event, + server_supports_ephemeral, + seen_gift_wrap_ids, + timeout, + token, + ) + .await; + })); + + tracing::info!( + target: LOG_TARGET, + relay_count = self.config.relay_urls.len(), + "Client transport event loop spawned" + ); + Ok(()) + } + + /// Close the transport — cancels the event loop and disconnects from relays. + pub async fn close(&mut self) -> Result<()> { + self.cancellation_token.cancel(); + if let Some(handle) = self.event_loop_handle.take() { + let _ = handle.await; + } + self.message_tx.take(); + self.base.disconnect().await + } + + /// Send a JSON-RPC message to the server. + pub async fn send(&self, message: &JsonRpcMessage) -> Result<()> { + // Stateless mode: emulate initialize response + if self.config.is_stateless { + if let JsonRpcMessage::Request(ref req) = message { + if req.method == "initialize" { + self.emulate_initialize_response(&req.id); + return Ok(()); + } + } + if let JsonRpcMessage::Notification(ref n) = message { + if n.method == "notifications/initialized" { + return Ok(()); + } + } + } + + let is_request = message.is_request(); + let base_tags = BaseTransport::create_recipient_tags(&self.server_pubkey); + let discovery_tags = if is_request { + self.get_pending_client_discovery_tags() + } else { + vec![] + }; + let tags = BaseTransport::compose_outbound_tags(&base_tags, &discovery_tags, &[]); + + let (event_id, publishable_event) = self + .base + .prepare_mcp_message( + message, + &self.server_pubkey, + CTXVM_MESSAGES_KIND, + tags, + None, + Some(self.choose_outbound_gift_wrap_kind()), + ) + .await + .map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + server_pubkey = %self.server_pubkey.to_hex(), + method = ?message.method(), + "Failed to prepare client message" + ); + error + })?; + + if let JsonRpcMessage::Request(ref req) = message { + let is_initialize = req.method == INITIALIZE_METHOD; + self.pending_requests + .register(event_id.to_hex(), req.id.clone(), is_initialize) + .await; + } + + if let Err(error) = self.base.relay_pool.publish_event(&publishable_event).await { + self.pending_requests.remove(&event_id.to_hex()).await; + tracing::error!( + target: LOG_TARGET, + error = %error, + server_pubkey = %self.server_pubkey.to_hex(), + method = ?message.method(), + "Failed to publish client message" + ); + return Err(error); + } + + // Flip one-shot flag only after successful publish + if is_request && !discovery_tags.is_empty() { + self.has_sent_discovery_tags.store(true, Ordering::Relaxed); + } + + tracing::debug!( + target: LOG_TARGET, + event_id = %event_id.to_hex(), + method = ?message.method(), + "Sent client message" + ); + Ok(()) + } + + /// Take the message receiver for consuming incoming messages. + pub fn take_message_receiver( + &mut self, + ) -> Option> { + self.message_rx.take() + } + + fn emulate_initialize_response(&self, request_id: &serde_json::Value) { + let response = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request_id.clone(), + result: serde_json::json!({ + "protocolVersion": crate::core::constants::mcp_protocol_version(), + "serverInfo": { + "name": "Emulated-Stateless-Server", + "version": "1.0.0" + }, + "capabilities": { + "tools": { "listChanged": true }, + "prompts": { "listChanged": true }, + "resources": { "subscribe": true, "listChanged": true } + } + }), + }); + if let Some(ref tx) = self.message_tx { + let _ = tx.send(response); + } + } + + #[allow(clippy::too_many_arguments)] + async fn event_loop( + relay_pool: Arc, + pending: ClientCorrelationStore, + server_pubkey: PublicKey, + tx: tokio::sync::mpsc::UnboundedSender, + encryption_mode: EncryptionMode, + gift_wrap_mode: GiftWrapMode, + discovered_caps: Arc>, + init_event: Arc>>, + server_supports_ephemeral: Arc, + seen_gift_wrap_ids: Arc>>, + timeout: Duration, + cancel: CancellationToken, + ) { + let mut notifications = relay_pool.notifications(); + // Sweep interval: half the timeout, clamped to [1s, 30s]. + let sweep_interval = (timeout / 2).clamp(Duration::from_secs(1), Duration::from_secs(30)); + let mut sweep_timer = + tokio::time::interval_at(tokio::time::Instant::now() + sweep_interval, sweep_interval); + + loop { + tokio::select! { + _ = cancel.cancelled() => { + tracing::info!( + target: LOG_TARGET, + "Client event loop cancelled" + ); + break; + } + result = notifications.recv() => { + let notification = match result { + Ok(n) => n, + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + tracing::warn!( + target: LOG_TARGET, + skipped = n, + "Relay broadcast lagged, skipping missed events" + ); + continue; + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => break, + }; + Self::handle_notification( + ¬ification, + &pending, + server_pubkey, + &tx, + encryption_mode, + gift_wrap_mode, + &discovered_caps, + &init_event, + &server_supports_ephemeral, + &seen_gift_wrap_ids, + &relay_pool, + ) + .await; + } + _ = sweep_timer.tick() => { + let swept = pending.sweep_expired(timeout).await; + if swept > 0 { + tracing::warn!( + target: LOG_TARGET, + swept, + timeout_ms = timeout.as_millis() as u64, + "Swept stale pending requests (rmcp handles timeout errors)" + ); + } + } + } + } + } + + // ── CEP-35 discovery tag helpers ────────────────────────────── + + /// Constructs client capability tags based on config. + fn get_client_capability_tags(&self) -> Vec { + let mut tags = Vec::new(); + if self.config.encryption_mode != EncryptionMode::Disabled { + tags.push(Tag::custom( + TagKind::Custom(tags::SUPPORT_ENCRYPTION.into()), + Vec::::new(), + )); + if self.config.gift_wrap_mode != GiftWrapMode::Persistent { + tags.push(Tag::custom( + TagKind::Custom(tags::SUPPORT_ENCRYPTION_EPHEMERAL.into()), + Vec::::new(), + )); + } + } + tags + } + + /// One-shot: returns capability tags if not yet sent, empty otherwise. + fn get_pending_client_discovery_tags(&self) -> Vec { + if self.has_sent_discovery_tags.load(Ordering::Relaxed) { + vec![] + } else { + self.get_client_capability_tags() + } + } + + /// Parses inbound event tags and updates learned server capabilities. + fn learn_server_discovery( + discovered_caps: &Mutex, + init_event: &Mutex>, + event: &Event, + ) { + let tag_vec: Vec = event.tags.clone().to_vec(); + let discovered = parse_discovered_peer_capabilities(&tag_vec); + if discovered.discovery_tags.is_empty() { + return; + } + + { + let mut caps = match discovered_caps.lock() { + Ok(g) => g, + Err(p) => p.into_inner(), + }; + caps.supports_encryption |= discovered.capabilities.supports_encryption; + caps.supports_ephemeral_encryption |= + discovered.capabilities.supports_ephemeral_encryption; + caps.supports_oversized_transfer |= discovered.capabilities.supports_oversized_transfer; + } + + let mut stored = match init_event.lock() { + Ok(g) => g, + Err(p) => p.into_inner(), + }; + if stored.is_none() { + *stored = Some(event.clone()); + } + // Note: TS SDK has an upgrade path where a later event with an InitializeResult + // replaces a non-initialize baseline. Not implemented here -- edge case only + // relevant if the first server message with discovery tags is a notification. + } + + /// Returns a clone of the first inbound event that carried server discovery tags. + pub fn get_server_initialize_event(&self) -> Option { + let guard = match self.server_initialize_event.lock() { + Ok(g) => g, + Err(p) => p.into_inner(), + }; + guard.clone() + } + + /// Returns a snapshot of the learned server capabilities from discovery tags. + pub fn discovered_server_capabilities(&self) -> PeerCapabilities { + let guard = match self.discovered_server_capabilities.lock() { + Ok(g) => g, + Err(p) => p.into_inner(), + }; + *guard + } + + #[allow(clippy::too_many_arguments)] + async fn handle_notification( + notification: &RelayPoolNotification, + pending: &ClientCorrelationStore, + server_pubkey: PublicKey, + tx: &tokio::sync::mpsc::UnboundedSender, + encryption_mode: EncryptionMode, + gift_wrap_mode: GiftWrapMode, + discovered_caps: &Arc>, + init_event: &Arc>>, + server_supports_ephemeral: &Arc, + seen_gift_wrap_ids: &Arc>>, + relay_pool: &Arc, + ) { + let event = match notification { + RelayPoolNotification::Event { event, .. } => event, + _ => return, + }; + + let is_gift_wrap = is_gift_wrap_kind(&event.kind); + let outer_kind = event.kind.as_u16(); + + // Enforce encryption mode before decrypt/parse. + if violates_encryption_policy(&event.kind, &encryption_mode) { + if is_gift_wrap { + tracing::warn!( + target: LOG_TARGET, + event_id = %event.id.to_hex(), + event_kind = outer_kind, + configured_mode = ?gift_wrap_mode, + "Skipping encrypted response because client encryption is disabled" + ); + } else { + tracing::warn!( + target: LOG_TARGET, + event_id = %event.id.to_hex(), + "Skipping plaintext response because client encryption is required" + ); + } + return; + } + + // Enforce CEP-19 gift-wrap-mode policy. + if is_gift_wrap && !gift_wrap_mode.allows_kind(outer_kind) { + tracing::warn!( + target: LOG_TARGET, + event_id = %event.id.to_hex(), + event_kind = outer_kind, + configured_mode = ?gift_wrap_mode, + "Skipping gift wrap due to CEP-19 policy" + ); + return; + } + + // Handle gift-wrapped events + let (actual_event_content, actual_pubkey, e_tag, verified_tags, source_event) = + if is_gift_wrap { + { + let guard = match seen_gift_wrap_ids.lock() { + Ok(g) => g, + Err(poisoned) => poisoned.into_inner(), + }; + if guard.contains(&event.id) { + tracing::debug!( + target: LOG_TARGET, + event_id = %event.id.to_hex(), + "Skipping duplicate gift-wrap (outer id)" + ); + return; + } + } + // Single-layer NIP-44 decrypt (matches JS/TS SDK) + let signer = match relay_pool.signer().await { + Ok(s) => s, + Err(error) => { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to get signer" + ); + return; + } + }; + match encryption::decrypt_gift_wrap_single_layer(&signer, event).await { + Ok(decrypted_json) => match serde_json::from_str::(&decrypted_json) { + Ok(inner) => { + if let Err(e) = inner.verify() { + tracing::warn!("Inner event signature verification failed: {e}"); + return; + } + { + let mut guard = match seen_gift_wrap_ids.lock() { + Ok(g) => g, + Err(poisoned) => poisoned.into_inner(), + }; + guard.put(event.id, ()); + } + let e_tag = serializers::get_tag_value(&inner.tags, "e"); + let inner_clone = inner.clone(); + (inner.content, inner.pubkey, e_tag, inner.tags, inner_clone) + } + Err(error) => { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to parse inner event" + ); + return; + } + }, + Err(error) => { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to decrypt gift wrap" + ); + return; + } + } + } else { + let e_tag = serializers::get_tag_value(&event.tags, "e"); + let event_clone: Event = (**event).clone(); + ( + event.content.clone(), + event.pubkey, + e_tag, + event.tags.clone(), + event_clone, + ) + }; + + // Verify it's from our server + if actual_pubkey != server_pubkey { + tracing::debug!( + target: LOG_TARGET, + event_pubkey = %actual_pubkey.to_hex(), + expected_pubkey = %server_pubkey.to_hex(), + "Skipping event from unexpected pubkey" + ); + return; + } + + // CEP-35: learn server capabilities from discovery tags + Self::learn_server_discovery(discovered_caps, init_event, &source_event); + + // CEP-19: learn ephemeral support from server + if Self::should_learn_ephemeral_support( + actual_pubkey, + server_pubkey, + if is_gift_wrap { Some(outer_kind) } else { None }, + &verified_tags, + ) { + server_supports_ephemeral.store(true, Ordering::Relaxed); + } + + // Correlate response + if let Some(ref correlated_id) = e_tag { + let is_pending = pending.contains(correlated_id.as_str()).await; + if !is_pending { + tracing::warn!( + target: LOG_TARGET, + correlated_event_id = %correlated_id, + "Response for unknown request" + ); + return; + } + } + + // Parse MCP message + if let Some(mcp_msg) = validation::validate_and_parse(&actual_event_content) { + // Drop uncorrelated responses and server-to-client requests (matches TS SDK). + match &mcp_msg { + JsonRpcMessage::Response(_) | JsonRpcMessage::ErrorResponse(_) + if e_tag.is_none() => + { + tracing::warn!( + target: LOG_TARGET, + "Dropping response/error without correlation `e` tag" + ); + return; + } + JsonRpcMessage::Request(_) => { + tracing::warn!( + target: LOG_TARGET, + method = ?mcp_msg.method(), + "Dropping server-to-client request (invalid in MCP)" + ); + return; + } + _ => {} + } + + // Clean up pending request + if let Some(ref correlated_id) = e_tag { + pending.remove(correlated_id.as_str()).await; + } + let _ = tx.send(mcp_msg); + } + } + + fn choose_outbound_gift_wrap_kind(&self) -> u16 { + match self.config.gift_wrap_mode { + GiftWrapMode::Persistent => GIFT_WRAP_KIND, + GiftWrapMode::Ephemeral => EPHEMERAL_GIFT_WRAP_KIND, + GiftWrapMode::Optional => { + if self.server_supports_ephemeral.load(Ordering::Relaxed) { + EPHEMERAL_GIFT_WRAP_KIND + } else { + GIFT_WRAP_KIND + } + } + } + } + + fn has_support_ephemeral_tag(tags: &Tags) -> bool { + tags.iter().any(|tag| { + tag.kind() + == TagKind::Custom( + crate::core::constants::tags::SUPPORT_ENCRYPTION_EPHEMERAL.into(), + ) + }) + } + + fn should_learn_ephemeral_support( + actual_pubkey: PublicKey, + server_pubkey: PublicKey, + event_kind: Option, + tags: &Tags, + ) -> bool { + actual_pubkey == server_pubkey + && (event_kind == Some(EPHEMERAL_GIFT_WRAP_KIND) + || Self::has_support_ephemeral_tag(tags)) + } + + /// Returns whether the client has learned ephemeral gift-wrap support from the server. + pub fn server_supports_ephemeral_encryption(&self) -> bool { + self.server_supports_ephemeral.load(Ordering::Relaxed) + } +} + +#[inline] +fn is_gift_wrap_kind(kind: &Kind) -> bool { + *kind == Kind::Custom(GIFT_WRAP_KIND) || *kind == Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND) +} + +/// Returns `true` when the inbound event kind violates the configured encryption +/// policy and must be dropped before any further processing. +#[inline] +fn violates_encryption_policy(kind: &Kind, mode: &EncryptionMode) -> bool { + let is_gift_wrap = is_gift_wrap_kind(kind); + (is_gift_wrap && *mode == EncryptionMode::Disabled) + || (!is_gift_wrap && *mode == EncryptionMode::Required) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_config_defaults() { + let config = NostrClientTransportConfig::default(); + assert_eq!(config.relay_urls, vec!["wss://relay.damus.io".to_string()]); + assert!(config.server_pubkey.is_empty()); + assert_eq!(config.encryption_mode, EncryptionMode::Optional); + assert_eq!(config.gift_wrap_mode, GiftWrapMode::Optional); + assert!(!config.is_stateless); + assert_eq!(config.timeout, Duration::from_secs(30)); + } + + #[test] + fn test_stateless_config() { + let config = NostrClientTransportConfig { + is_stateless: true, + ..Default::default() + }; + assert!(config.is_stateless); + } + + #[test] + fn test_custom_timeout_config() { + let config = NostrClientTransportConfig { + timeout: Duration::from_secs(60), + ..Default::default() + }; + assert_eq!(config.timeout, Duration::from_secs(60)); + } + + #[test] + fn test_has_support_ephemeral_tag_detects_capability() { + let tags = Tags::from_list(vec![Tag::custom( + TagKind::Custom(crate::core::constants::tags::SUPPORT_ENCRYPTION_EPHEMERAL.into()), + Vec::::new(), + )]); + assert!(NostrClientTransport::has_support_ephemeral_tag(&tags)); + } + + #[test] + fn test_has_support_ephemeral_tag_absent() { + let tags = Tags::from_list(vec![Tag::custom( + TagKind::Custom(crate::core::constants::tags::SUPPORT_ENCRYPTION.into()), + Vec::::new(), + )]); + assert!(!NostrClientTransport::has_support_ephemeral_tag(&tags)); + } + + #[test] + fn test_should_learn_ephemeral_support_requires_matching_server_pubkey() { + let server_keys = Keys::generate(); + let other_keys = Keys::generate(); + let tags = Tags::from_list(vec![Tag::custom( + TagKind::Custom(crate::core::constants::tags::SUPPORT_ENCRYPTION_EPHEMERAL.into()), + Vec::::new(), + )]); + + assert!(!NostrClientTransport::should_learn_ephemeral_support( + other_keys.public_key(), + server_keys.public_key(), + Some(EPHEMERAL_GIFT_WRAP_KIND), + &tags, + )); + assert!(NostrClientTransport::should_learn_ephemeral_support( + server_keys.public_key(), + server_keys.public_key(), + Some(EPHEMERAL_GIFT_WRAP_KIND), + &tags, + )); + } + + #[test] + fn test_should_learn_from_ephemeral_kind_even_without_tag() { + let server_keys = Keys::generate(); + let empty_tags = Tags::from_list(vec![]); + + assert!(NostrClientTransport::should_learn_ephemeral_support( + server_keys.public_key(), + server_keys.public_key(), + Some(EPHEMERAL_GIFT_WRAP_KIND), + &empty_tags, + )); + } + + #[test] + fn test_should_learn_from_tag_without_ephemeral_kind() { + let server_keys = Keys::generate(); + let tags = Tags::from_list(vec![Tag::custom( + TagKind::Custom(crate::core::constants::tags::SUPPORT_ENCRYPTION_EPHEMERAL.into()), + Vec::::new(), + )]); + + assert!(NostrClientTransport::should_learn_ephemeral_support( + server_keys.public_key(), + server_keys.public_key(), + Some(GIFT_WRAP_KIND), // persistent kind, but tag present + &tags, + )); + } + + #[test] + fn test_stateless_emulated_initialize_response_shape() { + let request_id = serde_json::json!(1); + let response = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request_id.clone(), + result: serde_json::json!({ + "protocolVersion": crate::core::constants::mcp_protocol_version(), + "serverInfo": { + "name": "Emulated-Stateless-Server", + "version": "1.0.0" + }, + "capabilities": { + "tools": { "listChanged": true }, + "prompts": { "listChanged": true }, + "resources": { "subscribe": true, "listChanged": true } + } + }), + }); + assert!(response.is_response()); + assert_eq!(response.id(), Some(&serde_json::json!(1))); + + if let JsonRpcMessage::Response(r) = &response { + assert!(r.result.get("capabilities").is_some()); + assert!(r.result.get("serverInfo").is_some()); + let server_info = r.result.get("serverInfo").unwrap(); + assert_eq!( + server_info.get("name").unwrap().as_str().unwrap(), + "Emulated-Stateless-Server" + ); + } + } + + #[test] + fn test_stateless_mode_initialize_request_detection() { + let init_req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: None, + }); + assert_eq!(init_req.method(), Some("initialize")); + + let init_notif = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/initialized".to_string(), + params: None, + }); + assert_eq!(init_notif.method(), Some("notifications/initialized")); + } + + #[test] + fn test_gift_wrap_kind_detection() { + assert!(is_gift_wrap_kind(&Kind::Custom(GIFT_WRAP_KIND))); + assert!(is_gift_wrap_kind(&Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND))); + assert!(!is_gift_wrap_kind(&Kind::Custom(CTXVM_MESSAGES_KIND))); + } + + #[test] + fn test_required_mode_drops_plaintext() { + let plaintext_kind = Kind::Custom(CTXVM_MESSAGES_KIND); + assert!( + violates_encryption_policy(&plaintext_kind, &EncryptionMode::Required), + "Required mode must reject plaintext (non-gift-wrap) events" + ); + } + + #[test] + fn test_disabled_mode_drops_encrypted() { + assert!( + violates_encryption_policy(&Kind::Custom(GIFT_WRAP_KIND), &EncryptionMode::Disabled), + "Disabled mode must reject gift-wrap events" + ); + assert!( + violates_encryption_policy( + &Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND), + &EncryptionMode::Disabled + ), + "Disabled mode must reject ephemeral gift-wrap events" + ); + } + + #[test] + fn test_optional_mode_accepts_all() { + let plaintext = Kind::Custom(CTXVM_MESSAGES_KIND); + let gift_wrap = Kind::Custom(GIFT_WRAP_KIND); + let ephemeral = Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND); + assert!(!violates_encryption_policy( + &plaintext, + &EncryptionMode::Optional + )); + assert!(!violates_encryption_policy( + &gift_wrap, + &EncryptionMode::Optional + )); + assert!(!violates_encryption_policy( + &ephemeral, + &EncryptionMode::Optional + )); + } + + #[test] + fn test_required_mode_accepts_encrypted() { + assert!( + !violates_encryption_policy(&Kind::Custom(GIFT_WRAP_KIND), &EncryptionMode::Required), + "Required mode must accept gift-wrap events" + ); + assert!( + !violates_encryption_policy( + &Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND), + &EncryptionMode::Required + ), + "Required mode must accept ephemeral gift-wrap events" + ); + } + + #[test] + fn test_disabled_mode_accepts_plaintext() { + let plaintext = Kind::Custom(CTXVM_MESSAGES_KIND); + assert!( + !violates_encryption_policy(&plaintext, &EncryptionMode::Disabled), + "Disabled mode must accept plaintext events" + ); + } + + // ── CEP-35 client discovery tag emission ──────────────────── + + fn make_transport_for_tags( + encryption_mode: EncryptionMode, + gift_wrap_mode: GiftWrapMode, + ) -> NostrClientTransport { + let keys = Keys::generate(); + NostrClientTransport { + base: BaseTransport { + relay_pool: Arc::new(crate::relay::mock::MockRelayPool::new()), + encryption_mode, + is_connected: false, + }, + config: NostrClientTransportConfig { + encryption_mode, + gift_wrap_mode, + server_pubkey: Keys::generate().public_key().to_hex(), + ..Default::default() + }, + server_pubkey: keys.public_key(), + pending_requests: ClientCorrelationStore::new(), + has_sent_discovery_tags: AtomicBool::new(false), + discovered_server_capabilities: Arc::new(Mutex::new(PeerCapabilities::default())), + server_initialize_event: Arc::new(Mutex::new(None)), + server_supports_ephemeral: Arc::new(AtomicBool::new(false)), + seen_gift_wrap_ids: Arc::new(Mutex::new(LruCache::new(NonZeroUsize::new(10).unwrap()))), + message_tx: Some(tokio::sync::mpsc::unbounded_channel().0), + message_rx: None, + cancellation_token: CancellationToken::new(), + event_loop_handle: None, + } + } + + fn make_tag(parts: &[&str]) -> Tag { + let kind = TagKind::Custom(parts[0].into()); + let values: Vec = parts[1..].iter().map(|s| s.to_string()).collect(); + Tag::custom(kind, values) + } + + fn tag_names(tags: &[Tag]) -> Vec { + tags.iter().map(|t| t.clone().to_vec()[0].clone()).collect() + } + + #[test] + fn client_capability_tags_encryption_optional() { + let t = make_transport_for_tags(EncryptionMode::Optional, GiftWrapMode::Optional); + let tags = t.get_client_capability_tags(); + let names = tag_names(&tags); + assert_eq!( + names, + vec!["support_encryption", "support_encryption_ephemeral"] + ); + } + + #[test] + fn client_capability_tags_encryption_disabled() { + let t = make_transport_for_tags(EncryptionMode::Disabled, GiftWrapMode::Optional); + let tags = t.get_client_capability_tags(); + assert!(tags.is_empty()); + } + + #[test] + fn client_capability_tags_persistent_gift_wrap() { + let t = make_transport_for_tags(EncryptionMode::Optional, GiftWrapMode::Persistent); + let tags = t.get_client_capability_tags(); + let names = tag_names(&tags); + assert_eq!(names, vec!["support_encryption"]); + } + + #[test] + fn client_discovery_tags_sent_once() { + let t = make_transport_for_tags(EncryptionMode::Optional, GiftWrapMode::Optional); + let first = t.get_pending_client_discovery_tags(); + assert!(!first.is_empty()); + + t.has_sent_discovery_tags.store(true, Ordering::Relaxed); + let second = t.get_pending_client_discovery_tags(); + assert!(second.is_empty()); + } + + // ── CEP-35 client capability learning ─────────────────────── + + fn make_event_with_tags(tag_parts: &[&[&str]]) -> Event { + let keys = Keys::generate(); + let tags: Vec = tag_parts.iter().map(|p| make_tag(p)).collect(); + let builder = EventBuilder::new(Kind::Custom(CTXVM_MESSAGES_KIND), "{}").tags(tags); + let unsigned = builder.build(keys.public_key()); + unsigned.sign_with_keys(&keys).unwrap() + } + + #[test] + fn client_learn_server_discovery_sets_baseline() { + let caps = Mutex::new(PeerCapabilities::default()); + let init = Mutex::new(None); + let event = make_event_with_tags(&[&["support_encryption"], &["name", "TestServer"]]); + + NostrClientTransport::learn_server_discovery(&caps, &init, &event); + + let c = caps.lock().unwrap(); + assert!(c.supports_encryption); + assert!(!c.supports_ephemeral_encryption); + + let stored = init.lock().unwrap(); + assert!(stored.is_some()); + assert_eq!(stored.as_ref().unwrap().id, event.id); + } + + #[test] + fn client_learn_server_discovery_or_assigns() { + let caps = Mutex::new(PeerCapabilities::default()); + let init = Mutex::new(None); + + let event1 = make_event_with_tags(&[&["support_encryption"]]); + NostrClientTransport::learn_server_discovery(&caps, &init, &event1); + + // Second event with different caps does NOT downgrade + let event2 = make_event_with_tags(&[&["support_encryption_ephemeral"]]); + NostrClientTransport::learn_server_discovery(&caps, &init, &event2); + + let c = caps.lock().unwrap(); + assert!(c.supports_encryption, "must not downgrade"); + assert!(c.supports_ephemeral_encryption, "must learn new cap"); + } + + #[test] + fn client_baseline_not_replaced_on_later_events() { + let caps = Mutex::new(PeerCapabilities::default()); + let init = Mutex::new(None); + + let event1 = make_event_with_tags(&[&["support_encryption"], &["name", "First"]]); + NostrClientTransport::learn_server_discovery(&caps, &init, &event1); + let first_id = event1.id; + + let event2 = + make_event_with_tags(&[&["support_encryption_ephemeral"], &["name", "Second"]]); + NostrClientTransport::learn_server_discovery(&caps, &init, &event2); + + let stored = init.lock().unwrap(); + assert_eq!( + stored.as_ref().unwrap().id, + first_id, + "baseline must not be replaced" + ); + } +} diff --git a/src/transport/discovery_tags.rs b/src/transport/discovery_tags.rs new file mode 100644 index 0000000..ca20a9c --- /dev/null +++ b/src/transport/discovery_tags.rs @@ -0,0 +1,273 @@ +//! Discovery tag utilities for CEP-35 capability exchange. +//! +//! Ports the TS SDK's `discovery-tags.ts` module. Provides functions to filter, +//! parse, and learn discovery tags on Nostr events exchanged between MCP clients +//! and servers. + +use nostr_sdk::prelude::*; + +use crate::core::constants::tags; + +/// Routing tag names that are excluded from discovery tags. +const NON_DISCOVERY_TAG_NAMES: &[&str] = &["p", "e"]; + +/// Capability flags learned from inbound peer discovery tags. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct PeerCapabilities { + /// Peer supports NIP-44/NIP-59 encrypted messaging. + pub supports_encryption: bool, + /// Peer supports ephemeral gift wraps (kind 21059, CEP-19). + pub supports_ephemeral_encryption: bool, + /// Peer supports CEP-22 oversized payload transfer. + pub supports_oversized_transfer: bool, +} + +/// Returns `true` when the tag list contains a single-valued tag whose name matches `name`. +/// +/// A single-valued tag is a tag array whose only element is the tag name itself, +/// e.g. `["support_encryption"]`. +pub fn has_single_tag(tags: &[Tag], name: &str) -> bool { + tags.iter().any(|tag| { + let v = tag.clone().to_vec(); + v.len() == 1 && v[0] == name + }) +} + +/// Filters out routing tags (`p`, `e`) and returns cloned discovery tags. +/// +/// Mirrors TS SDK `getDiscoveryTags()`. +pub fn get_discovery_tags(tags: &[Tag]) -> Vec { + tags.iter() + .filter(|tag| { + let v = (*tag).clone().to_vec(); + match v.first() { + Some(name) => !NON_DISCOVERY_TAG_NAMES.contains(&name.as_str()), + None => false, + } + }) + .cloned() + .collect() +} + +/// Inspects tags and returns discovered peer capabilities. +/// +/// Mirrors TS SDK `learnPeerCapabilities()`. +pub fn learn_peer_capabilities(tags: &[Tag]) -> PeerCapabilities { + PeerCapabilities { + supports_encryption: has_single_tag(tags, tags::SUPPORT_ENCRYPTION), + supports_ephemeral_encryption: has_single_tag(tags, tags::SUPPORT_ENCRYPTION_EPHEMERAL), + supports_oversized_transfer: has_single_tag(tags, tags::SUPPORT_OVERSIZED_TRANSFER), + } +} + +/// Parsed capability flags together with the raw discovery tags. +#[derive(Debug, Clone)] +pub struct DiscoveredPeerCapabilities { + /// The filtered discovery tags (routing tags stripped). + pub discovery_tags: Vec, + /// Parsed capability flags. + pub capabilities: PeerCapabilities, +} + +/// Parses peer discovery tags into normalized capability flags plus the raw +/// discovery tags for storage/forwarding. +/// +/// Mirrors TS SDK `parseDiscoveredPeerCapabilities()`. +pub fn parse_discovered_peer_capabilities(tags: &[Tag]) -> DiscoveredPeerCapabilities { + let discovery_tags = get_discovery_tags(tags); + let capabilities = learn_peer_capabilities(&discovery_tags); + DiscoveredPeerCapabilities { + discovery_tags, + capabilities, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_tag(parts: &[&str]) -> Tag { + let kind = TagKind::Custom(parts[0].into()); + let values: Vec = parts[1..].iter().map(|s| s.to_string()).collect(); + Tag::custom(kind, values) + } + + fn tag_name(tag: &Tag) -> String { + tag.clone().to_vec()[0].clone() + } + + // ── has_single_tag ────────────────────────────────────────────── + + #[test] + fn has_single_tag_finds_present() { + let tags = vec![make_tag(&["support_encryption"])]; + assert!(has_single_tag(&tags, "support_encryption")); + } + + #[test] + fn has_single_tag_ignores_multi_value() { + let tags = vec![make_tag(&["support_encryption", "extra"])]; + assert!(!has_single_tag(&tags, "support_encryption")); + } + + #[test] + fn has_single_tag_returns_false_when_absent() { + let tags = vec![make_tag(&["other_tag"])]; + assert!(!has_single_tag(&tags, "support_encryption")); + } + + #[test] + fn has_single_tag_empty_tags() { + assert!(!has_single_tag(&[], "support_encryption")); + } + + // ── get_discovery_tags ────────────────────────────────────────── + + #[test] + fn get_discovery_tags_filters_routing_tags() { + let tags = vec![ + Tag::public_key(Keys::generate().public_key()), + Tag::event(EventId::all_zeros()), + make_tag(&["support_encryption"]), + make_tag(&["name", "My Server"]), + ]; + let discovery = get_discovery_tags(&tags); + assert_eq!(discovery.len(), 2); + assert_eq!(tag_name(&discovery[0]), "support_encryption"); + assert_eq!(tag_name(&discovery[1]), "name"); + } + + #[test] + fn get_discovery_tags_empty_input() { + let discovery = get_discovery_tags(&[]); + assert!(discovery.is_empty()); + } + + #[test] + fn get_discovery_tags_all_routing() { + let tags = vec![ + Tag::public_key(Keys::generate().public_key()), + Tag::event(EventId::all_zeros()), + ]; + let discovery = get_discovery_tags(&tags); + assert!(discovery.is_empty()); + } + + #[test] + fn get_discovery_tags_preserves_order() { + let tags = vec![ + make_tag(&["about", "hello"]), + Tag::public_key(Keys::generate().public_key()), + make_tag(&["website", "https://example.com"]), + make_tag(&["support_encryption"]), + ]; + let discovery = get_discovery_tags(&tags); + assert_eq!(discovery.len(), 3); + assert_eq!(tag_name(&discovery[0]), "about"); + assert_eq!(tag_name(&discovery[1]), "website"); + assert_eq!(tag_name(&discovery[2]), "support_encryption"); + } + + // ── learn_peer_capabilities ───────────────────────────────────── + + #[test] + fn learn_peer_capabilities_all_present() { + let tags = vec![ + make_tag(&["support_encryption"]), + make_tag(&["support_encryption_ephemeral"]), + make_tag(&["support_oversized_transfer"]), + ]; + let caps = learn_peer_capabilities(&tags); + assert!(caps.supports_encryption); + assert!(caps.supports_ephemeral_encryption); + assert!(caps.supports_oversized_transfer); + } + + #[test] + fn learn_peer_capabilities_none_present() { + let tags = vec![make_tag(&["name", "Server"])]; + let caps = learn_peer_capabilities(&tags); + assert!(!caps.supports_encryption); + assert!(!caps.supports_ephemeral_encryption); + assert!(!caps.supports_oversized_transfer); + } + + #[test] + fn learn_peer_capabilities_partial() { + let tags = vec![make_tag(&["support_encryption"])]; + let caps = learn_peer_capabilities(&tags); + assert!(caps.supports_encryption); + assert!(!caps.supports_ephemeral_encryption); + assert!(!caps.supports_oversized_transfer); + } + + #[test] + fn learn_peer_capabilities_empty() { + let caps = learn_peer_capabilities(&[]); + assert_eq!(caps, PeerCapabilities::default()); + } + + #[test] + fn learn_peer_capabilities_ignores_multi_value_capability_tags() { + // Tags with values (e.g. ["support_encryption", "extra"]) are not + // single-valued and should not be treated as capability flags. + let tags = vec![ + make_tag(&["support_encryption", "yes"]), + make_tag(&["support_encryption_ephemeral"]), + ]; + let caps = learn_peer_capabilities(&tags); + assert!(!caps.supports_encryption); + assert!(caps.supports_ephemeral_encryption); + assert!(!caps.supports_oversized_transfer); + } + + // ── parse_discovered_peer_capabilities ────────────────────────── + + #[test] + fn parse_discovered_peer_capabilities_filters_and_parses() { + let tags = vec![ + Tag::public_key(Keys::generate().public_key()), + Tag::event(EventId::all_zeros()), + make_tag(&["support_encryption"]), + make_tag(&["support_encryption_ephemeral"]), + make_tag(&["name", "Test Server"]), + ]; + let result = parse_discovered_peer_capabilities(&tags); + + // Routing tags filtered out + assert_eq!(result.discovery_tags.len(), 3); + + // Capabilities parsed correctly + assert!(result.capabilities.supports_encryption); + assert!(result.capabilities.supports_ephemeral_encryption); + assert!(!result.capabilities.supports_oversized_transfer); + } + + #[test] + fn parse_discovered_peer_capabilities_empty() { + let result = parse_discovered_peer_capabilities(&[]); + assert!(result.discovery_tags.is_empty()); + assert_eq!(result.capabilities, PeerCapabilities::default()); + } + + // ── PeerCapabilities ──────────────────────────────────────────── + + #[test] + fn peer_capabilities_default_all_false() { + let caps = PeerCapabilities::default(); + assert!(!caps.supports_encryption); + assert!(!caps.supports_ephemeral_encryption); + assert!(!caps.supports_oversized_transfer); + } + + #[test] + fn peer_capabilities_copy_semantics() { + let caps = PeerCapabilities { + supports_encryption: true, + supports_ephemeral_encryption: true, + supports_oversized_transfer: false, + }; + let copy = caps; + assert_eq!(caps, copy); + } +} diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 13a4f8a..a5d53e8 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -5,7 +5,9 @@ pub mod base; pub mod client; +pub mod discovery_tags; pub mod server; -pub use client::{NostrClientTransport, NostrClientTransportConfig}; -pub use server::{NostrServerTransport, NostrServerTransportConfig}; +pub use client::{ClientCorrelationStore, NostrClientTransport, NostrClientTransportConfig}; +pub use discovery_tags::*; +pub use server::{NostrServerTransport, NostrServerTransportConfig, ServerEventRouteStore}; diff --git a/src/transport/server.rs b/src/transport/server.rs deleted file mode 100644 index 43d97a8..0000000 --- a/src/transport/server.rs +++ /dev/null @@ -1,755 +0,0 @@ -//! Server-side Nostr transport for ContextVM. -//! -//! Listens for incoming MCP requests from clients over Nostr, manages multi-client -//! sessions, handles request/response correlation, and optionally publishes -//! server announcements. - -use std::collections::HashMap; -use std::sync::Arc; -use std::time::{Duration, Instant}; - -use nostr_sdk::prelude::*; -use tokio::sync::RwLock; - -use crate::core::constants::*; -use crate::core::error::{Error, Result}; -use crate::core::serializers; -use crate::core::types::*; -use crate::encryption; -use crate::relay::RelayPool; -use crate::transport::base::BaseTransport; - -/// Configuration for the server transport. -pub struct NostrServerTransportConfig { - /// Relay URLs to connect to. - pub relay_urls: Vec, - /// Encryption mode. - pub encryption_mode: EncryptionMode, - /// Server information for announcements. - pub server_info: Option, - /// Whether this is a public server (publishes announcements). - pub is_public_server: bool, - /// Allowed client public keys (hex). Empty = allow all. - pub allowed_public_keys: Vec, - /// Capabilities excluded from pubkey whitelisting. - pub excluded_capabilities: Vec, - /// Session cleanup interval (default: 60s). - pub cleanup_interval: Duration, - /// Session timeout (default: 300s). - pub session_timeout: Duration, -} - -impl Default for NostrServerTransportConfig { - fn default() -> Self { - Self { - relay_urls: vec!["wss://relay.damus.io".to_string()], - encryption_mode: EncryptionMode::Optional, - server_info: None, - is_public_server: false, - allowed_public_keys: Vec::new(), - excluded_capabilities: Vec::new(), - cleanup_interval: Duration::from_secs(60), - session_timeout: Duration::from_secs(300), - } - } -} - -/// Server-side Nostr transport — receives MCP requests and sends responses. -pub struct NostrServerTransport { - base: BaseTransport, - config: NostrServerTransportConfig, - /// Client sessions: client_pubkey_hex → ClientSession - sessions: Arc>>, - /// Reverse lookup: event_id → client_pubkey_hex - event_to_client: Arc>>, - /// Channel for incoming MCP messages (consumed by the MCP server). - message_tx: tokio::sync::mpsc::UnboundedSender, - message_rx: Option>, -} - -/// An incoming MCP request with metadata for routing the response. -#[derive(Debug)] -pub struct IncomingRequest { - /// The parsed MCP message. - pub message: JsonRpcMessage, - /// The client's public key (hex). - pub client_pubkey: String, - /// The Nostr event ID (for response correlation). - pub event_id: String, - /// Whether the original message was encrypted. - pub is_encrypted: bool, -} - -impl NostrServerTransport { - /// Create a new server transport. - pub async fn new(signer: T, config: NostrServerTransportConfig) -> Result - where - T: IntoNostrSigner, - { - let relay_pool = Arc::new(RelayPool::new(signer).await?); - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - - Ok(Self { - base: BaseTransport { - relay_pool, - encryption_mode: config.encryption_mode, - is_connected: false, - }, - config, - sessions: Arc::new(RwLock::new(HashMap::new())), - event_to_client: Arc::new(RwLock::new(HashMap::new())), - message_tx: tx, - message_rx: Some(rx), - }) - } - - /// Start listening for incoming requests. - pub async fn start(&mut self) -> Result<()> { - self.base.connect(&self.config.relay_urls).await?; - - let pubkey = self.base.get_public_key().await?; - tracing::info!(pubkey = %pubkey.to_hex(), "Server transport started"); - - self.base.subscribe_for_pubkey(&pubkey).await?; - - // Spawn event loop - let client = self.base.relay_pool.client().clone(); - let sessions = self.sessions.clone(); - let event_to_client = self.event_to_client.clone(); - let tx = self.message_tx.clone(); - let allowed = self.config.allowed_public_keys.clone(); - let excluded = self.config.excluded_capabilities.clone(); - let encryption_mode = self.config.encryption_mode; - - tokio::spawn(async move { - Self::event_loop(client, sessions, event_to_client, tx, allowed, excluded, encryption_mode).await; - }); - - // Spawn session cleanup - let sessions_cleanup = self.sessions.clone(); - let event_to_client_cleanup = self.event_to_client.clone(); - let cleanup_interval = self.config.cleanup_interval; - let session_timeout = self.config.session_timeout; - - tokio::spawn(async move { - let mut interval = tokio::time::interval(cleanup_interval); - loop { - interval.tick().await; - let cleaned = Self::cleanup_sessions( - &sessions_cleanup, - &event_to_client_cleanup, - session_timeout, - ) - .await; - if cleaned > 0 { - tracing::info!(cleaned, "Cleaned up inactive sessions"); - } - } - }); - - Ok(()) - } - - /// Close the transport. - pub async fn close(&mut self) -> Result<()> { - self.base.disconnect().await?; - self.sessions.write().await.clear(); - self.event_to_client.write().await.clear(); - Ok(()) - } - - /// Send a response back to the client that sent the original request. - pub async fn send_response( - &self, - event_id: &str, - mut response: JsonRpcMessage, - ) -> Result<()> { - let event_to_client = self.event_to_client.read().await; - let client_pubkey_hex = event_to_client - .get(event_id) - .ok_or_else(|| Error::Other(format!("No client found for event {event_id}")))? - .clone(); - drop(event_to_client); - - let sessions = self.sessions.read().await; - let session = sessions - .get(&client_pubkey_hex) - .ok_or_else(|| Error::Other(format!("No session for client {client_pubkey_hex}")))?; - - // Restore original request ID - if let Some(original_id) = session.pending_requests.get(event_id) { - match &mut response { - JsonRpcMessage::Response(r) => r.id = original_id.clone(), - JsonRpcMessage::ErrorResponse(r) => r.id = original_id.clone(), - _ => {} - } - } - - let is_encrypted = session.is_encrypted; - drop(sessions); - - let client_pubkey = PublicKey::from_hex(&client_pubkey_hex) - .map_err(|e| Error::Other(e.to_string()))?; - - let event_id_parsed = - EventId::from_hex(event_id).map_err(|e| Error::Other(e.to_string()))?; - - let tags = BaseTransport::create_response_tags(&client_pubkey, &event_id_parsed); - - self.base - .send_mcp_message( - &response, - &client_pubkey, - CTXVM_MESSAGES_KIND, - tags, - Some(is_encrypted), - ) - .await?; - - // Clean up - let mut sessions = self.sessions.write().await; - if let Some(session) = sessions.get_mut(&client_pubkey_hex) { - // Clean up progress token - if let Some(token) = session.event_to_progress_token.remove(event_id) { - session.pending_requests.remove(&token); - } - session.pending_requests.remove(event_id); - } - drop(sessions); - - self.event_to_client.write().await.remove(event_id); - - Ok(()) - } - - /// Send a notification to a specific client. - pub async fn send_notification( - &self, - client_pubkey_hex: &str, - notification: &JsonRpcMessage, - correlated_event_id: Option<&str>, - ) -> Result<()> { - let sessions = self.sessions.read().await; - let session = sessions - .get(client_pubkey_hex) - .ok_or_else(|| Error::Other(format!("No session for {client_pubkey_hex}")))?; - let is_encrypted = session.is_encrypted; - drop(sessions); - - let client_pubkey = PublicKey::from_hex(client_pubkey_hex) - .map_err(|e| Error::Other(e.to_string()))?; - - let mut tags = BaseTransport::create_recipient_tags(&client_pubkey); - if let Some(eid) = correlated_event_id { - let event_id = EventId::from_hex(eid).map_err(|e| Error::Other(e.to_string()))?; - tags.push(Tag::event(event_id)); - } - - self.base - .send_mcp_message( - notification, - &client_pubkey, - CTXVM_MESSAGES_KIND, - tags, - Some(is_encrypted), - ) - .await?; - - Ok(()) - } - - /// Broadcast a notification to all initialized clients. - pub async fn broadcast_notification(&self, notification: &JsonRpcMessage) -> Result<()> { - let sessions = self.sessions.read().await; - let initialized: Vec = sessions - .iter() - .filter(|(_, s)| s.is_initialized) - .map(|(k, _)| k.clone()) - .collect(); - drop(sessions); - - for pubkey in initialized { - if let Err(e) = self.send_notification(&pubkey, notification, None).await { - tracing::error!(client = %pubkey, "Failed to send notification: {e}"); - } - } - Ok(()) - } - - /// Take the message receiver for consuming incoming requests. - pub fn take_message_receiver( - &mut self, - ) -> Option> { - self.message_rx.take() - } - - /// Publish server announcement (kind 11316). - pub async fn announce(&self) -> Result { - let info = self - .config - .server_info - .as_ref() - .ok_or_else(|| Error::Other("No server info configured".to_string()))?; - - let content = serde_json::to_string(info)?; - - let mut tags = Vec::new(); - if let Some(ref name) = info.name { - tags.push(Tag::custom( - TagKind::Custom(tags::NAME.into()), - vec![name.clone()], - )); - } - if let Some(ref about) = info.about { - tags.push(Tag::custom( - TagKind::Custom(tags::ABOUT.into()), - vec![about.clone()], - )); - } - if let Some(ref website) = info.website { - tags.push(Tag::custom( - TagKind::Custom(tags::WEBSITE.into()), - vec![website.clone()], - )); - } - if let Some(ref picture) = info.picture { - tags.push(Tag::custom( - TagKind::Custom(tags::PICTURE.into()), - vec![picture.clone()], - )); - } - if self.config.encryption_mode != EncryptionMode::Disabled { - tags.push(Tag::custom( - TagKind::Custom(tags::SUPPORT_ENCRYPTION.into()), - Vec::::new(), - )); - } - - let builder = - EventBuilder::new(Kind::Custom(SERVER_ANNOUNCEMENT_KIND), content).tags(tags); - - self.base.relay_pool.publish(builder).await - } - - /// Publish tools list (kind 11317). - pub async fn publish_tools(&self, tools: Vec) -> Result { - let content = serde_json::json!({ "tools": tools }); - let builder = EventBuilder::new( - Kind::Custom(TOOLS_LIST_KIND), - serde_json::to_string(&content)?, - ); - self.base.relay_pool.publish(builder).await - } - - /// Publish resources list (kind 11318). - pub async fn publish_resources(&self, resources: Vec) -> Result { - let content = serde_json::json!({ "resources": resources }); - let builder = EventBuilder::new( - Kind::Custom(RESOURCES_LIST_KIND), - serde_json::to_string(&content)?, - ); - self.base.relay_pool.publish(builder).await - } - - /// Publish prompts list (kind 11320). - pub async fn publish_prompts(&self, prompts: Vec) -> Result { - let content = serde_json::json!({ "prompts": prompts }); - let builder = EventBuilder::new( - Kind::Custom(PROMPTS_LIST_KIND), - serde_json::to_string(&content)?, - ); - self.base.relay_pool.publish(builder).await - } - - /// Publish resource templates list (kind 11319). - pub async fn publish_resource_templates( - &self, - templates: Vec, - ) -> Result { - let content = serde_json::json!({ "resourceTemplates": templates }); - let builder = EventBuilder::new( - Kind::Custom(RESOURCETEMPLATES_LIST_KIND), - serde_json::to_string(&content)?, - ); - self.base.relay_pool.publish(builder).await - } - - /// Delete server announcements (NIP-09 kind 5). - pub async fn delete_announcements(&self, reason: &str) -> Result<()> { - // We publish kind 5 events for each announcement kind - let pubkey = self.base.get_public_key().await?; - let _pubkey_hex = pubkey.to_hex(); - - for kind in UNENCRYPTED_KINDS { - let builder = EventBuilder::new(Kind::Custom(5), reason) - .tag(Tag::custom( - TagKind::Custom("k".into()), - vec![kind.to_string()], - )); - self.base.relay_pool.publish(builder).await?; - } - Ok(()) - } - - // ── Internal ──────────────────────────────────────────────── - - fn is_capability_excluded( - excluded: &[CapabilityExclusion], - method: &str, - name: Option<&str>, - ) -> bool { - // Always allow fundamental MCP methods - if method == "initialize" || method == "notifications/initialized" { - return true; - } - - excluded.iter().any(|excl| { - if excl.method != method { - return false; - } - match (&excl.name, name) { - (Some(excl_name), Some(req_name)) => excl_name == req_name, - (None, _) => true, // method-only match - _ => false, - } - }) - } - - async fn event_loop( - client: Arc, - sessions: Arc>>, - event_to_client: Arc>>, - tx: tokio::sync::mpsc::UnboundedSender, - allowed_pubkeys: Vec, - excluded_capabilities: Vec, - encryption_mode: EncryptionMode, - ) { - let mut notifications = client.notifications(); - - while let Ok(notification) = notifications.recv().await { - if let RelayPoolNotification::Event { event, .. } = notification { - let (content, sender_pubkey, event_id, is_encrypted) = - if event.kind == Kind::Custom(GIFT_WRAP_KIND) { - if encryption_mode == EncryptionMode::Disabled { - tracing::warn!("Received encrypted message but encryption is disabled"); - continue; - } - // Single-layer NIP-44 decrypt (matches JS/TS SDK) - let signer = match client.signer().await { - Ok(s) => s, - Err(e) => { - tracing::error!("Failed to get signer: {e}"); - continue; - } - }; - match encryption::decrypt_gift_wrap_single_layer(&signer, &event).await { - Ok(decrypted_json) => { - // The decrypted content is JSON of the inner signed event. - // Use the INNER event's ID for correlation — the client - // registers the inner event ID in its correlation store. - match serde_json::from_str::(&decrypted_json) { - Ok(inner) => ( - inner.content, - inner.pubkey.to_hex(), - inner.id.to_hex(), - true, - ), - Err(e) => { - tracing::error!("Failed to parse inner event: {e}"); - continue; - } - } - } - Err(e) => { - tracing::error!("Failed to decrypt: {e}"); - continue; - } - } - } else { - if encryption_mode == EncryptionMode::Required { - tracing::warn!( - pubkey = %event.pubkey, - "Received unencrypted message but encryption is required" - ); - continue; - } - ( - event.content.clone(), - event.pubkey.to_hex(), - event.id.to_hex(), - false, - ) - }; - - // Parse MCP message - let mcp_msg = match serializers::nostr_event_to_mcp_message(&content) { - Some(msg) => msg, - None => { - tracing::warn!("Invalid MCP message from {sender_pubkey}"); - continue; - } - }; - - // Authorization check - if !allowed_pubkeys.is_empty() { - let method = mcp_msg.method().unwrap_or(""); - let name = match &mcp_msg { - JsonRpcMessage::Request(r) => r - .params - .as_ref() - .and_then(|p| p.get("name")) - .and_then(|n| n.as_str()), - _ => None, - }; - - let is_excluded = - Self::is_capability_excluded(&excluded_capabilities, method, name); - - if !allowed_pubkeys.contains(&sender_pubkey) && !is_excluded { - tracing::warn!( - pubkey = %sender_pubkey, - method = %method, - "Unauthorized request" - ); - continue; - } - } - - // Session management - let mut sessions_w = sessions.write().await; - let session = sessions_w - .entry(sender_pubkey.clone()) - .or_insert_with(|| ClientSession::new(is_encrypted)); - session.update_activity(); - session.is_encrypted = is_encrypted; - - // Track request for correlation - if let JsonRpcMessage::Request(ref req) = mcp_msg { - let original_id = req.id.clone(); - session - .pending_requests - .insert(event_id.clone(), original_id); - event_to_client - .write() - .await - .insert(event_id.clone(), sender_pubkey.clone()); - - // Track progress token - if let Some(token) = req - .params - .as_ref() - .and_then(|p| p.get("_meta")) - .and_then(|m| m.get("progressToken")) - .and_then(|t| t.as_str()) - { - session - .pending_requests - .insert(token.to_string(), serde_json::json!(event_id)); - session - .event_to_progress_token - .insert(event_id.clone(), token.to_string()); - } - } - - // Handle initialized notification - if let JsonRpcMessage::Notification(ref n) = mcp_msg { - if n.method == "notifications/initialized" { - session.is_initialized = true; - } - } - - drop(sessions_w); - - // Forward to consumer - let _ = tx.send(IncomingRequest { - message: mcp_msg, - client_pubkey: sender_pubkey, - event_id, - is_encrypted, - }); - } - } - } - - async fn cleanup_sessions( - sessions: &RwLock>, - event_to_client: &RwLock>, - timeout: Duration, - ) -> usize { - let mut sessions_w = sessions.write().await; - let mut event_map = event_to_client.write().await; - let mut cleaned = 0; - - sessions_w.retain(|pubkey, session| { - if session.last_activity.elapsed() > timeout { - // Clean up reverse mappings - for event_id in session.pending_requests.keys() { - event_map.remove(event_id); - } - for event_id in session.event_to_progress_token.keys() { - event_map.remove(event_id); - } - tracing::debug!(client = %pubkey, "Session expired"); - cleaned += 1; - false - } else { - true - } - }); - - cleaned - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::thread; - - // ── Session management ────────────────────────────────────── - - #[test] - fn test_client_session_creation() { - let session = ClientSession::new(true); - assert!(!session.is_initialized); - assert!(session.is_encrypted); - assert!(session.pending_requests.is_empty()); - assert!(session.event_to_progress_token.is_empty()); - } - - #[test] - fn test_client_session_update_activity() { - let mut session = ClientSession::new(false); - let first = session.last_activity; - thread::sleep(Duration::from_millis(10)); - session.update_activity(); - assert!(session.last_activity > first); - } - - #[tokio::test] - async fn test_cleanup_sessions_removes_expired() { - let sessions = Arc::new(RwLock::new(HashMap::new())); - let event_to_client = Arc::new(RwLock::new(HashMap::new())); - - // Insert a session with an old activity time - let mut session = ClientSession::new(false); - session.pending_requests.insert("evt1".to_string(), serde_json::json!(1)); - sessions.write().await.insert("pubkey1".to_string(), session); - event_to_client.write().await.insert("evt1".to_string(), "pubkey1".to_string()); - - // With a long timeout, nothing should be cleaned - let cleaned = NostrServerTransport::cleanup_sessions( - &sessions, &event_to_client, Duration::from_secs(300), - ).await; - assert_eq!(cleaned, 0); - assert_eq!(sessions.read().await.len(), 1); - - // With zero timeout, it should be cleaned - thread::sleep(Duration::from_millis(5)); - let cleaned = NostrServerTransport::cleanup_sessions( - &sessions, &event_to_client, Duration::from_millis(1), - ).await; - assert_eq!(cleaned, 1); - assert!(sessions.read().await.is_empty()); - assert!(event_to_client.read().await.is_empty()); - } - - #[tokio::test] - async fn test_cleanup_preserves_active_sessions() { - let sessions = Arc::new(RwLock::new(HashMap::new())); - let event_to_client = Arc::new(RwLock::new(HashMap::new())); - - let session = ClientSession::new(false); - sessions.write().await.insert("active".to_string(), session); - - let cleaned = NostrServerTransport::cleanup_sessions( - &sessions, &event_to_client, Duration::from_secs(300), - ).await; - assert_eq!(cleaned, 0); - assert_eq!(sessions.read().await.len(), 1); - } - - // ── Request ID correlation ────────────────────────────────── - - #[test] - fn test_pending_request_tracking() { - let mut session = ClientSession::new(false); - session.pending_requests.insert("event_abc".to_string(), serde_json::json!(42)); - assert_eq!(session.pending_requests.get("event_abc"), Some(&serde_json::json!(42))); - } - - #[test] - fn test_progress_token_tracking() { - let mut session = ClientSession::new(false); - session.event_to_progress_token.insert("evt1".to_string(), "token1".to_string()); - session.pending_requests.insert("token1".to_string(), serde_json::json!("evt1")); - assert_eq!(session.event_to_progress_token.get("evt1"), Some(&"token1".to_string())); - } - - // ── Authorization (is_capability_excluded) ────────────────── - - #[test] - fn test_initialize_always_excluded() { - assert!(NostrServerTransport::is_capability_excluded(&[], "initialize", None)); - assert!(NostrServerTransport::is_capability_excluded(&[], "notifications/initialized", None)); - } - - #[test] - fn test_method_excluded_without_name() { - let exclusions = vec![CapabilityExclusion { - method: "tools/list".to_string(), - name: None, - }]; - assert!(NostrServerTransport::is_capability_excluded(&exclusions, "tools/list", None)); - assert!(NostrServerTransport::is_capability_excluded(&exclusions, "tools/list", Some("anything"))); - } - - #[test] - fn test_method_excluded_with_name() { - let exclusions = vec![CapabilityExclusion { - method: "tools/call".to_string(), - name: Some("get_weather".to_string()), - }]; - assert!(NostrServerTransport::is_capability_excluded(&exclusions, "tools/call", Some("get_weather"))); - assert!(!NostrServerTransport::is_capability_excluded(&exclusions, "tools/call", Some("other_tool"))); - assert!(!NostrServerTransport::is_capability_excluded(&exclusions, "tools/call", None)); - } - - #[test] - fn test_non_excluded_method() { - let exclusions = vec![CapabilityExclusion { - method: "tools/list".to_string(), - name: None, - }]; - assert!(!NostrServerTransport::is_capability_excluded(&exclusions, "tools/call", None)); - assert!(!NostrServerTransport::is_capability_excluded(&exclusions, "resources/list", None)); - } - - #[test] - fn test_empty_exclusions_non_init_method() { - assert!(!NostrServerTransport::is_capability_excluded(&[], "tools/list", None)); - assert!(!NostrServerTransport::is_capability_excluded(&[], "tools/call", Some("x"))); - } - - // ── Encryption mode enforcement ───────────────────────────── - - #[test] - fn test_encryption_mode_default() { - let config = NostrServerTransportConfig::default(); - assert_eq!(config.encryption_mode, EncryptionMode::Optional); - } - - // ── Config defaults ───────────────────────────────────────── - - #[test] - fn test_config_defaults() { - let config = NostrServerTransportConfig::default(); - assert_eq!(config.relay_urls, vec!["wss://relay.damus.io".to_string()]); - assert!(!config.is_public_server); - assert!(config.allowed_public_keys.is_empty()); - assert!(config.excluded_capabilities.is_empty()); - assert_eq!(config.cleanup_interval, Duration::from_secs(60)); - assert_eq!(config.session_timeout, Duration::from_secs(300)); - assert!(config.server_info.is_none()); - } -} diff --git a/src/transport/server/correlation_store.rs b/src/transport/server/correlation_store.rs new file mode 100644 index 0000000..c25404e --- /dev/null +++ b/src/transport/server/correlation_store.rs @@ -0,0 +1,387 @@ +//! Server-side event route store for mapping event IDs to client routes. + +use std::collections::{HashMap, HashSet}; +use std::num::NonZeroUsize; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use lru::LruCache; +use tokio::sync::RwLock; + +use crate::core::constants::DEFAULT_LRU_SIZE; + +/// A route entry for an in-flight request. +#[derive(Debug, Clone)] +pub struct RouteEntry { + /// The client's public key that originated this request. + pub client_pubkey: String, + /// The original JSON-RPC request ID (before replacement with event ID). + pub original_request_id: serde_json::Value, + /// Optional progress token for this request. + pub progress_token: Option, + /// The outer gift-wrap event kind that carried this request (e.g. 1059 or 21059). + /// Populated from the inbound event in a later PR; `None` until then. + pub wrap_kind: Option, + /// When the route was registered. + pub registered_at: Instant, +} + +/// Internal state behind the lock. +struct Inner { + /// Primary index: event_id → route entry (LRU-ordered). + routes: LruCache, + /// Secondary index: progress_token → event_id. + progress_token_to_event: HashMap, + /// Secondary index: client_pubkey → set of event_ids. + client_event_ids: HashMap>, +} + +impl Inner { + fn new(max_routes: usize) -> Self { + let routes = + LruCache::new(NonZeroUsize::new(max_routes).unwrap_or(NonZeroUsize::new(1).unwrap())); + Self { + routes, + progress_token_to_event: HashMap::new(), + client_event_ids: HashMap::new(), + } + } + + /// Clean up secondary indexes for a removed route. + fn cleanup_indexes(&mut self, event_id: &str, route: &RouteEntry) { + if let Some(ref token) = route.progress_token { + self.progress_token_to_event.remove(token); + } + if let Some(set) = self.client_event_ids.get_mut(&route.client_pubkey) { + set.remove(event_id); + if set.is_empty() { + self.client_event_ids.remove(&route.client_pubkey); + } + } + } + + /// Remove a single route and clean up all secondary indexes. + fn remove_route(&mut self, event_id: &str) -> Option { + let route = self.routes.pop(event_id)?; + self.cleanup_indexes(event_id, &route); + Some(route) + } +} + +/// Maps event IDs to full route entries for response routing on the server side. +/// +/// An optional capacity limit enables LRU eviction; when the limit is reached +/// the oldest entry is evicted and its secondary indexes are cleaned up. +#[derive(Clone)] +pub struct ServerEventRouteStore { + inner: Arc>, +} + +impl Default for ServerEventRouteStore { + fn default() -> Self { + Self::new() + } +} + +impl ServerEventRouteStore { + pub fn new() -> Self { + Self { + inner: Arc::new(RwLock::new(Inner::new(DEFAULT_LRU_SIZE))), + } + } + + /// Create a store with an upper bound on event routes. + /// When the limit is reached the oldest entry is evicted. + pub fn with_max_routes(max_routes: usize) -> Self { + Self { + inner: Arc::new(RwLock::new(Inner::new(max_routes))), + } + } + + /// Register a route for an incoming request. + pub async fn register( + &self, + event_id: String, + client_pubkey: String, + original_request_id: serde_json::Value, + progress_token: Option, + ) { + let mut inner = self.inner.write().await; + + // Update client index. + inner + .client_event_ids + .entry(client_pubkey.clone()) + .or_default() + .insert(event_id.clone()); + + // Update progress token index. + if let Some(ref token) = progress_token { + inner + .progress_token_to_event + .insert(token.clone(), event_id.clone()); + } + + // Insert into LRU; handle possible eviction. + let evicted = inner.routes.push( + event_id.clone(), + RouteEntry { + client_pubkey, + original_request_id, + progress_token, + wrap_kind: None, + registered_at: Instant::now(), + }, + ); + + if let Some((evicted_key, evicted_route)) = evicted { + if evicted_key != event_id { + // A different entry was evicted due to capacity — clean up its indexes. + inner.cleanup_indexes(&evicted_key, &evicted_route); + } + } + } + + /// Returns the client public key for the given event ID without removing it. + pub async fn get(&self, event_id: &str) -> Option { + self.inner + .read() + .await + .routes + .peek(event_id) + .map(|r| r.client_pubkey.clone()) + } + + /// Returns the full route entry for the given event ID without removing it. + pub async fn get_route(&self, event_id: &str) -> Option { + self.inner.read().await.routes.peek(event_id).cloned() + } + + /// Removes and returns the full route entry for the given event ID. + pub async fn pop(&self, event_id: &str) -> Option { + self.inner.write().await.remove_route(event_id) + } + + /// Removes all routes for a given client public key. Returns the count removed. + pub async fn remove_for_client(&self, client_pubkey: &str) -> usize { + let mut inner = self.inner.write().await; + + let event_ids = match inner.client_event_ids.remove(client_pubkey) { + Some(ids) => ids, + None => return 0, + }; + + let count = event_ids.len(); + for event_id in &event_ids { + if let Some(route) = inner.routes.pop(event_id.as_str()) { + if let Some(ref token) = route.progress_token { + inner.progress_token_to_event.remove(token); + } + } + } + count + } + + /// Check whether a route exists for the given event ID. + pub async fn has_event_route(&self, event_id: &str) -> bool { + self.inner.read().await.routes.contains(event_id) + } + + /// Check whether the given client has any active routes. + pub async fn has_active_routes_for_client(&self, client_pubkey: &str) -> bool { + self.inner + .read() + .await + .client_event_ids + .get(client_pubkey) + .is_some_and(|set| !set.is_empty()) + } + + /// Look up the event ID associated with a progress token. + pub async fn get_event_id_by_progress_token(&self, token: &str) -> Option { + self.inner + .read() + .await + .progress_token_to_event + .get(token) + .cloned() + } + + /// Check whether a progress token mapping exists. + pub async fn has_progress_token(&self, token: &str) -> bool { + self.inner + .read() + .await + .progress_token_to_event + .contains_key(token) + } + + /// Number of event routes currently tracked. + pub async fn event_route_count(&self) -> usize { + self.inner.read().await.routes.len() + } + + /// Number of progress token mappings currently tracked. + pub async fn progress_token_count(&self) -> usize { + self.inner.read().await.progress_token_to_event.len() + } + + /// Remove all route entries older than `timeout`. + /// (Routes for expired sessions are already cleaned by `cleanup_sessions`.) + /// Returns the event IDs of the removed entries. + pub async fn sweep_stale_routes(&self, timeout: Duration) -> Vec { + let now = Instant::now(); + let mut inner = self.inner.write().await; + let mut expired_keys = Vec::new(); + + for (key, entry) in inner.routes.iter() { + if now.duration_since(entry.registered_at) >= timeout { + expired_keys.push(key.clone()); + } + } + + for key in &expired_keys { + inner.remove_route(key); + } + expired_keys + } + + pub async fn clear(&self) { + let mut inner = self.inner.write().await; + inner.routes.clear(); + inner.progress_token_to_event.clear(); + inner.client_event_ids.clear(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[tokio::test] + async fn pop_on_empty_returns_none() { + let store = ServerEventRouteStore::new(); + assert!(store.pop("nonexistent").await.is_none()); + } + + #[tokio::test] + async fn get_returns_without_removing() { + let store = ServerEventRouteStore::new(); + store + .register("e1".into(), "pk1".into(), json!("r1"), None) + .await; + assert_eq!(store.get("e1").await.as_deref(), Some("pk1")); + assert_eq!(store.get("e1").await.as_deref(), Some("pk1")); + } + + #[tokio::test] + async fn pop_removes_entry() { + let store = ServerEventRouteStore::new(); + store + .register("e1".into(), "pk1".into(), json!("r1"), None) + .await; + let route = store.pop("e1").await.unwrap(); + assert_eq!(route.client_pubkey, "pk1"); + assert!(store.pop("e1").await.is_none()); + } + + #[tokio::test] + async fn remove_for_client_only_removes_matching() { + let store = ServerEventRouteStore::new(); + store + .register("e1".into(), "pk1".into(), json!("r1"), None) + .await; + store + .register("e2".into(), "pk2".into(), json!("r2"), None) + .await; + store + .register("e3".into(), "pk1".into(), json!("r3"), None) + .await; + + let removed = store.remove_for_client("pk1").await; + assert_eq!(removed, 2); + + assert!(store.get("e1").await.is_none()); + assert!(store.get("e3").await.is_none()); + assert_eq!(store.get("e2").await.as_deref(), Some("pk2")); + } + + #[tokio::test] + async fn remove_for_client_noop_when_no_match() { + let store = ServerEventRouteStore::new(); + store + .register("e1".into(), "pk1".into(), json!("r1"), None) + .await; + let removed = store.remove_for_client("pk_other").await; + assert_eq!(removed, 0); + assert_eq!(store.get("e1").await.as_deref(), Some("pk1")); + } + + #[tokio::test] + async fn clear_empties_store() { + let store = ServerEventRouteStore::new(); + store + .register("e1".into(), "pk1".into(), json!("r1"), None) + .await; + store + .register("e2".into(), "pk2".into(), json!("r2"), None) + .await; + store.clear().await; + assert!(store.get("e1").await.is_none()); + assert!(store.get("e2").await.is_none()); + } + + #[tokio::test] + async fn default_store_is_bounded() { + let store = ServerEventRouteStore::new(); + for i in 0..=DEFAULT_LRU_SIZE { + store + .register(format!("e{i}"), "pk1".into(), json!(i), None) + .await; + } + + assert_eq!(store.event_route_count().await, DEFAULT_LRU_SIZE); + assert!(!store.has_event_route("e0").await); + assert!(store.has_event_route(&format!("e{DEFAULT_LRU_SIZE}")).await); + } + + #[tokio::test] + async fn sweep_stale_routes_removes_only_expired() { + let store = ServerEventRouteStore::new(); + + // Insert a route that will age past the threshold. + store + .register("old".into(), "pk1".into(), json!(1), Some("tok1".into())) + .await; + + tokio::time::sleep(Duration::from_millis(20)).await; + + // Insert a fresh route. + store + .register("fresh".into(), "pk2".into(), json!(2), None) + .await; + + // Sweep with 10ms timeout — "old" should be removed, "fresh" should remain. + let swept = store.sweep_stale_routes(Duration::from_millis(10)).await; + assert_eq!(swept.len(), 1); + assert_eq!(swept[0], "old"); + assert!(!store.has_event_route("old").await); + assert!(store.has_event_route("fresh").await); + // Secondary indexes should also be cleaned. + assert!(!store.has_progress_token("tok1").await); + assert!(!store.has_active_routes_for_client("pk1").await); + } + + #[tokio::test] + async fn sweep_stale_routes_returns_zero_when_nothing_expired() { + let store = ServerEventRouteStore::new(); + store + .register("e1".into(), "pk1".into(), json!(1), None) + .await; + + let swept = store.sweep_stale_routes(Duration::from_secs(60)).await; + assert!(swept.is_empty()); + assert!(store.has_event_route("e1").await); + } +} diff --git a/src/transport/server/mod.rs b/src/transport/server/mod.rs new file mode 100644 index 0000000..ba9f6c7 --- /dev/null +++ b/src/transport/server/mod.rs @@ -0,0 +1,1907 @@ +//! Server-side Nostr transport for ContextVM. +//! +//! Listens for incoming MCP requests from clients over Nostr, manages multi-client +//! sessions, handles request/response correlation, and optionally publishes +//! server announcements. + +pub mod correlation_store; +pub mod session_store; + +pub use correlation_store::{RouteEntry, ServerEventRouteStore}; +pub use session_store::{SessionSnapshot, SessionStore}; +use tokio::sync::RwLock; + +use std::collections::HashMap; +use std::num::NonZeroUsize; +use std::sync::{Arc, Mutex}; +use std::time::Duration; + +use lru::LruCache; +use nostr_sdk::prelude::*; +use tokio_util::sync::CancellationToken; + +use crate::core::constants::*; +use crate::core::error::{Error, Result}; +use crate::core::types::*; +use crate::core::validation; +use crate::encryption; +use crate::relay::{RelayPool, RelayPoolTrait}; +use crate::transport::base::BaseTransport; +use crate::transport::discovery_tags::learn_peer_capabilities; + +const LOG_TARGET: &str = "contextvm_sdk::transport::server"; + +/// Configuration for the server transport. +#[non_exhaustive] +pub struct NostrServerTransportConfig { + /// Relay URLs to connect to. + pub relay_urls: Vec, + /// Encryption mode. + pub encryption_mode: EncryptionMode, + /// Gift-wrap kind selection policy (CEP-19). + pub gift_wrap_mode: GiftWrapMode, + /// Server information for announcements. + pub server_info: Option, + /// Whether this server publishes public announcements (CEP-6). + pub is_announced_server: bool, + /// Allowed client public keys (hex). Empty = allow all. + pub allowed_public_keys: Vec, + /// Capabilities excluded from pubkey whitelisting. + pub excluded_capabilities: Vec, + /// Maximum number of concurrent client sessions (LRU-bounded, default: 1000). + pub max_sessions: usize, + /// Session cleanup interval (default: 60s). + pub cleanup_interval: Duration, + /// Session timeout (default: 300s). + pub session_timeout: Duration, + /// Correlation-retention TTL for server-side event routes (default: 60s). + /// + /// Stale route entries older than this are swept from the correlation store. + /// This prevents leaks -- rmcp owns actual request timeout and cancellation. + /// Keep this value above your rmcp request timeout to avoid premature cleanup. + pub request_timeout: Duration, +} + +impl Default for NostrServerTransportConfig { + fn default() -> Self { + Self { + relay_urls: vec!["wss://relay.damus.io".to_string()], + encryption_mode: EncryptionMode::Optional, + gift_wrap_mode: GiftWrapMode::Optional, + server_info: None, + is_announced_server: false, + allowed_public_keys: Vec::new(), + excluded_capabilities: Vec::new(), + max_sessions: session_store::DEFAULT_MAX_SESSIONS, + cleanup_interval: Duration::from_secs(60), + session_timeout: Duration::from_secs(300), + request_timeout: Duration::from_secs(60), + } + } +} + +/// Server-side Nostr transport — receives MCP requests and sends responses. +pub struct NostrServerTransport { + /// Relay pool for publishing and subscribing. + base: BaseTransport, + /// Configuration for this server transport. + config: NostrServerTransportConfig, + /// Extra common discovery tags to include in server announcements and first responses. + extra_common_tags: Vec, + /// Pricing tags to include in announcements and capability list responses. + pricing_tags: Vec, + /// Client sessions. + sessions: SessionStore, + /// Reverse lookup: event_id → client route. + event_routes: ServerEventRouteStore, + /// CEP-19: Track the incoming gift-wrap kind per request for mirroring. + request_wrap_kinds: Arc>>>, + /// Outer gift-wrap event IDs successfully decrypted and verified (inner `verify()`). + /// Duplicate outer ids are skipped before decrypt; ids are inserted only after success + /// so failed decrypt/verify can be retried on redelivery. + seen_gift_wrap_ids: Arc>>, + /// Channel for incoming MCP messages (consumed by the MCP server). + message_tx: Option>, + message_rx: Option>, + /// Token used to cancel spawned tasks (event loop + cleanup) on close(). + cancellation_token: CancellationToken, + /// Handles for spawned tasks (event loop + cleanup). + task_handles: Vec>, +} + +impl NostrServerTransportConfig { + /// Set the encryption mode. + pub fn with_encryption_mode(mut self, mode: EncryptionMode) -> Self { + self.encryption_mode = mode; + self + } + /// Set the gift-wrap mode (CEP-19). + pub fn with_gift_wrap_mode(mut self, mode: GiftWrapMode) -> Self { + self.gift_wrap_mode = mode; + self + } + /// Set server information for announcements. + pub fn with_server_info(mut self, info: ServerInfo) -> Self { + self.server_info = Some(info); + self + } + /// Enable or disable public announcement publishing (CEP-6). + pub fn with_announced_server(mut self, announced: bool) -> Self { + self.is_announced_server = announced; + self + } + /// Set the allowed client public keys (hex). Empty = allow all. + pub fn with_allowed_public_keys(mut self, keys: Vec) -> Self { + self.allowed_public_keys = keys; + self + } + /// Set capabilities excluded from pubkey whitelisting. + pub fn with_excluded_capabilities(mut self, caps: Vec) -> Self { + self.excluded_capabilities = caps; + self + } + /// Set the maximum number of concurrent client sessions. + pub fn with_max_sessions(mut self, max: usize) -> Self { + self.max_sessions = max; + self + } + /// Set the relay URLs to connect to. + pub fn with_relay_urls(mut self, urls: Vec) -> Self { + self.relay_urls = urls; + self + } + /// Set the session cleanup interval. + pub fn with_cleanup_interval(mut self, interval: Duration) -> Self { + self.cleanup_interval = interval; + self + } + /// Set the session timeout. + pub fn with_session_timeout(mut self, timeout: Duration) -> Self { + self.session_timeout = timeout; + self + } + /// Set the correlation-retention TTL for event routes. + pub fn with_request_timeout(mut self, timeout: Duration) -> Self { + self.request_timeout = timeout; + self + } +} + +/// An incoming MCP request with metadata for routing the response. +#[derive(Debug)] +#[non_exhaustive] +pub struct IncomingRequest { + /// The parsed MCP message. + pub message: JsonRpcMessage, + /// The client's public key (hex). + pub client_pubkey: String, + /// The Nostr event ID (for response correlation). + pub event_id: String, + /// Whether the original message was encrypted. + pub is_encrypted: bool, +} + +impl NostrServerTransport { + /// Create a new server transport. + pub async fn new(signer: T, config: NostrServerTransportConfig) -> Result + where + T: IntoNostrSigner, + { + let relay_pool: Arc = + Arc::new(RelayPool::new(signer).await.map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to initialize relay pool for server transport" + ); + error + })?); + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let seen_gift_wrap_ids = Arc::new(Mutex::new(LruCache::new( + NonZeroUsize::new(DEFAULT_LRU_SIZE).expect("DEFAULT_LRU_SIZE must be non-zero"), + ))); + + tracing::info!( + target: LOG_TARGET, + relay_count = config.relay_urls.len(), + announced = config.is_announced_server, + encryption_mode = ?config.encryption_mode, + gift_wrap_mode = ?config.gift_wrap_mode, + "Created server transport" + ); + Ok(Self { + base: BaseTransport { + relay_pool, + encryption_mode: config.encryption_mode, + is_connected: false, + }, + sessions: SessionStore::with_capacity(config.max_sessions), + config, + extra_common_tags: Vec::new(), + pricing_tags: Vec::new(), + event_routes: ServerEventRouteStore::new(), + request_wrap_kinds: Arc::new(RwLock::new(HashMap::new())), + seen_gift_wrap_ids, + message_tx: Some(tx), + message_rx: Some(rx), + cancellation_token: CancellationToken::new(), + task_handles: Vec::new(), + }) + } + + /// Like [`new`](Self::new) but accepts an existing relay pool. + pub async fn with_relay_pool( + config: NostrServerTransportConfig, + relay_pool: Arc, + ) -> Result { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let seen_gift_wrap_ids = Arc::new(Mutex::new(LruCache::new( + NonZeroUsize::new(DEFAULT_LRU_SIZE).expect("DEFAULT_LRU_SIZE must be non-zero"), + ))); + + tracing::info!( + target: LOG_TARGET, + relay_count = config.relay_urls.len(), + announced = config.is_announced_server, + encryption_mode = ?config.encryption_mode, + "Created server transport (with_relay_pool)" + ); + Ok(Self { + base: BaseTransport { + relay_pool, + encryption_mode: config.encryption_mode, + is_connected: false, + }, + sessions: SessionStore::with_capacity(config.max_sessions), + config, + extra_common_tags: Vec::new(), + pricing_tags: Vec::new(), + request_wrap_kinds: Arc::new(RwLock::new(HashMap::new())), + event_routes: ServerEventRouteStore::new(), + seen_gift_wrap_ids, + message_tx: Some(tx), + message_rx: Some(rx), + cancellation_token: CancellationToken::new(), + task_handles: Vec::new(), + }) + } + + /// Start listening for incoming requests. + pub async fn start(&mut self) -> Result<()> { + self.base + .connect(&self.config.relay_urls) + .await + .map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to connect server transport to relays" + ); + error + })?; + + let pubkey = self.base.get_public_key().await.map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to fetch server transport public key" + ); + error + })?; + tracing::info!( + target: LOG_TARGET, + pubkey = %pubkey.to_hex(), + "Server transport started" + ); + + self.base + .subscribe_for_pubkey(&pubkey) + .await + .map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + pubkey = %pubkey.to_hex(), + "Failed to subscribe server transport for pubkey" + ); + error + })?; + + // Spawn event loop with cancellation support + let relay_pool = Arc::clone(&self.base.relay_pool); + let sessions = self.sessions.clone(); + let event_routes = self.event_routes.clone(); + let request_wrap_kinds = self.request_wrap_kinds.clone(); + let tx = self + .message_tx + .as_ref() + .expect("message_tx must exist before start()") + .clone(); + let allowed = self.config.allowed_public_keys.clone(); + let excluded = self.config.excluded_capabilities.clone(); + let encryption_mode = self.config.encryption_mode; + let gift_wrap_mode = self.config.gift_wrap_mode; + let is_announced_server = self.config.is_announced_server; + let server_info = self.config.server_info.clone(); + let extra_common_tags = self.extra_common_tags.clone(); + let seen_gift_wrap_ids = self.seen_gift_wrap_ids.clone(); + let event_loop_token = self.cancellation_token.child_token(); + + let event_loop_handle = tokio::spawn(async move { + Self::event_loop( + relay_pool, + sessions, + event_routes, + request_wrap_kinds, + tx, + allowed, + excluded, + encryption_mode, + gift_wrap_mode, + is_announced_server, + server_info, + extra_common_tags, + seen_gift_wrap_ids, + event_loop_token, + ) + .await; + }); + + // Spawn session cleanup with cancellation support + let sessions_cleanup = self.sessions.clone(); + let event_routes_cleanup = self.event_routes.clone(); + let request_wrap_kinds_cleanup = self.request_wrap_kinds.clone(); + let cleanup_interval = self.config.cleanup_interval; + let session_timeout = self.config.session_timeout; + let request_timeout = self.config.request_timeout; + let cleanup_token = self.cancellation_token.child_token(); + + let cleanup_handle = tokio::spawn(async move { + let mut interval = tokio::time::interval(cleanup_interval); + loop { + tokio::select! { + _ = cleanup_token.cancelled() => { + tracing::info!( + target: LOG_TARGET, + "Server cleanup task cancelled" + ); + break; + } + _ = interval.tick() => { + let cleaned = Self::cleanup_sessions( + &sessions_cleanup, + &event_routes_cleanup, + &request_wrap_kinds_cleanup, + session_timeout, + ) + .await; + if cleaned > 0 { + tracing::info!( + target: LOG_TARGET, + cleaned_sessions = cleaned, + "Cleaned up inactive sessions" + ); + } + } + } + + // Sweep stale route entries in active sessions (rmcp handles timeout errors). + let swept_event_ids = event_routes_cleanup + .sweep_stale_routes(request_timeout) + .await; + if !swept_event_ids.is_empty() { + let mut kinds_w = request_wrap_kinds_cleanup.write().await; + for event_id in &swept_event_ids { + kinds_w.remove(event_id); + } + drop(kinds_w); + tracing::warn!( + target: LOG_TARGET, + swept = swept_event_ids.len(), + timeout_secs = request_timeout.as_secs(), + "Swept stale event routes (rmcp handles timeout errors)" + ); + } + } + }); + + self.task_handles.push(event_loop_handle); + self.task_handles.push(cleanup_handle); + + tracing::info!( + target: LOG_TARGET, + relay_count = self.config.relay_urls.len(), + cleanup_interval_secs = self.config.cleanup_interval.as_secs(), + session_timeout_secs = self.config.session_timeout.as_secs(), + "Server transport loops spawned" + ); + Ok(()) + } + + /// Close the transport — cancels event loop and cleanup tasks, then disconnects. + pub async fn close(&mut self) -> Result<()> { + self.cancellation_token.cancel(); + for handle in self.task_handles.drain(..) { + let _ = handle.await; + } + self.message_tx.take(); + self.base.disconnect().await?; + self.sessions.clear().await; + self.event_routes.clear().await; + Ok(()) + } + + /// Send a response back to the client that sent the original request. + pub async fn send_response(&self, event_id: &str, mut response: JsonRpcMessage) -> Result<()> { + // Consume the route up-front so only one concurrent responder can proceed + // for a given event_id. + let route = self.event_routes.pop(event_id).await.ok_or_else(|| { + tracing::error!( + target: LOG_TARGET, + event_id = %event_id, + "No client found for response correlation" + ); + Error::Other(format!("No client found for event {event_id}")) + })?; + + let client_pubkey_hex = route.client_pubkey; + let original_request_id = route.original_request_id; + let progress_token = route.progress_token; + + let mut sessions_w = self.sessions.write().await; + let session = sessions_w.get_mut(&client_pubkey_hex).ok_or_else(|| { + tracing::error!( + target: LOG_TARGET, + client_pubkey = %client_pubkey_hex, + "No session for correlated client" + ); + Error::Other(format!("No session for client {client_pubkey_hex}")) + })?; + + // Restore original request ID + match &mut response { + JsonRpcMessage::Response(r) => r.id = original_request_id.clone(), + JsonRpcMessage::ErrorResponse(r) => r.id = original_request_id.clone(), + _ => {} + } + + let is_encrypted = session.is_encrypted; + + // CEP-35: include discovery tags on first response to this client + let discovery_tags = self.take_pending_server_discovery_tags(session); + drop(sessions_w); + + // CEP-19: Look up the incoming wrap kind for mirroring + let mirrored_wrap_kind = self + .request_wrap_kinds + .read() + .await + .get(event_id) + .copied() + .flatten(); + + let client_pubkey = PublicKey::from_hex(&client_pubkey_hex).map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + client_pubkey = %client_pubkey_hex, + "Invalid client pubkey in session map" + ); + Error::Other(error.to_string()) + })?; + + let event_id_parsed = EventId::from_hex(event_id).map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + event_id = %event_id, + "Invalid event id while sending response" + ); + Error::Other(error.to_string()) + })?; + + let base_tags = BaseTransport::create_response_tags(&client_pubkey, &event_id_parsed); + let tags = BaseTransport::compose_outbound_tags(&base_tags, &discovery_tags, &[]); + + if let Err(error) = self + .base + .send_mcp_message( + &response, + &client_pubkey, + CTXVM_MESSAGES_KIND, + tags, + Some(is_encrypted), + Self::select_outbound_gift_wrap_kind( + self.config.gift_wrap_mode, + is_encrypted, + mirrored_wrap_kind, + ), + ) + .await + { + tracing::error!( + target: LOG_TARGET, + error = %error, + client_pubkey = %client_pubkey_hex, + event_id = %event_id, + "Failed to publish response message" + ); + + // Re-register route on publish failure so caller can retry. + self.event_routes + .register( + event_id.to_string(), + client_pubkey_hex, + original_request_id, + progress_token, + ) + .await; + + return Err(error); + } + + // Clean up wrap-kind tracking + self.request_wrap_kinds.write().await.remove(event_id); + + let mut sessions = self.sessions.write().await; + if let Some(session) = sessions.get_mut(&client_pubkey_hex) { + // Clean up progress token + if let Some(token) = progress_token { + session.pending_requests.remove(&token); + } + session.event_to_progress_token.remove(event_id); + session.pending_requests.remove(event_id); + } + drop(sessions); + + tracing::debug!( + target: LOG_TARGET, + client_pubkey = %client_pubkey_hex, + event_id = %event_id, + encrypted = is_encrypted, + "Sent server response and cleaned correlation state" + ); + Ok(()) + } + + /// Send a notification to a specific client. + pub async fn send_notification( + &self, + client_pubkey_hex: &str, + notification: &JsonRpcMessage, + correlated_event_id: Option<&str>, + ) -> Result<()> { + let mut sessions = self.sessions.write().await; + let session = sessions + .get_mut(client_pubkey_hex) + .ok_or_else(|| Error::Other(format!("No session for {client_pubkey_hex}")))?; + let is_encrypted = session.is_encrypted; + let supports_ephemeral = session.supports_ephemeral_gift_wrap; + + // CEP-35: include discovery tags on first message to this client + let discovery_tags = self.take_pending_server_discovery_tags(session); + drop(sessions); + + let client_pubkey = + PublicKey::from_hex(client_pubkey_hex).map_err(|e| Error::Other(e.to_string()))?; + + let mut base_tags = BaseTransport::create_recipient_tags(&client_pubkey); + if let Some(eid) = correlated_event_id { + let event_id = EventId::from_hex(eid).map_err(|e| Error::Other(e.to_string()))?; + base_tags.push(Tag::event(event_id)); + } + + let tags = BaseTransport::compose_outbound_tags(&base_tags, &discovery_tags, &[]); + + // CEP-19: Look up mirrored wrap kind from correlated request + let correlated_wrap_kind = if let Some(event_id) = correlated_event_id { + self.request_wrap_kinds + .read() + .await + .get(event_id) + .copied() + .flatten() + } else { + None + }; + + self.base + .send_mcp_message( + notification, + &client_pubkey, + CTXVM_MESSAGES_KIND, + tags, + Some(is_encrypted), + Self::select_outbound_notification_gift_wrap_kind( + self.config.gift_wrap_mode, + is_encrypted, + correlated_wrap_kind, + supports_ephemeral, + ), + ) + .await?; + + Ok(()) + } + + /// Broadcast a notification to all initialized clients. + pub async fn broadcast_notification(&self, notification: &JsonRpcMessage) -> Result<()> { + let sessions = self.sessions.read().await; + let initialized: Vec = sessions + .iter() + .filter(|(_, s)| s.is_initialized) + .map(|(k, _)| k.clone()) + .collect(); + drop(sessions); + + for pubkey in initialized { + if let Err(error) = self.send_notification(&pubkey, notification, None).await { + tracing::error!( + target: LOG_TARGET, + error = %error, + client_pubkey = %pubkey, + "Failed to send notification" + ); + } + } + Ok(()) + } + + /// Take the message receiver for consuming incoming requests. + pub fn take_message_receiver( + &mut self, + ) -> Option> { + self.message_rx.take() + } + + /// Sets extra discovery tags to include in announcements and first-response discovery replay. + pub fn set_announcement_extra_tags(&mut self, tags: Vec) { + self.extra_common_tags = tags; + } + + /// Sets pricing tags to include in announcement/list events and capability list responses. + pub fn set_announcement_pricing_tags(&mut self, tags: Vec) { + self.pricing_tags = tags; + } + + /// Publish server announcement (kind 11316). + pub async fn announce(&self) -> Result { + let info = self + .config + .server_info + .as_ref() + .ok_or_else(|| Error::Other("No server info configured".to_string()))?; + + let content = serde_json::to_string(info)?; + + let mut tags = Vec::new(); + if let Some(ref name) = info.name { + tags.push(Tag::custom( + TagKind::Custom(tags::NAME.into()), + vec![name.clone()], + )); + } + if let Some(ref about) = info.about { + tags.push(Tag::custom( + TagKind::Custom(tags::ABOUT.into()), + vec![about.clone()], + )); + } + if let Some(ref website) = info.website { + tags.push(Tag::custom( + TagKind::Custom(tags::WEBSITE.into()), + vec![website.clone()], + )); + } + if let Some(ref picture) = info.picture { + tags.push(Tag::custom( + TagKind::Custom(tags::PICTURE.into()), + vec![picture.clone()], + )); + } + if self.config.encryption_mode != EncryptionMode::Disabled { + tags.push(Tag::custom( + TagKind::Custom(tags::SUPPORT_ENCRYPTION.into()), + Vec::::new(), + )); + if self.config.gift_wrap_mode.supports_ephemeral() { + tags.push(Tag::custom( + TagKind::Custom(tags::SUPPORT_ENCRYPTION_EPHEMERAL.into()), + Vec::::new(), + )); + } + } + tags.extend(self.extra_common_tags.iter().cloned()); + tags.extend(self.pricing_tags.iter().cloned()); + + let builder = EventBuilder::new(Kind::Custom(SERVER_ANNOUNCEMENT_KIND), content).tags(tags); + + self.base.relay_pool.publish(builder).await + } + + /// Publish tools list (kind 11317). + pub async fn publish_tools(&self, tools: Vec) -> Result { + let content = serde_json::json!({ "tools": tools }); + let builder = EventBuilder::new( + Kind::Custom(TOOLS_LIST_KIND), + serde_json::to_string(&content)?, + ) + .tags(self.pricing_tags.iter().cloned()); + self.base.relay_pool.publish(builder).await + } + + /// Publish resources list (kind 11318). + pub async fn publish_resources(&self, resources: Vec) -> Result { + let content = serde_json::json!({ "resources": resources }); + let builder = EventBuilder::new( + Kind::Custom(RESOURCES_LIST_KIND), + serde_json::to_string(&content)?, + ) + .tags(self.pricing_tags.iter().cloned()); + self.base.relay_pool.publish(builder).await + } + + /// Publish prompts list (kind 11320). + pub async fn publish_prompts(&self, prompts: Vec) -> Result { + let content = serde_json::json!({ "prompts": prompts }); + let builder = EventBuilder::new( + Kind::Custom(PROMPTS_LIST_KIND), + serde_json::to_string(&content)?, + ) + .tags(self.pricing_tags.iter().cloned()); + self.base.relay_pool.publish(builder).await + } + + /// Publish resource templates list (kind 11319). + pub async fn publish_resource_templates( + &self, + templates: Vec, + ) -> Result { + let content = serde_json::json!({ "resourceTemplates": templates }); + let builder = EventBuilder::new( + Kind::Custom(RESOURCETEMPLATES_LIST_KIND), + serde_json::to_string(&content)?, + ) + .tags(self.pricing_tags.iter().cloned()); + self.base.relay_pool.publish(builder).await + } + + /// Delete server announcements (NIP-09 kind 5). + pub async fn delete_announcements(&self, reason: &str) -> Result<()> { + // We publish kind 5 events for each announcement kind + let pubkey = self.base.get_public_key().await?; + let _pubkey_hex = pubkey.to_hex(); + + for kind in UNENCRYPTED_KINDS { + let builder = EventBuilder::new(Kind::Custom(5), reason).tag(Tag::custom( + TagKind::Custom("k".into()), + vec![kind.to_string()], + )); + self.base.relay_pool.publish(builder).await?; + } + Ok(()) + } + + /// Publish tools list from rmcp typed tool descriptors. + #[cfg(feature = "rmcp")] + pub async fn publish_tools_typed(&self, tools: Vec) -> Result { + let tools = tools + .into_iter() + .map(serde_json::to_value) + .collect::, _>>()?; + self.publish_tools(tools).await + } + + /// Publish resources list from rmcp typed resource descriptors. + #[cfg(feature = "rmcp")] + pub async fn publish_resources_typed( + &self, + resources: Vec, + ) -> Result { + let resources = resources + .into_iter() + .map(serde_json::to_value) + .collect::, _>>()?; + self.publish_resources(resources).await + } + + /// Publish prompts list from rmcp typed prompt descriptors. + #[cfg(feature = "rmcp")] + pub async fn publish_prompts_typed( + &self, + prompts: Vec, + ) -> Result { + let prompts = prompts + .into_iter() + .map(serde_json::to_value) + .collect::, _>>()?; + self.publish_prompts(prompts).await + } + + /// Publish resource templates list from rmcp typed template descriptors. + #[cfg(feature = "rmcp")] + pub async fn publish_resource_templates_typed( + &self, + templates: Vec, + ) -> Result { + let templates = templates + .into_iter() + .map(serde_json::to_value) + .collect::, _>>()?; + self.publish_resource_templates(templates).await + } + + // ── CEP-35 discovery tag helpers ────────────────────────────── + + /// Build common discovery tags from server config. + /// + /// Includes server info tags (name, about, website, picture) and capability + /// tags (support_encryption, support_encryption_ephemeral) based on the + /// transport's encryption and gift-wrap mode. + fn get_common_tags(&self) -> Vec { + let mut tags = Vec::new(); + + // Server info tags + if let Some(ref info) = self.config.server_info { + if let Some(ref name) = info.name { + tags.push(Tag::custom( + TagKind::Custom(tags::NAME.into()), + vec![name.clone()], + )); + } + if let Some(ref about) = info.about { + tags.push(Tag::custom( + TagKind::Custom(tags::ABOUT.into()), + vec![about.clone()], + )); + } + if let Some(ref website) = info.website { + tags.push(Tag::custom( + TagKind::Custom(tags::WEBSITE.into()), + vec![website.clone()], + )); + } + if let Some(ref picture) = info.picture { + tags.push(Tag::custom( + TagKind::Custom(tags::PICTURE.into()), + vec![picture.clone()], + )); + } + } + + // Capability tags + if self.config.encryption_mode != EncryptionMode::Disabled { + tags.push(Tag::custom( + TagKind::Custom(tags::SUPPORT_ENCRYPTION.into()), + Vec::::new(), + )); + if self.config.gift_wrap_mode.supports_ephemeral() { + tags.push(Tag::custom( + TagKind::Custom(tags::SUPPORT_ENCRYPTION_EPHEMERAL.into()), + Vec::::new(), + )); + } + } + + tags + } + + /// One-shot: returns common tags if not yet sent to this client, empty otherwise. + fn take_pending_server_discovery_tags(&self, session: &mut ClientSession) -> Vec { + if session.has_sent_common_tags { + return vec![]; + } + session.has_sent_common_tags = true; + self.get_common_tags() + } + + // ── Internal ──────────────────────────────────────────────── + + fn is_capability_excluded( + excluded: &[CapabilityExclusion], + method: &str, + name: Option<&str>, + ) -> bool { + // Always allow fundamental MCP methods + if method == "initialize" || method == "notifications/initialized" { + return true; + } + + excluded.iter().any(|excl| { + if excl.method != method { + return false; + } + match (&excl.name, name) { + (Some(excl_name), Some(req_name)) => excl_name == req_name, + (None, _) => true, // method-only match + _ => false, + } + }) + } + + #[allow(clippy::too_many_arguments)] + async fn event_loop( + relay_pool: Arc, + sessions: SessionStore, + event_routes: ServerEventRouteStore, + request_wrap_kinds: Arc>>>, + tx: tokio::sync::mpsc::UnboundedSender, + allowed_pubkeys: Vec, + excluded_capabilities: Vec, + encryption_mode: EncryptionMode, + gift_wrap_mode: GiftWrapMode, + is_announced_server: bool, + server_info: Option, + extra_common_tags: Vec, + seen_gift_wrap_ids: Arc>>, + cancel: CancellationToken, + ) { + let mut notifications = relay_pool.notifications(); + + loop { + let notification = tokio::select! { + _ = cancel.cancelled() => { + tracing::info!( + target: LOG_TARGET, + "Server event loop cancelled" + ); + break; + } + result = notifications.recv() => { + match result { + Ok(n) => n, + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + tracing::warn!( + target: LOG_TARGET, + skipped = n, + "Relay broadcast lagged, skipping missed events" + ); + continue; + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => break, + } + } + }; + if let RelayPoolNotification::Event { event, .. } = notification { + let is_gift_wrap = event.kind == Kind::Custom(GIFT_WRAP_KIND) + || event.kind == Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND); + let outer_kind: u16 = event.kind.as_u16(); + + // CEP-19: Drop gift-wraps that violate the configured gift-wrap mode + if is_gift_wrap && !gift_wrap_mode.allows_kind(outer_kind) { + tracing::warn!( + target: LOG_TARGET, + event_id = %event.id.to_hex(), + event_kind = outer_kind, + configured_mode = ?gift_wrap_mode, + "Dropping gift-wrap because it violates gift_wrap_mode policy" + ); + continue; + } + + let (content, sender_pubkey, event_id, is_encrypted, inner_tags) = if is_gift_wrap { + if encryption_mode == EncryptionMode::Disabled { + tracing::warn!( + target: LOG_TARGET, + event_id = %event.id.to_hex(), + sender_pubkey = %event.pubkey.to_hex(), + "Received encrypted message but encryption is disabled" + ); + continue; + } + { + let guard = match seen_gift_wrap_ids.lock() { + Ok(g) => g, + Err(poisoned) => poisoned.into_inner(), + }; + if guard.contains(&event.id) { + tracing::debug!( + target: LOG_TARGET, + event_id = %event.id.to_hex(), + "Skipping duplicate gift-wrap (outer id)" + ); + continue; + } + } + // Single-layer NIP-44 decrypt (matches JS/TS SDK) + let signer = match relay_pool.signer().await { + Ok(s) => s, + Err(error) => { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to get signer" + ); + continue; + } + }; + match encryption::decrypt_gift_wrap_single_layer(&signer, &event).await { + Ok(decrypted_json) => { + // The decrypted content is JSON of the inner signed event. + // Use the INNER event's ID for correlation — the client + // registers the inner event ID in its correlation store. + match serde_json::from_str::(&decrypted_json) { + Ok(inner) => { + if let Err(e) = inner.verify() { + tracing::warn!( + "Inner event signature verification failed: {e}" + ); + continue; + } + { + let mut guard = match seen_gift_wrap_ids.lock() { + Ok(g) => g, + Err(poisoned) => poisoned.into_inner(), + }; + guard.put(event.id, ()); + } + let inner_tags: Vec = inner.tags.to_vec(); + ( + inner.content, + inner.pubkey.to_hex(), + inner.id.to_hex(), + true, + inner_tags, + ) + } + Err(error) => { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to parse inner event" + ); + continue; + } + } + } + Err(error) => { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to decrypt" + ); + continue; + } + } + } else { + if encryption_mode == EncryptionMode::Required { + tracing::warn!( + target: LOG_TARGET, + sender_pubkey = %event.pubkey.to_hex(), + "Received unencrypted message but encryption is required" + ); + continue; + } + ( + event.content.clone(), + event.pubkey.to_hex(), + event.id.to_hex(), + false, + event.tags.to_vec(), + ) + }; + + // Parse MCP message + let mcp_msg = match validation::validate_and_parse(&content) { + Some(msg) => msg, + None => { + tracing::warn!( + target: LOG_TARGET, + sender_pubkey = %sender_pubkey, + "Invalid MCP message" + ); + continue; + } + }; + + // Authorization check + if !allowed_pubkeys.is_empty() { + let method = mcp_msg.method().unwrap_or(""); + let name = match &mcp_msg { + JsonRpcMessage::Request(r) => r + .params + .as_ref() + .and_then(|p| p.get("name")) + .and_then(|n| n.as_str()), + _ => None, + }; + + let is_excluded = + Self::is_capability_excluded(&excluded_capabilities, method, name); + + if !allowed_pubkeys.contains(&sender_pubkey) && !is_excluded { + tracing::warn!( + target: LOG_TARGET, + sender_pubkey = %sender_pubkey, + method = method, + "Unauthorized request" + ); + + // Send a JSON-RPC error back for Request messages so the + // client doesn't hang indefinitely (announced servers only). + if is_announced_server { + if let JsonRpcMessage::Request(ref req) = mcp_msg { + if let Ok(client_pk) = PublicKey::from_hex(&sender_pubkey) { + let event_id_parsed = EventId::from_hex(&event_id) + .unwrap_or(EventId::all_zeros()); + let mut tags = BaseTransport::create_response_tags( + &client_pk, + &event_id_parsed, + ); + + // CEP-19: Inject common discovery tags on first response + let has_sent = sessions + .get_session(&sender_pubkey) + .await + .is_some_and(|s| s.has_sent_common_tags); + if !has_sent { + Self::append_common_response_tags( + &mut tags, + server_info.as_ref(), + &extra_common_tags, + encryption_mode, + gift_wrap_mode, + ); + sessions.mark_common_tags_sent(&sender_pubkey).await; + } + + let error_response = + JsonRpcMessage::ErrorResponse(JsonRpcErrorResponse { + jsonrpc: "2.0".to_string(), + id: req.id.clone(), + error: JsonRpcError { + code: -32000, + message: "Unauthorized".to_string(), + data: None, + }, + }); + + let base = BaseTransport { + relay_pool: Arc::clone(&relay_pool), + encryption_mode, + is_connected: true, + }; + if let Err(e) = base + .send_mcp_message( + &error_response, + &client_pk, + CTXVM_MESSAGES_KIND, + tags, + Some(is_encrypted), + Self::select_outbound_gift_wrap_kind( + gift_wrap_mode, + is_encrypted, + if is_gift_wrap { Some(outer_kind) } else { None }, + ), + ) + .await + { + tracing::error!( + target: LOG_TARGET, + error = %e, + sender_pubkey = %sender_pubkey, + "Failed to send unauthorized error response" + ); + } + } + } + } // if is_announced_server + + continue; + } + } + + // Session management + let on_evicted_cb = sessions.eviction_callback(); + let mut sessions_w = sessions.write().await; + if !sessions_w.contains(&sender_pubkey) { + let evicted = + sessions_w.push(sender_pubkey.clone(), ClientSession::new(is_encrypted)); + SessionStore::handle_eviction( + &sender_pubkey, + evicted, + &mut sessions_w, + on_evicted_cb.as_ref(), + &event_routes, + ) + .await; + } + let session = sessions_w.get_mut(&sender_pubkey).unwrap(); + session.update_activity(); + session.is_encrypted = is_encrypted; + + // CEP-19: Mark ephemeral support if client used kind 21059 + if is_gift_wrap && outer_kind == EPHEMERAL_GIFT_WRAP_KIND { + session.supports_ephemeral_gift_wrap = true; + } + + // CEP-35: learn client capabilities from inner event tags + let discovered = learn_peer_capabilities(&inner_tags); + session.supports_encryption |= discovered.supports_encryption; + session.supports_ephemeral_encryption |= discovered.supports_ephemeral_encryption; + // Only learn oversized support if CEP-22 is enabled on this server + // TODO: wire from config when CEP-22 lands + let oversized_enabled = false; + session.supports_oversized_transfer |= + oversized_enabled && discovered.supports_oversized_transfer; + + // Track request for correlation + if let JsonRpcMessage::Request(ref req) = mcp_msg { + let original_id = req.id.clone(); + + // Extract progress token from _meta if present. + let progress_token = req + .params + .as_ref() + .and_then(|p| p.get("_meta")) + .and_then(|m| m.get("progressToken")) + .and_then(|t| t.as_str()) + .map(String::from); + + // Duplicate into session fields (kept for backward compat). + session + .pending_requests + .insert(event_id.clone(), original_id.clone()); + if let Some(ref token) = progress_token { + session + .pending_requests + .insert(token.clone(), serde_json::json!(event_id)); + session + .event_to_progress_token + .insert(event_id.clone(), token.clone()); + } + + drop(sessions_w); + + // CEP-19: Record the incoming wrap kind for response mirroring + { + let mut kinds_w = request_wrap_kinds.write().await; + kinds_w.insert( + event_id.clone(), + if is_gift_wrap { Some(outer_kind) } else { None }, + ); + } + + event_routes + .register( + event_id.clone(), + sender_pubkey.clone(), + original_id, + progress_token, + ) + .await; + } else { + drop(sessions_w); + } + + // Handle initialized notification (re-acquire for write) + if let JsonRpcMessage::Notification(ref n) = mcp_msg { + if n.method == "notifications/initialized" { + let mut sessions_w2 = sessions.write().await; + if let Some(session) = sessions_w2.get_mut(&sender_pubkey) { + session.is_initialized = true; + } + } + } + + // Forward to consumer + let _ = tx.send(IncomingRequest { + message: mcp_msg, + client_pubkey: sender_pubkey, + event_id, + is_encrypted, + }); + } + } + } + + async fn cleanup_sessions( + sessions: &SessionStore, + event_routes: &ServerEventRouteStore, + request_wrap_kinds: &Arc>>>, + timeout: Duration, + ) -> usize { + let mut sessions_w = sessions.write().await; + let mut cleaned = 0; + let mut stale_event_ids = Vec::new(); + + // LruCache has no retain(); collect expired keys then pop each one. + let expired_keys: Vec = sessions_w + .iter() + .filter(|(_, session)| session.last_activity.elapsed() > timeout) + .map(|(k, _)| k.clone()) + .collect(); + + for key in &expired_keys { + if let Some(session) = sessions_w.pop(key) { + stale_event_ids.extend(session.pending_requests.keys().cloned()); + stale_event_ids.extend(session.event_to_progress_token.keys().cloned()); + tracing::debug!( + target: LOG_TARGET, + client_pubkey = %key, + "Session expired" + ); + cleaned += 1; + } + } + drop(sessions_w); + + { + let mut kinds_w = request_wrap_kinds.write().await; + for event_id in &stale_event_ids { + kinds_w.remove(event_id); + } + } + + for event_id in &stale_event_ids { + event_routes.pop(event_id).await; + } + + cleaned + } + + /// CEP-19: Choose outbound gift-wrap kind for responses. + /// If `is_encrypted` is false, return None (send plaintext). + /// Otherwise mirror the kind used by the client, falling back to the mode default. + fn select_outbound_gift_wrap_kind( + mode: GiftWrapMode, + is_encrypted: bool, + mirrored_kind: Option, + ) -> Option { + if !is_encrypted { + return None; + } + if let Some(kind) = mirrored_kind { + if mode.allows_kind(kind) { + return Some(kind); + } + } + match mode { + GiftWrapMode::Persistent => Some(GIFT_WRAP_KIND), + GiftWrapMode::Ephemeral => Some(EPHEMERAL_GIFT_WRAP_KIND), + GiftWrapMode::Optional => Some(GIFT_WRAP_KIND), + } + } + + /// CEP-19: Choose outbound gift-wrap kind for notifications. + fn select_outbound_notification_gift_wrap_kind( + mode: GiftWrapMode, + is_encrypted: bool, + correlated_wrap_kind: Option, + client_supports_ephemeral: bool, + ) -> Option { + if !is_encrypted { + return None; + } + // Mirror correlated request kind if available + if let Some(kind) = correlated_wrap_kind { + if mode.allows_kind(kind) { + return Some(kind); + } + } + // Fall back based on learned ephemeral support + if client_supports_ephemeral && mode.supports_ephemeral() { + return Some(EPHEMERAL_GIFT_WRAP_KIND); + } + match mode { + GiftWrapMode::Persistent => Some(GIFT_WRAP_KIND), + GiftWrapMode::Ephemeral => Some(EPHEMERAL_GIFT_WRAP_KIND), + GiftWrapMode::Optional => Some(GIFT_WRAP_KIND), + } + } + + /// CEP-19: Append server capability discovery tags to the given tag vec. + fn append_common_response_tags( + tags: &mut Vec, + server_info: Option<&ServerInfo>, + extra_common_tags: &[Tag], + encryption_mode: EncryptionMode, + gift_wrap_mode: GiftWrapMode, + ) { + if encryption_mode != EncryptionMode::Disabled { + tags.push(Tag::custom( + TagKind::Custom(tags::SUPPORT_ENCRYPTION.into()), + Vec::::new(), + )); + if gift_wrap_mode.supports_ephemeral() { + tags.push(Tag::custom( + TagKind::Custom(tags::SUPPORT_ENCRYPTION_EPHEMERAL.into()), + Vec::::new(), + )); + } + } + if let Some(info) = server_info { + if let Some(ref name) = info.name { + tags.push(Tag::custom( + TagKind::Custom(tags::NAME.into()), + vec![name.clone()], + )); + } + } + tags.extend(extra_common_tags.iter().cloned()); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::thread; + + // ── Session management ────────────────────────────────────── + + #[test] + fn test_client_session_creation() { + let session = ClientSession::new(true); + assert!(!session.is_initialized); + assert!(session.is_encrypted); + assert!(!session.has_sent_common_tags); + assert!(!session.supports_ephemeral_gift_wrap); + assert!(session.pending_requests.is_empty()); + assert!(session.event_to_progress_token.is_empty()); + } + + #[test] + fn test_client_session_update_activity() { + let mut session = ClientSession::new(false); + let first = session.last_activity; + thread::sleep(Duration::from_millis(10)); + session.update_activity(); + assert!(session.last_activity > first); + } + + #[tokio::test] + async fn test_cleanup_sessions_removes_expired() { + let sessions = SessionStore::new(); + let event_routes = ServerEventRouteStore::new(); + + // Insert a session with an old activity time + let mut session = ClientSession::new(false); + session + .pending_requests + .insert("evt1".to_string(), serde_json::json!(1)); + sessions.write().await.put("pubkey1".to_string(), session); + event_routes + .register( + "evt1".to_string(), + "pubkey1".to_string(), + serde_json::json!(1), + None, + ) + .await; + + let request_wrap_kinds = Arc::new(RwLock::new(HashMap::new())); + + // With a long timeout, nothing should be cleaned + let cleaned = NostrServerTransport::cleanup_sessions( + &sessions, + &event_routes, + &request_wrap_kinds, + Duration::from_secs(300), + ) + .await; + assert_eq!(cleaned, 0); + assert_eq!(sessions.session_count().await, 1); + + // With zero timeout, it should be cleaned + thread::sleep(Duration::from_millis(5)); + let cleaned = NostrServerTransport::cleanup_sessions( + &sessions, + &event_routes, + &request_wrap_kinds, + Duration::from_millis(1), + ) + .await; + assert_eq!(cleaned, 1); + assert_eq!(sessions.session_count().await, 0); + assert!(event_routes.pop("evt1").await.is_none()); + } + + #[tokio::test] + async fn test_cleanup_preserves_active_sessions() { + let sessions = SessionStore::new(); + let event_routes = ServerEventRouteStore::new(); + let request_wrap_kinds = Arc::new(RwLock::new(HashMap::new())); + + sessions + .get_or_create_session("active", false, &event_routes) + .await; + + let cleaned = NostrServerTransport::cleanup_sessions( + &sessions, + &event_routes, + &request_wrap_kinds, + Duration::from_secs(300), + ) + .await; + assert_eq!(cleaned, 0); + assert_eq!(sessions.session_count().await, 1); + } + + // ── Request ID correlation ────────────────────────────────── + + #[test] + fn test_pending_request_tracking() { + let mut session = ClientSession::new(false); + session + .pending_requests + .insert("event_abc".to_string(), serde_json::json!(42)); + assert_eq!( + session.pending_requests.get("event_abc"), + Some(&serde_json::json!(42)) + ); + } + + #[test] + fn test_progress_token_tracking() { + let mut session = ClientSession::new(false); + session + .event_to_progress_token + .insert("evt1".to_string(), "token1".to_string()); + session + .pending_requests + .insert("token1".to_string(), serde_json::json!("evt1")); + assert_eq!( + session.event_to_progress_token.get("evt1"), + Some(&"token1".to_string()) + ); + } + + // ── Authorization (is_capability_excluded) ────────────────── + + #[test] + fn test_initialize_always_excluded() { + assert!(NostrServerTransport::is_capability_excluded( + &[], + "initialize", + None + )); + assert!(NostrServerTransport::is_capability_excluded( + &[], + "notifications/initialized", + None + )); + } + + #[test] + fn test_method_excluded_without_name() { + let exclusions = vec![CapabilityExclusion { + method: "tools/list".to_string(), + name: None, + }]; + assert!(NostrServerTransport::is_capability_excluded( + &exclusions, + "tools/list", + None + )); + assert!(NostrServerTransport::is_capability_excluded( + &exclusions, + "tools/list", + Some("anything") + )); + } + + #[test] + fn test_method_excluded_with_name() { + let exclusions = vec![CapabilityExclusion { + method: "tools/call".to_string(), + name: Some("get_weather".to_string()), + }]; + assert!(NostrServerTransport::is_capability_excluded( + &exclusions, + "tools/call", + Some("get_weather") + )); + assert!(!NostrServerTransport::is_capability_excluded( + &exclusions, + "tools/call", + Some("other_tool") + )); + assert!(!NostrServerTransport::is_capability_excluded( + &exclusions, + "tools/call", + None + )); + } + + #[test] + fn test_non_excluded_method() { + let exclusions = vec![CapabilityExclusion { + method: "tools/list".to_string(), + name: None, + }]; + assert!(!NostrServerTransport::is_capability_excluded( + &exclusions, + "tools/call", + None + )); + assert!(!NostrServerTransport::is_capability_excluded( + &exclusions, + "resources/list", + None + )); + } + + #[test] + fn test_empty_exclusions_non_init_method() { + assert!(!NostrServerTransport::is_capability_excluded( + &[], + "tools/list", + None + )); + assert!(!NostrServerTransport::is_capability_excluded( + &[], + "tools/call", + Some("x") + )); + } + + // ── Encryption mode enforcement ───────────────────────────── + + #[test] + fn test_encryption_mode_default() { + let config = NostrServerTransportConfig::default(); + assert_eq!(config.encryption_mode, EncryptionMode::Optional); + } + + // ── Config defaults ───────────────────────────────────────── + + #[test] + fn test_config_defaults() { + let config = NostrServerTransportConfig::default(); + assert_eq!(config.relay_urls, vec!["wss://relay.damus.io".to_string()]); + assert!(!config.is_announced_server); + assert_eq!(config.gift_wrap_mode, GiftWrapMode::Optional); + assert!(config.allowed_public_keys.is_empty()); + assert!(config.excluded_capabilities.is_empty()); + assert_eq!(config.max_sessions, 1000); + assert_eq!(config.cleanup_interval, Duration::from_secs(60)); + assert_eq!(config.session_timeout, Duration::from_secs(300)); + assert_eq!(config.request_timeout, Duration::from_secs(60)); + assert!(config.server_info.is_none()); + } + + // ── CEP-19 helper logic ────────────────────────────────────── + + #[test] + fn test_select_outbound_gift_wrap_kind_plaintext() { + assert_eq!( + NostrServerTransport::select_outbound_gift_wrap_kind( + GiftWrapMode::Optional, + false, + Some(GIFT_WRAP_KIND), + ), + None + ); + } + + #[test] + fn test_select_outbound_gift_wrap_kind_mirrors_incoming() { + assert_eq!( + NostrServerTransport::select_outbound_gift_wrap_kind( + GiftWrapMode::Optional, + true, + Some(EPHEMERAL_GIFT_WRAP_KIND), + ), + Some(EPHEMERAL_GIFT_WRAP_KIND) + ); + } + + #[test] + fn test_select_outbound_gift_wrap_kind_persistent_mode_overrides_ephemeral() { + assert_eq!( + NostrServerTransport::select_outbound_gift_wrap_kind( + GiftWrapMode::Persistent, + true, + Some(EPHEMERAL_GIFT_WRAP_KIND), + ), + Some(GIFT_WRAP_KIND) + ); + } + + #[test] + fn test_append_common_response_tags_includes_encryption_when_optional() { + let mut tags = Vec::new(); + NostrServerTransport::append_common_response_tags( + &mut tags, + None, + &[], + EncryptionMode::Optional, + GiftWrapMode::Optional, + ); + let kinds: Vec = tags.iter().map(|t| format!("{:?}", t.kind())).collect(); + assert!( + kinds.iter().any(|k| k.contains("support_encryption")), + "should include support_encryption tag" + ); + } + + #[test] + fn test_append_common_response_tags_no_encryption_when_disabled() { + let mut tags = Vec::new(); + NostrServerTransport::append_common_response_tags( + &mut tags, + None, + &[], + EncryptionMode::Disabled, + GiftWrapMode::Optional, + ); + assert!( + tags.is_empty(), + "should not include encryption tags when encryption disabled" + ); + } + + #[test] + fn test_select_outbound_notification_gift_wrap_kind_plaintext() { + assert_eq!( + NostrServerTransport::select_outbound_notification_gift_wrap_kind( + GiftWrapMode::Optional, + false, + Some(EPHEMERAL_GIFT_WRAP_KIND), + true, + ), + None + ); + } + + #[test] + fn test_select_outbound_notification_gift_wrap_kind_mirrors_correlated() { + assert_eq!( + NostrServerTransport::select_outbound_notification_gift_wrap_kind( + GiftWrapMode::Optional, + true, + Some(EPHEMERAL_GIFT_WRAP_KIND), + false, + ), + Some(EPHEMERAL_GIFT_WRAP_KIND) + ); + } + + #[test] + fn test_select_outbound_notification_gift_wrap_kind_falls_back_to_mode_if_correlated_not_allowed( + ) { + assert_eq!( + NostrServerTransport::select_outbound_notification_gift_wrap_kind( + GiftWrapMode::Ephemeral, + true, + Some(GIFT_WRAP_KIND), + false, + ), + Some(EPHEMERAL_GIFT_WRAP_KIND) + ); + } + + #[test] + fn test_select_outbound_notification_gift_wrap_kind_uses_ephemeral_if_supported() { + assert_eq!( + NostrServerTransport::select_outbound_notification_gift_wrap_kind( + GiftWrapMode::Optional, + true, + None, + true, + ), + Some(EPHEMERAL_GIFT_WRAP_KIND) + ); + } + + #[test] + fn test_select_outbound_notification_gift_wrap_kind_uses_persistent_if_ephemeral_supported_but_mode_persistent( + ) { + assert_eq!( + NostrServerTransport::select_outbound_notification_gift_wrap_kind( + GiftWrapMode::Persistent, + true, + None, + true, + ), + Some(GIFT_WRAP_KIND) + ); + } + + #[test] + fn test_select_outbound_notification_gift_wrap_kind_uses_default_mode_if_ephemeral_not_supported( + ) { + assert_eq!( + NostrServerTransport::select_outbound_notification_gift_wrap_kind( + GiftWrapMode::Optional, + true, + None, + false, + ), + Some(GIFT_WRAP_KIND) + ); + } + + #[test] + fn test_append_common_response_tags_includes_ephemeral_tag() { + let mut tags = Vec::new(); + NostrServerTransport::append_common_response_tags( + &mut tags, + None, + &[], + EncryptionMode::Optional, + GiftWrapMode::Optional, + ); + let kinds: Vec = tags.iter().map(|t| format!("{:?}", t.kind())).collect(); + assert!( + kinds + .iter() + .any(|k| k.contains("support_encryption_ephemeral")), + "should include support_encryption_ephemeral tag" + ); + } + + #[test] + fn test_append_common_response_tags_includes_server_info() { + let mut tags = Vec::new(); + let server_info = ServerInfo { + name: Some("TestServer".to_string()), + ..Default::default() + }; + NostrServerTransport::append_common_response_tags( + &mut tags, + Some(&server_info), + &[], + EncryptionMode::Disabled, + GiftWrapMode::Optional, + ); + let tag_value = tags + .iter() + .find(|t| (*t).clone().to_vec().first().map(|s| s.as_str()) == Some("name")) + .and_then(|t| t.clone().to_vec().get(1).cloned()); + assert_eq!(tag_value.as_deref(), Some("TestServer")); + } + + #[test] + fn test_append_common_response_tags_extra_tags() { + let mut tags = Vec::new(); + let extra_tags = vec![Tag::custom( + TagKind::Custom("custom_tag".into()), + vec!["value".to_string()], + )]; + NostrServerTransport::append_common_response_tags( + &mut tags, + None, + &extra_tags, + EncryptionMode::Disabled, + GiftWrapMode::Optional, + ); + let tag_value = tags + .iter() + .find(|t| (*t).clone().to_vec().first().map(|s| s.as_str()) == Some("custom_tag")) + .and_then(|t| t.clone().to_vec().get(1).cloned()); + assert_eq!(tag_value.as_deref(), Some("value")); + } + + // ── CEP-35 discovery tag helpers ──────────────────────────── + + #[test] + fn test_cep35_client_session_new_fields_default_false() { + let session = ClientSession::new(false); + assert!(!session.has_sent_common_tags); + assert!(!session.supports_encryption); + assert!(!session.supports_ephemeral_encryption); + assert!(!session.supports_oversized_transfer); + } + + #[test] + fn test_cep35_capability_or_assign() { + let mut session = ClientSession::new(false); + + session.supports_encryption |= true; + session.supports_ephemeral_encryption |= false; + + session.supports_encryption |= false; + session.supports_ephemeral_encryption |= true; + + assert!(session.supports_encryption, "OR-assign must not downgrade"); + assert!(session.supports_ephemeral_encryption); + assert!(!session.supports_oversized_transfer); + } + + #[test] + fn test_config_gift_wrap_mode_default() { + let config = NostrServerTransportConfig::default(); + assert_eq!(config.gift_wrap_mode, GiftWrapMode::Optional); + } +} diff --git a/src/transport/server/session_store.rs b/src/transport/server/session_store.rs new file mode 100644 index 0000000..a37d53a --- /dev/null +++ b/src/transport/server/session_store.rs @@ -0,0 +1,534 @@ +//! Server-side session store for managing client sessions. +//! +//! Uses an LRU cache bounded by `max_sessions` (default 1000, matching the TS SDK +//! server session store). When a new session would exceed capacity the +//! least-recently-used session is evicted. If the evicted session still has +//! active routes in the correlation store it is recreated with clean state +//! (eviction safety, matching TS SDK's `hasActiveRoutesForClient` check), and +//! the optional eviction callback fires so external code can clean up resources. + +use std::num::NonZeroUsize; +use std::sync::Arc; + +use lru::LruCache; +use tokio::sync::RwLock; + +use crate::core::types::ClientSession; +use crate::transport::server::ServerEventRouteStore; + +const LOG_TARGET: &str = "contextvm_sdk::transport::server::session_store"; + +/// Default maximum number of concurrent client sessions. +/// +/// Matches the TS SDK's `SessionStore` default (`maxSessions ?? 1000`), not +/// the broader `DEFAULT_LRU_SIZE` constant (5000) used elsewhere in the TS SDK. +pub const DEFAULT_MAX_SESSIONS: usize = 1000; + +/// Callback invoked when a session is evicted from the LRU cache. +/// Receives the evicted client's public key (hex). +pub type EvictionCallback = Arc; + +/// Manages client sessions keyed by public key (hex). +/// +/// Backed by an LRU cache so memory usage is bounded. +#[derive(Clone)] +pub struct SessionStore { + sessions: Arc>>, + on_evicted: Option, +} + +impl Default for SessionStore { + fn default() -> Self { + Self::new() + } +} + +impl SessionStore { + /// Create a store with the default capacity ([`DEFAULT_MAX_SESSIONS`]). + pub fn new() -> Self { + Self::with_capacity(DEFAULT_MAX_SESSIONS) + } + + /// Create a store with a specific maximum number of sessions. + pub fn with_capacity(max_sessions: usize) -> Self { + Self { + sessions: Arc::new(RwLock::new(LruCache::new( + NonZeroUsize::new(max_sessions).unwrap_or(NonZeroUsize::new(1).unwrap()), + ))), + on_evicted: None, + } + } + + /// Register a callback that fires when a session is evicted from the LRU. + pub fn set_eviction_callback(&mut self, cb: EvictionCallback) { + self.on_evicted = Some(cb); + } + + /// Clone the eviction callback (cheap Arc clone) for use outside the lock. + pub fn eviction_callback(&self) -> Option { + self.on_evicted.clone() + } + + /// Get an existing session or create a new one. Returns `true` if a new session was created. + /// + /// `event_routes` is consulted during eviction safety: if the evicted client + /// still has active routes, the session is recreated with clean state + /// (matching TS SDK's `hasActiveRoutesForClient` check). + pub async fn get_or_create_session( + &self, + client_pubkey: &str, + is_encrypted: bool, + event_routes: &ServerEventRouteStore, + ) -> bool { + let on_evicted = self.on_evicted.clone(); + let mut sessions = self.sessions.write().await; + if let Some(session) = sessions.get_mut(client_pubkey) { + session.is_encrypted = is_encrypted; + false + } else { + let new_session = ClientSession::new(is_encrypted); + let evicted = sessions.push(client_pubkey.to_string(), new_session); + Self::handle_eviction( + client_pubkey, + evicted, + &mut sessions, + on_evicted.as_ref(), + event_routes, + ) + .await; + true + } + } + + /// Get a read-only snapshot of session fields. + /// Returns `None` if the session does not exist. + pub async fn get_session(&self, client_pubkey: &str) -> Option { + let sessions = self.sessions.read().await; + sessions.peek(client_pubkey).map(|s| SessionSnapshot { + is_initialized: s.is_initialized, + is_encrypted: s.is_encrypted, + has_sent_common_tags: s.has_sent_common_tags, + supports_ephemeral_gift_wrap: s.supports_ephemeral_gift_wrap, + }) + } + + /// Mark a session as initialized. Returns `true` if the session existed. + pub async fn mark_initialized(&self, client_pubkey: &str) -> bool { + let mut sessions = self.sessions.write().await; + if let Some(session) = sessions.get_mut(client_pubkey) { + session.is_initialized = true; + true + } else { + false + } + } + + /// Mark that common tags have been sent for this session. + pub async fn mark_common_tags_sent(&self, client_pubkey: &str) -> bool { + let mut sessions = self.sessions.write().await; + if let Some(session) = sessions.get_mut(client_pubkey) { + session.has_sent_common_tags = true; + true + } else { + false + } + } + + /// Remove a session. Returns `true` if it existed. + pub async fn remove_session(&self, client_pubkey: &str) -> bool { + self.sessions.write().await.pop(client_pubkey).is_some() + } + + /// Remove all sessions. + pub async fn clear(&self) { + self.sessions.write().await.clear(); + } + + /// Number of active sessions. + pub async fn session_count(&self) -> usize { + self.sessions.read().await.len() + } + + /// Return a snapshot of all sessions as `(client_pubkey, snapshot)` pairs. + pub async fn get_all_sessions(&self) -> Vec<(String, SessionSnapshot)> { + let sessions = self.sessions.read().await; + sessions + .iter() + .map(|(k, s)| { + ( + k.clone(), + SessionSnapshot { + is_initialized: s.is_initialized, + is_encrypted: s.is_encrypted, + has_sent_common_tags: s.has_sent_common_tags, + supports_ephemeral_gift_wrap: s.supports_ephemeral_gift_wrap, + }, + ) + }) + .collect() + } + + /// Acquire write access to the underlying LRU cache (transport internals only). + pub(crate) async fn write( + &self, + ) -> tokio::sync::RwLockWriteGuard<'_, LruCache> { + self.sessions.write().await + } + + /// Acquire read access to the underlying LRU cache (transport internals only). + pub(crate) async fn read( + &self, + ) -> tokio::sync::RwLockReadGuard<'_, LruCache> { + self.sessions.read().await + } + + /// Handle a potential LRU eviction after inserting a session. + /// + /// If the evicted client still has active routes in the correlation store, + /// a clean session is re-inserted (eviction safety, matching TS SDK's + /// `hasActiveRoutesForClient` check). The eviction callback fires only + /// for genuine, non-vetoed evictions. + pub(crate) async fn handle_eviction( + inserted_key: &str, + evicted: Option<(String, ClientSession)>, + sessions: &mut LruCache, + on_evicted: Option<&EvictionCallback>, + event_routes: &ServerEventRouteStore, + ) { + if let Some((evicted_key, evicted_session)) = evicted { + // `push` also returns the old value when the *same* key is updated; + // only act when a *different* key was evicted due to capacity. + if evicted_key != inserted_key { + if event_routes + .has_active_routes_for_client(&evicted_key) + .await + { + tracing::warn!( + target: LOG_TARGET, + client_pubkey = %evicted_key, + "LRU eviction of session with active routes; recreating with clean state" + ); + // Re-insert with clean state so the client isn't orphaned. + // Skip the external callback — the session still exists + // (matches TS SDK: vetoed evictions don't fire the callback). + let _ = sessions.push( + evicted_key.clone(), + ClientSession::new(evicted_session.is_encrypted), + ); + } else if let Some(cb) = on_evicted { + cb(evicted_key); + } + } + } + } +} + +/// A lightweight snapshot of session state (avoids exposing the full `ClientSession` +/// through the async API boundary). +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SessionSnapshot { + pub is_initialized: bool, + pub is_encrypted: bool, + pub has_sent_common_tags: bool, + pub supports_ephemeral_gift_wrap: bool, +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + fn routes() -> ServerEventRouteStore { + ServerEventRouteStore::new() + } + + #[tokio::test] + async fn create_and_retrieve_session() { + let store = SessionStore::new(); + let r = routes(); + + let created = store.get_or_create_session("client-1", true, &r).await; + assert!(created); + + let snap = store.get_session("client-1").await.unwrap(); + assert!(snap.is_encrypted); + assert!(!snap.is_initialized); + } + + #[tokio::test] + async fn get_or_create_returns_existing() { + let store = SessionStore::new(); + let r = routes(); + + let created = store.get_or_create_session("client-1", false, &r).await; + assert!(created); + + let created2 = store.get_or_create_session("client-1", true, &r).await; + assert!(!created2); + + let snap = store.get_session("client-1").await.unwrap(); + assert!(snap.is_encrypted); + } + + #[tokio::test] + async fn mark_initialized() { + let store = SessionStore::new(); + let r = routes(); + store.get_or_create_session("client-1", false, &r).await; + + assert!(store.mark_initialized("client-1").await); + let snap = store.get_session("client-1").await.unwrap(); + assert!(snap.is_initialized); + } + + #[tokio::test] + async fn mark_initialized_unknown_returns_false() { + let store = SessionStore::new(); + assert!(!store.mark_initialized("unknown").await); + } + + #[tokio::test] + async fn remove_session() { + let store = SessionStore::new(); + let r = routes(); + store.get_or_create_session("client-1", false, &r).await; + assert!(store.remove_session("client-1").await); + assert!(store.get_session("client-1").await.is_none()); + } + + #[tokio::test] + async fn remove_unknown_returns_false() { + let store = SessionStore::new(); + assert!(!store.remove_session("unknown").await); + } + + #[tokio::test] + async fn clear_all_sessions() { + let store = SessionStore::new(); + let r = routes(); + store.get_or_create_session("client-1", false, &r).await; + store.get_or_create_session("client-2", true, &r).await; + + store.clear().await; + + assert_eq!(store.session_count().await, 0); + assert!(store.get_session("client-1").await.is_none()); + assert!(store.get_session("client-2").await.is_none()); + } + + #[tokio::test] + async fn get_all_sessions() { + let store = SessionStore::new(); + let r = routes(); + store.get_or_create_session("client-1", false, &r).await; + store.get_or_create_session("client-2", true, &r).await; + + let all = store.get_all_sessions().await; + assert_eq!(all.len(), 2); + + let keys: Vec<&str> = all.iter().map(|(k, _)| k.as_str()).collect(); + assert!(keys.contains(&"client-1")); + assert!(keys.contains(&"client-2")); + } + + // ── CEP-35 capability fields ──────────────────────────────── + + #[tokio::test] + async fn new_session_capability_fields_default_false() { + let store = SessionStore::new(); + let r = routes(); + store.get_or_create_session("client-1", false, &r).await; + + let sessions = store.read().await; + let session = sessions.peek("client-1").unwrap(); + assert!(!session.has_sent_common_tags); + assert!(!session.supports_encryption); + assert!(!session.supports_ephemeral_encryption); + assert!(!session.supports_oversized_transfer); + } + + #[tokio::test] + async fn has_sent_common_tags_flag() { + let store = SessionStore::new(); + let r = routes(); + store.get_or_create_session("client-1", false, &r).await; + + let mut sessions = store.write().await; + let session = sessions.get_mut("client-1").unwrap(); + assert!(!session.has_sent_common_tags); + session.has_sent_common_tags = true; + assert!(session.has_sent_common_tags); + } + + #[tokio::test] + async fn capability_or_assign_persists() { + let store = SessionStore::new(); + let r = routes(); + store.get_or_create_session("client-1", false, &r).await; + + { + let mut sessions = store.write().await; + let session = sessions.get_mut("client-1").unwrap(); + session.supports_encryption |= true; + session.supports_ephemeral_encryption |= false; + } + + { + let mut sessions = store.write().await; + let session = sessions.get_mut("client-1").unwrap(); + session.supports_encryption |= false; + session.supports_ephemeral_encryption |= true; + } + + let sessions = store.read().await; + let session = sessions.peek("client-1").unwrap(); + assert!(session.supports_encryption, "OR-assign must not downgrade"); + assert!(session.supports_ephemeral_encryption); + assert!(!session.supports_oversized_transfer); + } + + #[tokio::test] + async fn capability_fields_independent_per_client() { + let store = SessionStore::new(); + let r = routes(); + store.get_or_create_session("client-a", false, &r).await; + store.get_or_create_session("client-b", false, &r).await; + + { + let mut sessions = store.write().await; + let sa = sessions.get_mut("client-a").unwrap(); + sa.supports_encryption = true; + sa.has_sent_common_tags = true; + } + + let sessions = store.read().await; + let sa = sessions.peek("client-a").unwrap(); + let sb = sessions.peek("client-b").unwrap(); + assert!(sa.supports_encryption); + assert!(sa.has_sent_common_tags); + assert!(!sb.supports_encryption); + assert!(!sb.has_sent_common_tags); + } + + #[tokio::test] + async fn get_or_create_preserves_capability_fields() { + let store = SessionStore::new(); + let r = routes(); + store.get_or_create_session("client-1", false, &r).await; + + { + let mut sessions = store.write().await; + let session = sessions.get_mut("client-1").unwrap(); + session.supports_encryption = true; + session.has_sent_common_tags = true; + } + + let created = store.get_or_create_session("client-1", true, &r).await; + assert!(!created); + + let sessions = store.read().await; + let session = sessions.peek("client-1").unwrap(); + assert!(session.supports_encryption); + assert!(session.has_sent_common_tags); + } + + #[tokio::test] + async fn clear_resets_capability_fields() { + let store = SessionStore::new(); + let r = routes(); + store.get_or_create_session("client-1", false, &r).await; + { + let mut sessions = store.write().await; + let s = sessions.get_mut("client-1").unwrap(); + s.supports_encryption = true; + } + + store.clear().await; + store.get_or_create_session("client-1", false, &r).await; + + let sessions = store.read().await; + let session = sessions.peek("client-1").unwrap(); + assert!(!session.supports_encryption); + assert!(!session.has_sent_common_tags); + } + + // ── LRU eviction ──────────────────────────────────────────── + + #[tokio::test] + async fn lru_eviction_drops_oldest_session() { + let store = SessionStore::with_capacity(3); + let r = routes(); + store.get_or_create_session("a", false, &r).await; + store.get_or_create_session("b", false, &r).await; + store.get_or_create_session("c", false, &r).await; + + store.get_or_create_session("d", false, &r).await; + + assert!( + store.get_session("a").await.is_none(), + "a should be evicted" + ); + assert!(store.get_session("b").await.is_some()); + assert!(store.get_session("c").await.is_some()); + assert!(store.get_session("d").await.is_some()); + assert_eq!(store.session_count().await, 3); + } + + #[tokio::test] + async fn eviction_callback_fires_on_lru_eviction() { + let evicted = Arc::new(std::sync::Mutex::new(Vec::::new())); + let evicted_clone = evicted.clone(); + let r = routes(); + + let mut store = SessionStore::with_capacity(2); + store.set_eviction_callback(Arc::new(move |pubkey| { + evicted_clone.lock().unwrap().push(pubkey); + })); + + store.get_or_create_session("a", false, &r).await; + store.get_or_create_session("b", false, &r).await; + store.get_or_create_session("c", false, &r).await; + + let evicted = evicted.lock().unwrap(); + assert_eq!(evicted.len(), 1); + assert_eq!(evicted[0], "a"); + } + + #[tokio::test] + async fn eviction_safety_recreates_session_with_active_routes() { + let store = SessionStore::with_capacity(2); + let r = routes(); + store.get_or_create_session("a", true, &r).await; + store.get_or_create_session("b", false, &r).await; + + // Register an active route for client "a" in the correlation store + r.register("evt1".into(), "a".into(), json!(1), None).await; + + // Adding "c" would normally evict "a", but eviction safety recreates it + // because "a" has active routes. + store.get_or_create_session("c", false, &r).await; + + let snap = store.get_session("a").await; + assert!( + snap.is_some(), + "session with active routes must survive eviction" + ); + // "b" was evicted instead (next LRU after "a" was re-inserted) + assert!( + store.get_session("b").await.is_none(), + "b should be evicted" + ); + } + + #[tokio::test] + async fn with_capacity_sets_limit() { + let store = SessionStore::with_capacity(5); + let r = routes(); + for i in 0..10 { + store + .get_or_create_session(&format!("client-{i}"), false, &r) + .await; + } + assert_eq!(store.session_count().await, 5); + } +} diff --git a/tests/conformance_dedup.rs b/tests/conformance_dedup.rs new file mode 100644 index 0000000..3f53050 --- /dev/null +++ b/tests/conformance_dedup.rs @@ -0,0 +1,113 @@ +//! Conformance tests for gift-wrap deduplication via LRU cache. +//! +//! Both the client and server transports use an `LruCache` to skip +//! duplicate outer gift-wrap event IDs. The dedup check happens *before* decrypt +//! and the insert happens only *after* successful decrypt + inner `verify()`. +//! These tests exercise the LRU cache logic in isolation — no async, no transport. + +use std::num::NonZeroUsize; + +use lru::LruCache; +use nostr_sdk::prelude::*; + +use contextvm_sdk::core::constants::DEFAULT_LRU_SIZE; + +/// Helper: build a cache with the same capacity used by both transports. +fn new_dedup_cache() -> LruCache { + LruCache::new(NonZeroUsize::new(DEFAULT_LRU_SIZE).expect("DEFAULT_LRU_SIZE must be non-zero")) +} + +fn event_id_from_byte(b: u8) -> EventId { + EventId::from_byte_array([b; 32]) +} + +// ── Gift-wrap kind 1059 dedup ───────────────────────────────────────────────── + +#[test] +fn client_dedup_skips_duplicate_outer_gift_wrap_id() { + let mut cache = new_dedup_cache(); + let outer_id = event_id_from_byte(0x01); + + // First delivery: not yet seen, decrypt succeeds, insert into cache. + assert!( + !cache.contains(&outer_id), + "first delivery must not be in cache yet" + ); + cache.put(outer_id, ()); + + // Second delivery: same outer id is already cached, skip before decrypt. + assert!( + cache.contains(&outer_id), + "second delivery of the same outer id must be rejected" + ); +} + +#[test] +fn client_dedup_ephemeral_gift_wrap_skips_duplicate() { + let mut cache = new_dedup_cache(); + let ephemeral_outer_id = event_id_from_byte(0xE1); + + // First delivery of an ephemeral gift-wrap (kind 21059). + assert!( + !cache.contains(&ephemeral_outer_id), + "first delivery must not be in cache yet" + ); + cache.put(ephemeral_outer_id, ()); + + // Second delivery: same outer id is already cached, skip before decrypt. + assert!( + cache.contains(&ephemeral_outer_id), + "second delivery of the same ephemeral outer id must be rejected" + ); +} + +// ── Server dedup ────────────────────────────────────────────────────────────── + +#[test] +fn server_dedup_ephemeral_gift_wrap_skips_duplicate() { + let mut cache = new_dedup_cache(); + let ephemeral_outer_id = event_id_from_byte(0xE2); + + // First delivery of an ephemeral gift-wrap (kind 21059). + assert!( + !cache.contains(&ephemeral_outer_id), + "first delivery must not be in cache yet" + ); + cache.put(ephemeral_outer_id, ()); + + // Second delivery: same outer id is already cached, skip before decrypt. + assert!( + cache.contains(&ephemeral_outer_id), + "second delivery of the same ephemeral outer id must be rejected" + ); +} + +#[test] +fn server_dedup_lru_evicts_oldest_when_capacity_reached() { + let capacity = 3; + let mut cache: LruCache = + LruCache::new(NonZeroUsize::new(capacity).expect("non-zero")); + + let id_0 = event_id_from_byte(0x00); + let id_1 = event_id_from_byte(0x01); + let id_2 = event_id_from_byte(0x02); + let id_3 = event_id_from_byte(0x03); + + cache.put(id_0, ()); + cache.put(id_1, ()); + cache.put(id_2, ()); + + // Cache is at capacity (3). Inserting a fourth must evict the oldest (id_0). + cache.put(id_3, ()); + + assert!( + !cache.contains(&id_0), + "oldest entry must be evicted when capacity is exceeded" + ); + assert!(cache.contains(&id_1), "second entry must still be present"); + assert!(cache.contains(&id_2), "third entry must still be present"); + assert!( + cache.contains(&id_3), + "newly inserted entry must be present" + ); +} diff --git a/tests/conformance_signer.rs b/tests/conformance_signer.rs new file mode 100644 index 0000000..5b7c661 --- /dev/null +++ b/tests/conformance_signer.rs @@ -0,0 +1,75 @@ +//! Conformance tests for signer behavior (hex `from_sk`, `generate`, NIP-44, signing). +//! +//! Same layout as `conformance_wire_format.rs`; scenarios follow the TS SDK +//! `private-key-signer.test.ts` alongside `src/signer/mod.rs` / `src/encryption/mod.rs`. + +use contextvm_sdk::encryption::{decrypt_nip44, encrypt_nip44}; +use contextvm_sdk::signer::{self, Keys}; +use nostr_sdk::prelude::*; + +/// Secret `1`, x-only pubkey of secp256k1 `G`. +const FIXTURE_SK_HEX: &str = "0000000000000000000000000000000000000000000000000000000000000001"; +const FIXTURE_PK_HEX: &str = "79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798"; + +fn fixture_keys() -> Keys { + signer::from_sk(FIXTURE_SK_HEX).expect("fixture SK hex parses") +} + +// ── Key derivation ─────────────────────────────────────────────────────────── + +#[test] +fn signer_generates_keypair_from_secret_key() { + let keys = fixture_keys(); + assert_eq!(keys.public_key().to_hex(), FIXTURE_PK_HEX); +} + +// ── Random generation ──────────────────────────────────────────────────────── + +#[test] +fn signer_generates_random_keypair_when_no_secret_provided() { + let keys = signer::generate(); + assert_eq!(keys.public_key().to_hex().len(), 64); +} + +// ── NIP-44 ─────────────────────────────────────────────────────────────────── + +#[tokio::test] +async fn signer_nip44_encrypt_decrypt_roundtrip() { + let sender_keys = Keys::generate(); + let recipient_keys = Keys::generate(); + let plaintext = "Hello Encryption!"; + + let ciphertext = encrypt_nip44(&sender_keys, &recipient_keys.public_key(), plaintext) + .await + .expect("nip44 encrypt"); + + assert_ne!(ciphertext, plaintext); + + let decrypted = decrypt_nip44(&recipient_keys, &sender_keys.public_key(), &ciphertext) + .await + .expect("nip44 decrypt"); + + assert_eq!(decrypted, plaintext); +} + +// ── Public key ─────────────────────────────────────────────────────────────── + +#[test] +fn signer_get_public_key_returns_correct_key() { + let keys = fixture_keys(); + let expected_pk = PublicKey::parse(FIXTURE_PK_HEX).expect("fixture PK hex parses"); + assert_eq!(keys.public_key(), expected_pk); +} + +// ── Signed events ──────────────────────────────────────────────────────────── + +#[test] +fn signer_signed_event_has_valid_signature() { + let keys = fixture_keys(); + let event = EventBuilder::new(Kind::TextNote, "Hello Nostr!") + .sign_with_keys(&keys) + .expect("sign text note"); + + assert_eq!(event.pubkey, keys.public_key()); + event.verify().expect("verify signed event"); +} diff --git a/tests/conformance_stateless_mode.rs b/tests/conformance_stateless_mode.rs new file mode 100644 index 0000000..0e30a65 --- /dev/null +++ b/tests/conformance_stateless_mode.rs @@ -0,0 +1,194 @@ +//! Stateless-mode conformance tests for the client transport. + +use std::time::Duration; + +use contextvm_sdk::core::constants::{mcp_protocol_version, INITIALIZE_METHOD}; +use contextvm_sdk::core::types::{ + EncryptionMode, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, +}; +use contextvm_sdk::transport::client::{NostrClientTransport, NostrClientTransportConfig}; +use contextvm_sdk::{signer, GiftWrapMode}; +use tokio::time::timeout; + +async fn make_stateless_transport() -> ( + NostrClientTransport, + tokio::sync::mpsc::UnboundedReceiver, +) { + let server_keys = signer::generate(); + let client_keys = signer::generate(); + + let config = NostrClientTransportConfig::default() + .with_relay_urls(Vec::new()) + .with_server_pubkey(server_keys.public_key().to_hex()) + .with_encryption_mode(EncryptionMode::Optional) + .with_gift_wrap_mode(GiftWrapMode::Optional) + .with_stateless(true) + .with_timeout(Duration::from_secs(1)); + + let mut transport = NostrClientTransport::new(client_keys, config) + .await + .expect("transport should be constructed"); + let rx = transport + .take_message_receiver() + .expect("message receiver should be available once"); + + (transport, rx) +} + +#[tokio::test] +async fn create_emulated_response_returns_correct_request_id() { + let (transport, mut rx) = make_stateless_transport().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("test-id"), + method: INITIALIZE_METHOD.to_string(), + params: Some(serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "capabilities": {}, + "clientInfo": { "name": "conformance-test", "version": "0.0.0" } + })), + }); + + transport + .send(&request) + .await + .expect("initialize should be emulated in stateless mode"); + + let msg = timeout(Duration::from_millis(200), rx.recv()) + .await + .expect("should receive emulated response promptly") + .expect("channel should contain response"); + + match msg { + JsonRpcMessage::Response(resp) => { + assert_eq!(resp.id, serde_json::json!("test-id")); + assert_eq!(resp.jsonrpc, "2.0"); + assert_eq!( + resp.result + .get("protocolVersion") + .and_then(serde_json::Value::as_str), + Some(mcp_protocol_version()) + ); + assert_eq!( + resp.result + .get("serverInfo") + .and_then(|v| v.get("name")) + .and_then(serde_json::Value::as_str), + Some("Emulated-Stateless-Server") + ); + assert_eq!( + resp.result + .get("serverInfo") + .and_then(|v| v.get("version")) + .and_then(serde_json::Value::as_str), + Some("1.0.0") + ); + assert_eq!( + resp.result + .get("capabilities") + .and_then(|v| v.get("tools")) + .and_then(|v| v.get("listChanged")) + .and_then(serde_json::Value::as_bool), + Some(true) + ); + assert_eq!( + resp.result + .get("capabilities") + .and_then(|v| v.get("prompts")) + .and_then(|v| v.get("listChanged")) + .and_then(serde_json::Value::as_bool), + Some(true) + ); + assert_eq!( + resp.result + .get("capabilities") + .and_then(|v| v.get("resources")) + .and_then(|v| v.get("subscribe")) + .and_then(serde_json::Value::as_bool), + Some(true) + ); + assert_eq!( + resp.result + .get("capabilities") + .and_then(|v| v.get("resources")) + .and_then(|v| v.get("listChanged")) + .and_then(serde_json::Value::as_bool), + Some(true) + ); + } + other => panic!("expected Response, got {other:?}"), + } + + let duplicate = timeout(Duration::from_millis(100), rx.recv()).await; + assert!( + duplicate.is_err(), + "initialize request should emit exactly one emulated response" + ); +} + +#[tokio::test] +async fn should_handle_statelessly_returns_true_for_initialize() { + let (transport, mut rx) = make_stateless_transport().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: INITIALIZE_METHOD.to_string(), + params: None, + }); + + transport + .send(&request) + .await + .expect("initialize should be handled statelessly"); + + let msg = timeout(Duration::from_millis(200), rx.recv()) + .await + .expect("initialize should produce local emulated response") + .expect("response should be delivered"); + + assert_eq!(msg.id(), Some(&serde_json::json!(1))); +} + +#[tokio::test] +async fn should_handle_statelessly_returns_false_for_other_methods() { + let (transport, mut rx) = make_stateless_transport().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(2), + method: "tools/list".to_string(), + params: None, + }); + + let _send_result = transport.send(&request).await; + + let recv_result = timeout(Duration::from_millis(200), rx.recv()).await; + assert!( + recv_result.is_err(), + "non-initialize request should not create a local emulated response" + ); +} + +#[tokio::test] +async fn notifications_initialized_swallowed_in_stateless_mode() { + let (transport, mut rx) = make_stateless_transport().await; + + let notification = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/initialized".to_string(), + params: None, + }); + + transport + .send(¬ification) + .await + .expect("notifications/initialized should be accepted in stateless mode"); + + let recv_result = timeout(Duration::from_millis(200), rx.recv()).await; + assert!( + recv_result.is_err(), + "notifications/initialized must be swallowed in stateless mode" + ); +} diff --git a/tests/conformance_stores.rs b/tests/conformance_stores.rs new file mode 100644 index 0000000..8216403 --- /dev/null +++ b/tests/conformance_stores.rs @@ -0,0 +1,828 @@ +//! Conformance tests for store abstractions. +//! +//! Ported from the TS SDK: +//! - `src/transport/nostr-client/correlation-store.test.ts` +//! - `src/transport/nostr-server/session-store.test.ts` +//! - `src/transport/nostr-server/correlation-store.test.ts` + +use contextvm_sdk::{ClientCorrelationStore, ServerEventRouteStore, SessionStore}; +use serde_json::json; + +// ════════════════════════════════════════════════════════════════════ +// Client Correlation Store +// ════════════════════════════════════════════════════════════════════ + +mod client_correlation_store { + use super::*; + + // ── registerRequest ─────────────────────────────────────────── + + #[tokio::test] + async fn stores_request_with_event_id() { + let store = ClientCorrelationStore::new(); + store + .register("event123".into(), json!("req1"), false) + .await; + assert!(store.contains("event123").await); + } + + #[tokio::test] + async fn stores_and_resolves_original_request_id() { + let store = ClientCorrelationStore::new(); + store + .register("event456".into(), json!("req2"), false) + .await; + + // Retrieve the stored original ID. + let original = store.get_original_id("event456").await.unwrap(); + assert_eq!(original, json!("req2")); + + // After removal the entry is fully gone. + assert!(store.remove("event456").await); + assert!(store.get_original_id("event456").await.is_none()); + } + + #[tokio::test] + async fn register_request_flags_initialize_requests() { + let store = ClientCorrelationStore::new(); + store.register("e_init".into(), json!("r1"), true).await; + store.register("e_normal".into(), json!("r2"), false).await; + + assert!(store.is_initialize_request("e_init").await); + assert!(!store.is_initialize_request("e_normal").await); + assert!(!store.is_initialize_request("unknown").await); + } + + // ── resolveResponse (get_original_id + remove) ──────────────── + + #[tokio::test] + async fn restores_original_request_id() { + let store = ClientCorrelationStore::new(); + store.register("event789".into(), json!(42), false).await; + let original = store.get_original_id("event789").await.unwrap(); + assert_eq!(original, json!(42)); + } + + #[tokio::test] + async fn returns_none_for_unknown_event_id() { + let store = ClientCorrelationStore::new(); + assert!(store.get_original_id("unknown").await.is_none()); + } + + #[tokio::test] + async fn get_and_remove_roundtrip() { + let store = ClientCorrelationStore::new(); + store.register("event1".into(), json!("req1"), false).await; + + // Lookup succeeds before removal. + let original = store.get_original_id("event1").await.unwrap(); + assert_eq!(original, json!("req1")); + + // Remove returns true and cleans up completely. + assert!(store.remove("event1").await); + assert!(!store.contains("event1").await); + assert!(store.get_original_id("event1").await.is_none()); + } + + // ── removePendingRequest ────────────────────────────────────── + + #[tokio::test] + async fn removes_existing_request() { + let store = ClientCorrelationStore::new(); + store.register("event1".into(), json!(null), false).await; + assert!(store.remove("event1").await); + assert!(!store.contains("event1").await); + } + + #[tokio::test] + async fn returns_false_for_unknown_request() { + let store = ClientCorrelationStore::new(); + assert!(!store.remove("unknown").await); + } + + // ── clear ───────────────────────────────────────────────────── + + #[tokio::test] + async fn removes_all_pending_requests() { + let store = ClientCorrelationStore::new(); + store.register("event1".into(), json!(null), false).await; + store.register("event2".into(), json!(null), false).await; + store.clear().await; + assert_eq!(store.count().await, 0); + } + + // ── LRU eviction (TS SDK client test 9) ─────────────────────── + + #[tokio::test] + async fn evicts_oldest_when_capacity_reached() { + let store = ClientCorrelationStore::with_max_pending(2); + for i in 0..5 { + store + .register(format!("event{i}"), json!(null), false) + .await; + } + assert_eq!(store.count().await, 2); + // Only the two most recent entries survive. + assert!(!store.contains("event0").await); + assert!(!store.contains("event1").await); + assert!(!store.contains("event2").await); + assert!(store.contains("event3").await); + assert!(store.contains("event4").await); + } +} + +// ════════════════════════════════════════════════════════════════════ +// Server Session Store +// ════════════════════════════════════════════════════════════════════ + +mod server_session_store { + use super::*; + + fn routes() -> ServerEventRouteStore { + ServerEventRouteStore::new() + } + + #[tokio::test] + async fn create_and_retrieve_sessions() { + let store = SessionStore::new(); + let r = routes(); + + let created = store.get_or_create_session("client-1", true, &r).await; + assert!(created); + + let session = store.get_session("client-1").await.unwrap(); + assert!(session.is_encrypted); + assert!(!session.is_initialized); + + // Retrieving same key should return it + assert!(store.get_session("client-1").await.is_some()); + } + + #[tokio::test] + async fn mark_sessions_as_initialized() { + let store = SessionStore::new(); + let r = routes(); + store.get_or_create_session("client-1", false, &r).await; + + let result = store.mark_initialized("client-1").await; + assert!(result); + + let session = store.get_session("client-1").await.unwrap(); + assert!(session.is_initialized); + } + + #[tokio::test] + async fn remove_sessions() { + let store = SessionStore::new(); + let r = routes(); + store.get_or_create_session("client-1", false, &r).await; + + let result = store.remove_session("client-1").await; + assert!(result); + assert!(store.get_session("client-1").await.is_none()); + } + + #[tokio::test] + async fn clear_all_sessions() { + let store = SessionStore::new(); + let r = routes(); + store.get_or_create_session("client-1", false, &r).await; + store.get_or_create_session("client-2", true, &r).await; + + store.clear().await; + + assert_eq!(store.session_count().await, 0); + assert!(store.get_session("client-1").await.is_none()); + assert!(store.get_session("client-2").await.is_none()); + } + + #[tokio::test] + async fn iterate_over_all_sessions() { + let store = SessionStore::new(); + let r = routes(); + store.get_or_create_session("client-1", false, &r).await; + store.get_or_create_session("client-2", true, &r).await; + + let sessions = store.get_all_sessions().await; + assert_eq!(sessions.len(), 2); + + let keys: Vec<&str> = sessions.iter().map(|(k, _)| k.as_str()).collect(); + assert!(keys.contains(&"client-1")); + assert!(keys.contains(&"client-2")); + } +} + +// ════════════════════════════════════════════════════════════════════ +// Server Correlation Store (ServerEventRouteStore) +// ════════════════════════════════════════════════════════════════════ + +mod server_correlation_store { + use super::*; + + // ── registerEventRoute ──────────────────────────────────────── + + #[tokio::test] + async fn registers_route_with_all_fields() { + let store = ServerEventRouteStore::new(); + store + .register( + "event1".into(), + "client1".into(), + json!("req1"), + Some("token1".into()), + ) + .await; + + let route = store.get_route("event1").await.unwrap(); + assert_eq!(route.client_pubkey, "client1"); + assert_eq!(route.original_request_id, json!("req1")); + assert_eq!(route.progress_token.as_deref(), Some("token1")); + } + + #[tokio::test] + async fn registers_route_without_progress_token() { + let store = ServerEventRouteStore::new(); + store + .register("event1".into(), "client1".into(), json!("req1"), None) + .await; + + let route = store.get_route("event1").await.unwrap(); + assert!(route.progress_token.is_none()); + } + + #[tokio::test] + async fn registers_route_with_numeric_request_id() { + let store = ServerEventRouteStore::new(); + store + .register("event1".into(), "client1".into(), json!(42), None) + .await; + + let route = store.get_route("event1").await.unwrap(); + assert_eq!(route.original_request_id, json!(42)); + } + + #[tokio::test] + async fn updates_client_index_when_registering() { + let store = ServerEventRouteStore::new(); + store + .register("event1".into(), "client1".into(), json!("req1"), None) + .await; + store + .register("event2".into(), "client1".into(), json!("req2"), None) + .await; + + assert!(store.has_active_routes_for_client("client1").await); + } + + #[tokio::test] + async fn registers_progress_token_mapping() { + let store = ServerEventRouteStore::new(); + store + .register( + "event1".into(), + "client1".into(), + json!("req1"), + Some("token1".into()), + ) + .await; + + assert_eq!( + store + .get_event_id_by_progress_token("token1") + .await + .as_deref(), + Some("event1") + ); + assert!(store.has_progress_token("token1").await); + } + + // ── getEventRoute ───────────────────────────────────────────── + + #[tokio::test] + async fn returns_none_for_unknown_event_id() { + let store = ServerEventRouteStore::new(); + assert!(store.get_route("unknown").await.is_none()); + } + + // ── popEventRoute ───────────────────────────────────────────── + + #[tokio::test] + async fn returns_and_removes_route_atomically() { + let store = ServerEventRouteStore::new(); + store + .register( + "event1".into(), + "client1".into(), + json!("req1"), + Some("token1".into()), + ) + .await; + + let route = store.pop("event1").await.unwrap(); + assert_eq!(route.client_pubkey, "client1"); + assert_eq!(route.original_request_id, json!("req1")); + assert_eq!(route.progress_token.as_deref(), Some("token1")); + + // Route + token mapping should be gone. + assert!(!store.has_event_route("event1").await); + assert!(!store.has_progress_token("token1").await); + + // Second pop is a no-op. + assert!(store.pop("event1").await.is_none()); + } + + // ── getEventIdByProgressToken ───────────────────────────────── + + #[tokio::test] + async fn returns_none_for_unknown_token() { + let store = ServerEventRouteStore::new(); + assert!(store + .get_event_id_by_progress_token("unknown") + .await + .is_none()); + } + + #[tokio::test] + async fn returns_correct_event_id_for_token() { + let store = ServerEventRouteStore::new(); + store + .register( + "event1".into(), + "client1".into(), + json!("req1"), + Some("token1".into()), + ) + .await; + store + .register( + "event2".into(), + "client2".into(), + json!("req2"), + Some("token2".into()), + ) + .await; + + assert_eq!( + store + .get_event_id_by_progress_token("token1") + .await + .as_deref(), + Some("event1") + ); + assert_eq!( + store + .get_event_id_by_progress_token("token2") + .await + .as_deref(), + Some("event2") + ); + } + + // ── removeRoutesForClient ───────────────────────────────────── + + #[tokio::test] + async fn removes_all_routes_for_client() { + let store = ServerEventRouteStore::new(); + store + .register("event1".into(), "client1".into(), json!("req1"), None) + .await; + store + .register("event2".into(), "client1".into(), json!("req2"), None) + .await; + store + .register("event3".into(), "client2".into(), json!("req3"), None) + .await; + + let removed = store.remove_for_client("client1").await; + assert_eq!(removed, 2); + + assert!(!store.has_event_route("event1").await); + assert!(!store.has_event_route("event2").await); + assert!(store.has_event_route("event3").await); + } + + #[tokio::test] + async fn returns_zero_for_unknown_client() { + let store = ServerEventRouteStore::new(); + assert_eq!(store.remove_for_client("unknown").await, 0); + } + + #[tokio::test] + async fn cleans_up_progress_tokens_for_removed_routes() { + let store = ServerEventRouteStore::new(); + store + .register( + "event1".into(), + "client1".into(), + json!("req1"), + Some("token1".into()), + ) + .await; + store + .register( + "event2".into(), + "client1".into(), + json!("req2"), + Some("token2".into()), + ) + .await; + + store.remove_for_client("client1").await; + + assert!(!store.has_progress_token("token1").await); + assert!(!store.has_progress_token("token2").await); + } + + #[tokio::test] + async fn removes_client_from_index_after_cleanup() { + let store = ServerEventRouteStore::new(); + store + .register("event1".into(), "client1".into(), json!("req1"), None) + .await; + + store.remove_for_client("client1").await; + + assert!(!store.has_active_routes_for_client("client1").await); + } + + // ── hasEventRoute ───────────────────────────────────────────── + + #[tokio::test] + async fn has_event_route_true_for_existing() { + let store = ServerEventRouteStore::new(); + store + .register("event1".into(), "client1".into(), json!("req1"), None) + .await; + assert!(store.has_event_route("event1").await); + } + + #[tokio::test] + async fn has_event_route_false_for_unknown() { + let store = ServerEventRouteStore::new(); + assert!(!store.has_event_route("unknown").await); + } + + // ── hasProgressToken ────────────────────────────────────────── + + #[tokio::test] + async fn has_progress_token_true_for_existing() { + let store = ServerEventRouteStore::new(); + store + .register( + "event1".into(), + "client1".into(), + json!("req1"), + Some("token1".into()), + ) + .await; + assert!(store.has_progress_token("token1").await); + } + + #[tokio::test] + async fn has_progress_token_false_for_unknown() { + let store = ServerEventRouteStore::new(); + assert!(!store.has_progress_token("unknown").await); + } + + // ── hasActiveRoutesForClient ────────────────────────────────── + + #[tokio::test] + async fn has_active_routes_true_when_routes_exist() { + let store = ServerEventRouteStore::new(); + store + .register("event1".into(), "client1".into(), json!("req1"), None) + .await; + assert!(store.has_active_routes_for_client("client1").await); + } + + #[tokio::test] + async fn has_active_routes_false_when_no_routes() { + let store = ServerEventRouteStore::new(); + assert!(!store.has_active_routes_for_client("client1").await); + } + + #[tokio::test] + async fn has_active_routes_false_after_all_popped() { + let store = ServerEventRouteStore::new(); + store + .register("event1".into(), "client1".into(), json!("req1"), None) + .await; + store.pop("event1").await; + assert!(!store.has_active_routes_for_client("client1").await); + } + + // ── eventRouteCount ─────────────────────────────────────────── + + #[tokio::test] + async fn event_route_count_zero_for_empty() { + let store = ServerEventRouteStore::new(); + assert_eq!(store.event_route_count().await, 0); + } + + #[tokio::test] + async fn event_route_count_after_registrations() { + let store = ServerEventRouteStore::new(); + store + .register("event1".into(), "client1".into(), json!("req1"), None) + .await; + store + .register("event2".into(), "client1".into(), json!("req2"), None) + .await; + assert_eq!(store.event_route_count().await, 2); + } + + #[tokio::test] + async fn event_route_count_after_removals() { + let store = ServerEventRouteStore::new(); + store + .register("event1".into(), "client1".into(), json!("req1"), None) + .await; + store + .register("event2".into(), "client1".into(), json!("req2"), None) + .await; + store.pop("event1").await; + assert_eq!(store.event_route_count().await, 1); + } + + // ── progressTokenCount ──────────────────────────────────────── + + #[tokio::test] + async fn progress_token_count_zero_for_empty() { + let store = ServerEventRouteStore::new(); + assert_eq!(store.progress_token_count().await, 0); + } + + #[tokio::test] + async fn progress_token_count_after_registrations() { + let store = ServerEventRouteStore::new(); + store + .register( + "event1".into(), + "client1".into(), + json!("req1"), + Some("token1".into()), + ) + .await; + store + .register( + "event2".into(), + "client1".into(), + json!("req2"), + Some("token2".into()), + ) + .await; + store + .register("event3".into(), "client1".into(), json!("req3"), None) + .await; + assert_eq!(store.progress_token_count().await, 2); + } + + #[tokio::test] + async fn progress_token_count_after_removals() { + let store = ServerEventRouteStore::new(); + store + .register( + "event1".into(), + "client1".into(), + json!("req1"), + Some("token1".into()), + ) + .await; + store + .register( + "event2".into(), + "client1".into(), + json!("req2"), + Some("token2".into()), + ) + .await; + store.pop("event1").await; + assert_eq!(store.progress_token_count().await, 1); + } + + // ── clear ───────────────────────────────────────────────────── + + #[tokio::test] + async fn clear_removes_all_routes() { + let store = ServerEventRouteStore::new(); + store + .register("event1".into(), "client1".into(), json!("req1"), None) + .await; + store + .register("event2".into(), "client2".into(), json!("req2"), None) + .await; + + store.clear().await; + + assert_eq!(store.event_route_count().await, 0); + assert!(!store.has_event_route("event1").await); + } + + #[tokio::test] + async fn clear_removes_all_progress_tokens() { + let store = ServerEventRouteStore::new(); + store + .register( + "event1".into(), + "client1".into(), + json!("req1"), + Some("token1".into()), + ) + .await; + + store.clear().await; + + assert_eq!(store.progress_token_count().await, 0); + assert!(!store.has_progress_token("token1").await); + } + + #[tokio::test] + async fn clear_cleans_up_client_index() { + let store = ServerEventRouteStore::new(); + store + .register("event1".into(), "client1".into(), json!("req1"), None) + .await; + + store.clear().await; + + assert!(!store.has_active_routes_for_client("client1").await); + } + + // ── complex scenarios ───────────────────────────────────────── + + #[tokio::test] + async fn handles_multiple_clients_with_multiple_routes() { + let store = ServerEventRouteStore::new(); + + // Client 1: 2 routes + store + .register( + "c1e1".into(), + "client1".into(), + json!("r1"), + Some("t1".into()), + ) + .await; + store + .register( + "c1e2".into(), + "client1".into(), + json!("r2"), + Some("t2".into()), + ) + .await; + + // Client 2: 1 route + store + .register( + "c2e1".into(), + "client2".into(), + json!("r3"), + Some("t3".into()), + ) + .await; + + assert_eq!(store.event_route_count().await, 3); + assert_eq!(store.progress_token_count().await, 3); + assert!(store.has_active_routes_for_client("client1").await); + assert!(store.has_active_routes_for_client("client2").await); + + // Remove one of client1's routes + store.pop("c1e1").await; + + assert!(store.has_active_routes_for_client("client1").await); + assert!(!store.has_progress_token("t1").await); + assert!(store.has_progress_token("t2").await); + } + + #[tokio::test] + async fn handles_route_replacement_with_same_progress_token() { + let store = ServerEventRouteStore::new(); + + store + .register( + "event1".into(), + "client1".into(), + json!("req1"), + Some("token1".into()), + ) + .await; + assert_eq!( + store + .get_event_id_by_progress_token("token1") + .await + .as_deref(), + Some("event1") + ); + + // Register new route with same token (overwrites mapping) + store + .register( + "event2".into(), + "client1".into(), + json!("req2"), + Some("token1".into()), + ) + .await; + assert_eq!( + store + .get_event_id_by_progress_token("token1") + .await + .as_deref(), + Some("event2") + ); + } + + #[tokio::test] + async fn maintains_consistency_through_mixed_operations() { + let store = ServerEventRouteStore::new(); + + // Add routes + store + .register("e1".into(), "c1".into(), json!("r1"), Some("t1".into())) + .await; + store + .register("e2".into(), "c1".into(), json!("r2"), Some("t2".into())) + .await; + store + .register("e3".into(), "c2".into(), json!("r3"), Some("t3".into())) + .await; + + // Remove one + store.pop("e2").await; + + // Verify consistency + assert!(store.has_event_route("e1").await); + assert!(!store.has_event_route("e2").await); + assert!(store.has_event_route("e3").await); + + assert!(store.has_progress_token("t1").await); + assert!(!store.has_progress_token("t2").await); + assert!(store.has_progress_token("t3").await); + + assert!(store.has_active_routes_for_client("c1").await); + assert!(store.has_active_routes_for_client("c2").await); + } + + // ── LRU eviction (TS SDK server tests 28–30) ───────────────── + + #[tokio::test] + async fn evicts_oldest_route_when_capacity_reached() { + let store = ServerEventRouteStore::with_max_routes(2); + + store + .register("event1".into(), "client1".into(), json!("req1"), None) + .await; + store + .register("event2".into(), "client1".into(), json!("req2"), None) + .await; + store + .register("event3".into(), "client1".into(), json!("req3"), None) + .await; + + // event1 should have been evicted. + assert!(!store.has_event_route("event1").await); + assert_eq!(store.event_route_count().await, 2); + } + + #[tokio::test] + async fn cleans_up_progress_tokens_on_eviction() { + let store = ServerEventRouteStore::with_max_routes(1); + + store + .register( + "event1".into(), + "client1".into(), + json!("req1"), + Some("token1".into()), + ) + .await; + store + .register( + "event2".into(), + "client1".into(), + json!("req2"), + Some("token2".into()), + ) + .await; + + assert!(!store.has_progress_token("token1").await); + assert!(store.has_progress_token("token2").await); + } + + #[tokio::test] + async fn cleans_up_client_index_on_eviction() { + let store = ServerEventRouteStore::with_max_routes(1); + + store + .register("event1".into(), "client1".into(), json!("req1"), None) + .await; + store + .register("event2".into(), "client2".into(), json!("req2"), None) + .await; + + // client1's only route was evicted. + assert!(!store.has_active_routes_for_client("client1").await); + assert!(store.has_active_routes_for_client("client2").await); + } +} diff --git a/tests/conformance_wire_format.rs b/tests/conformance_wire_format.rs new file mode 100644 index 0000000..828ac38 --- /dev/null +++ b/tests/conformance_wire_format.rs @@ -0,0 +1,463 @@ +//! Conformance tests for ContextVM wire format: MCP JSON-RPC carried in Nostr kind 25910 events. +//! +//! These mirror the layering style of `src/rmcp_transport/pipeline_tests.rs`: build the JSON-RPC +//! payload, serialize through the same helpers the transport uses (`mcp_to_nostr_event`, tag +//! builders from [`BaseTransport`]), sign with nostr-sdk, then assert on kind, tags, and content. + +use contextvm_sdk::core::constants::{ + mcp_protocol_version, tags, CTXVM_MESSAGES_KIND, INITIALIZE_METHOD, + NOTIFICATIONS_INITIALIZED_METHOD, SERVER_ANNOUNCEMENT_KIND, +}; +use contextvm_sdk::core::serializers; +use contextvm_sdk::core::types::{ + JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, +}; +use contextvm_sdk::transport::base::BaseTransport; +use nostr_sdk::prelude::*; + +fn assert_ctxvm_message_kind(event: &Event) { + assert_eq!( + event.kind, + Kind::Custom(CTXVM_MESSAGES_KIND), + "ContextVM MCP messages must use kind {}", + CTXVM_MESSAGES_KIND + ); +} + +fn p_tag_hex(event: &Event) -> Option { + serializers::get_tag_value(&event.tags, tags::PUBKEY) +} + +fn e_tag_hex(event: &Event) -> Option { + serializers::get_tag_value(&event.tags, tags::EVENT_ID) +} + +// ── Initialize request ─────────────────────────────────────────────────────── + +#[test] +fn ctxvm_initialize_request_has_kind_p_tag_and_jsonrpc_initialize() { + let server_keys = Keys::generate(); + let server_pk = server_keys.public_key(); + + let init_req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: INITIALIZE_METHOD.to_string(), + params: Some(serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "capabilities": {}, + "clientInfo": { "name": "conformance-test", "version": "0.0.0" } + })), + }); + + let recipient_tags = BaseTransport::create_recipient_tags(&server_pk); + let builder = serializers::mcp_to_nostr_event(&init_req, CTXVM_MESSAGES_KIND, recipient_tags) + .expect("initialize request should serialize to event content"); + + let client_keys = Keys::generate(); + let event = builder + .sign_with_keys(&client_keys) + .expect("sign initialize request event"); + + assert_ctxvm_message_kind(&event); + assert_eq!( + p_tag_hex(&event), + Some(server_pk.to_hex()), + "initialize request must target the server via p tag" + ); + + let msg = serializers::nostr_event_to_mcp_message(&event.content) + .expect("event content should be valid JSON-RPC"); + assert!(msg.is_request()); + assert_eq!(msg.method(), Some(INITIALIZE_METHOD)); + + // Parse at the raw JSON level to verify wire format independently of the typed deserializer. + let v: serde_json::Value = + serde_json::from_str(&event.content).expect("content must be JSON object"); + assert_eq!(v["jsonrpc"], "2.0"); + assert_eq!(v["id"], serde_json::json!(1)); +} + +// ── Initialize response ────────────────────────────────────────────────────── + +#[test] +fn ctxvm_initialize_response_has_kind_e_tag_and_result_protocol_version() { + let server_keys = Keys::generate(); + let server_pk = server_keys.public_key(); + let client_keys = Keys::generate(); + let client_pk = client_keys.public_key(); + + // Signed request event provides the Nostr event id referenced by e on the response. + let init_req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("corr-1"), + method: INITIALIZE_METHOD.to_string(), + params: Some(serde_json::json!({})), + }); + let recipient_tags = BaseTransport::create_recipient_tags(&server_pk); + let request_event = + serializers::mcp_to_nostr_event(&init_req, CTXVM_MESSAGES_KIND, recipient_tags) + .expect("request event for response correlation should serialize") + .sign_with_keys(&client_keys) + .expect("sign request event for correlation"); + + let init_resp = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("corr-1"), + result: serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "serverInfo": { + "name": "conformance-test-server", + "version": "0.0.0" + }, + "capabilities": {} + }), + }); + + let response_tags = BaseTransport::create_response_tags(&client_pk, &request_event.id); + let response_event = + serializers::mcp_to_nostr_event(&init_resp, CTXVM_MESSAGES_KIND, response_tags) + .expect("initialize response should serialize") + .sign_with_keys(&server_keys) + .expect("sign initialize response event"); + + assert_ctxvm_message_kind(&response_event); + assert_eq!( + p_tag_hex(&response_event), + Some(client_pk.to_hex()), + "initialize response must route back to the client via p tag" + ); + assert_eq!( + e_tag_hex(&response_event), + Some(request_event.id.to_hex()), + "initialize response must correlate to the request Nostr event via e tag" + ); + + let v: serde_json::Value = + serde_json::from_str(&response_event.content).expect("content must be JSON"); + assert_eq!(v["jsonrpc"], "2.0"); + assert_eq!(v["id"], serde_json::json!("corr-1")); + assert!(v["result"]["protocolVersion"].is_string()); + assert!(v["result"]["serverInfo"]["name"].is_string()); +} + +// ── notifications/initialized ────────────────────────────────────────────── + +#[test] +fn ctxvm_notifications_initialized_has_kind_p_tag_and_method() { + let server_keys = Keys::generate(); + let server_pk = server_keys.public_key(); + let client_keys = Keys::generate(); + + let notif = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: NOTIFICATIONS_INITIALIZED_METHOD.to_string(), + params: None, + }); + + let recipient_tags = BaseTransport::create_recipient_tags(&server_pk); + let event = serializers::mcp_to_nostr_event(¬if, CTXVM_MESSAGES_KIND, recipient_tags) + .expect("notification should serialize") + // Client sends this to the server; signer must differ from `p` so the tag is not stripped. + .sign_with_keys(&client_keys) + .expect("sign initialized notification"); + + assert_ctxvm_message_kind(&event); + assert_eq!( + p_tag_hex(&event), + Some(server_pk.to_hex()), + "initialized notification must include server p tag" + ); + + let msg = serializers::nostr_event_to_mcp_message(&event.content).expect("parse content"); + assert!(msg.is_notification()); + assert_eq!(msg.method(), Some(NOTIFICATIONS_INITIALIZED_METHOD)); + + // Parse at the raw JSON level to verify wire format independently of the typed deserializer. + let v: serde_json::Value = + serde_json::from_str(&event.content).expect("content must be JSON object"); + assert_eq!(v["jsonrpc"], "2.0"); + assert!( + v.get("id").is_none_or(serde_json::Value::is_null), + "JSON-RPC notifications must not include an id" + ); +} + +// ── tools/list request ─────────────────────────────────────────────────────── + +#[test] +fn ctxvm_tools_list_request_has_kind_p_tag_and_method() { + let server_keys = Keys::generate(); + let server_pk = server_keys.public_key(); + + let list_req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(2), + method: "tools/list".to_string(), + params: Some(serde_json::json!({})), + }); + + let recipient_tags = BaseTransport::create_recipient_tags(&server_pk); + let builder = serializers::mcp_to_nostr_event(&list_req, CTXVM_MESSAGES_KIND, recipient_tags) + .expect("tools/list request should serialize to event content"); + + let client_keys = Keys::generate(); + let event = builder + .sign_with_keys(&client_keys) + .expect("sign tools/list request event"); + + assert_ctxvm_message_kind(&event); + assert_eq!( + p_tag_hex(&event), + Some(server_pk.to_hex()), + "tools/list request must target the server via p tag" + ); + + let msg = serializers::nostr_event_to_mcp_message(&event.content) + .expect("event content should be valid JSON-RPC"); + assert!(msg.is_request()); + assert_eq!(msg.method(), Some("tools/list")); + + let v: serde_json::Value = + serde_json::from_str(&event.content).expect("content must be JSON object"); + assert_eq!(v["jsonrpc"], "2.0"); + assert_eq!(v["id"], serde_json::json!(2)); +} + +// ── tools/call request ─────────────────────────────────────────────────────── + +#[test] +fn ctxvm_tools_call_request_has_kind_p_tag_method_and_params() { + let server_keys = Keys::generate(); + let server_pk = server_keys.public_key(); + + let call_req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(3), + method: "tools/call".to_string(), + params: Some(serde_json::json!({ + "name": "add", + "arguments": { "a": 5, "b": 3 } + })), + }); + + let recipient_tags = BaseTransport::create_recipient_tags(&server_pk); + let builder = serializers::mcp_to_nostr_event(&call_req, CTXVM_MESSAGES_KIND, recipient_tags) + .expect("tools/call request should serialize to event content"); + + let client_keys = Keys::generate(); + let event = builder + .sign_with_keys(&client_keys) + .expect("sign tools/call request event"); + + assert_ctxvm_message_kind(&event); + assert_eq!( + p_tag_hex(&event), + Some(server_pk.to_hex()), + "tools/call request must target the server via p tag" + ); + + let msg = serializers::nostr_event_to_mcp_message(&event.content) + .expect("event content should be valid JSON-RPC"); + assert!(msg.is_request()); + assert_eq!(msg.method(), Some("tools/call")); + + let v: serde_json::Value = + serde_json::from_str(&event.content).expect("content must be JSON object"); + assert_eq!(v["jsonrpc"], "2.0"); + assert_eq!(v["id"], serde_json::json!(3)); + assert_eq!(v["params"]["name"], "add"); + assert!( + v["params"]["arguments"].is_object(), + "tools/call params.arguments must be an object on the wire" + ); +} + +// ── tools/list response ─────────────────────────────────────────────────────── + +#[test] +fn ctxvm_tools_list_response_has_kind_e_tag_and_result() { + let server_keys = Keys::generate(); + let server_pk = server_keys.public_key(); + let client_keys = Keys::generate(); + let client_pk = client_keys.public_key(); + + let list_req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("corr-tools-list"), + method: "tools/list".to_string(), + params: Some(serde_json::json!({})), + }); + let recipient_tags = BaseTransport::create_recipient_tags(&server_pk); + let request_event = + serializers::mcp_to_nostr_event(&list_req, CTXVM_MESSAGES_KIND, recipient_tags) + .expect("tools/list request for response correlation should serialize") + .sign_with_keys(&client_keys) + .expect("sign tools/list request event for correlation"); + + let list_resp = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("corr-tools-list"), + result: serde_json::json!({ "tools": [] }), + }); + + let response_tags = BaseTransport::create_response_tags(&client_pk, &request_event.id); + let response_event = + serializers::mcp_to_nostr_event(&list_resp, CTXVM_MESSAGES_KIND, response_tags) + .expect("tools/list response should serialize") + .sign_with_keys(&server_keys) + .expect("sign tools/list response event"); + + assert_ctxvm_message_kind(&response_event); + assert_eq!( + p_tag_hex(&response_event), + Some(client_pk.to_hex()), + "tools/list response must route back to the client via p tag" + ); + assert_eq!( + e_tag_hex(&response_event), + Some(request_event.id.to_hex()), + "tools/list response must correlate to the request Nostr event via e tag" + ); + + let v: serde_json::Value = + serde_json::from_str(&response_event.content).expect("content must be JSON"); + assert_eq!(v["jsonrpc"], "2.0"); + assert_eq!(v["id"], serde_json::json!("corr-tools-list")); + assert!(v["result"]["tools"].is_array()); +} + +// ── tools/call response ─────────────────────────────────────────────────────── + +#[test] +fn ctxvm_tools_call_response_has_kind_e_tag_and_result() { + let server_keys = Keys::generate(); + let server_pk = server_keys.public_key(); + let client_keys = Keys::generate(); + let client_pk = client_keys.public_key(); + + let call_req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("corr-tools-call"), + method: "tools/call".to_string(), + params: Some(serde_json::json!({ + "name": "add", + "arguments": { "a": 5, "b": 3 } + })), + }); + let recipient_tags = BaseTransport::create_recipient_tags(&server_pk); + let request_event = + serializers::mcp_to_nostr_event(&call_req, CTXVM_MESSAGES_KIND, recipient_tags) + .expect("tools/call request for response correlation should serialize") + .sign_with_keys(&client_keys) + .expect("sign tools/call request event for correlation"); + + let call_resp = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("corr-tools-call"), + result: serde_json::json!({ + "content": [{ "type": "text", "text": "8" }], + "isError": false + }), + }); + + let response_tags = BaseTransport::create_response_tags(&client_pk, &request_event.id); + let response_event = + serializers::mcp_to_nostr_event(&call_resp, CTXVM_MESSAGES_KIND, response_tags) + .expect("tools/call response should serialize") + .sign_with_keys(&server_keys) + .expect("sign tools/call response event"); + + assert_ctxvm_message_kind(&response_event); + assert_eq!( + p_tag_hex(&response_event), + Some(client_pk.to_hex()), + "tools/call response must route back to the client via p tag" + ); + assert_eq!( + e_tag_hex(&response_event), + Some(request_event.id.to_hex()), + "tools/call response must correlate to the request Nostr event via e tag" + ); + + let v: serde_json::Value = + serde_json::from_str(&response_event.content).expect("content must be JSON"); + assert_eq!(v["jsonrpc"], "2.0"); + assert_eq!(v["id"], serde_json::json!("corr-tools-call")); + assert!(v["result"]["content"].is_array()); + assert_eq!(v["result"]["isError"], serde_json::json!(false)); +} + +// ── Server announcement (kind 11316) ────────────────────────────────────────── + +#[test] +fn ctxvm_server_announcement_has_kind_and_required_tags() { + let server_keys = Keys::generate(); + + // MCP-flavoured JSON for wire conformance; not the same content shape as `NostrServerTransport::announce` (flat `ServerInfo` only). + let content = serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "serverInfo": { "name": "Test Server" }, + "capabilities": {}, + }); + let content_str = serde_json::to_string(&content).expect("announcement content must serialize"); + + let announcement_tags = vec![ + Tag::custom( + TagKind::Custom(tags::NAME.into()), + vec!["Test Server".to_string()], + ), + Tag::custom( + TagKind::Custom(tags::ABOUT.into()), + vec!["A test server".to_string()], + ), + Tag::custom( + TagKind::Custom(tags::WEBSITE.into()), + vec!["http://localhost".to_string()], + ), + Tag::custom( + TagKind::Custom(tags::SUPPORT_ENCRYPTION.into()), + Vec::::new(), + ), + ]; + + let event = EventBuilder::new(Kind::Custom(SERVER_ANNOUNCEMENT_KIND), content_str) + .tags(announcement_tags) + .sign_with_keys(&server_keys) + .expect("sign server announcement event"); + + assert_eq!( + event.kind, + Kind::Custom(SERVER_ANNOUNCEMENT_KIND), + "server announcement must use kind {}", + SERVER_ANNOUNCEMENT_KIND + ); + assert_eq!(event.pubkey, server_keys.public_key()); + + assert_eq!( + serializers::get_tag_value(&event.tags, tags::NAME).as_deref(), + Some("Test Server") + ); + assert_eq!( + serializers::get_tag_value(&event.tags, tags::ABOUT).as_deref(), + Some("A test server") + ); + assert_eq!( + serializers::get_tag_value(&event.tags, tags::WEBSITE).as_deref(), + Some("http://localhost") + ); + + assert!( + event.tags.iter().any(|t| { + let parts = t.clone().to_vec(); + parts.len() == 1 && parts.first().map(|s| s.as_str()) == Some(tags::SUPPORT_ENCRYPTION) + }), + "support_encryption must be present as a single-element tag" + ); + + let v: serde_json::Value = + serde_json::from_str(&event.content).expect("announcement content must be JSON"); + assert_eq!(v["protocolVersion"], mcp_protocol_version()); + assert_eq!(v["serverInfo"]["name"], "Test Server"); + assert!(v["capabilities"].is_object()); +} diff --git a/tests/integration.rs b/tests/integration.rs new file mode 100644 index 0000000..339f2a0 --- /dev/null +++ b/tests/integration.rs @@ -0,0 +1,195 @@ +//! Local RMCP integration test (in-process duplex I/O, no relay required). +//! Relay-dependent scenarios live in `examples/rmcp_integration_test.rs` +//! and run via the `integration.yml` workflow against a local relay container. + +#![cfg(feature = "rmcp")] + +use rmcp::{ + handler::server::router::tool::ToolRouter, handler::server::wrapper::Parameters, model::*, + schemars, service::RequestContext, tool, tool_handler, tool_router, ClientHandler, RoleServer, + ServerHandler, ServiceExt, +}; +use std::sync::Arc; +use tokio::sync::Mutex; + +// Minimal fixture: same tools as examples/rmcp_integration_test.rs + +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +struct EchoParams { + message: String, +} + +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +struct AddParams { + a: i64, + b: i64, +} + +#[derive(Clone)] +struct DemoServer { + echo_count: Arc>, + tool_router: ToolRouter, +} + +impl DemoServer { + fn new() -> Self { + Self { + echo_count: Arc::new(Mutex::new(0)), + tool_router: Self::tool_router(), + } + } +} + +#[tool_router] +impl DemoServer { + #[tool(description = "Echo a message back")] + async fn echo( + &self, + Parameters(EchoParams { message }): Parameters, + ) -> Result { + let mut n = self.echo_count.lock().await; + *n += 1; + Ok(CallToolResult::success(vec![Content::text(format!( + "Echo #{n}: {message}" + ))])) + } + + #[tool(description = "Add two integers")] + fn add( + &self, + Parameters(AddParams { a, b }): Parameters, + ) -> Result { + Ok(CallToolResult::success(vec![Content::text(format!( + "{a} + {b} = {}", + a + b + ))])) + } + + #[tool(description = "Return total echo calls")] + async fn get_echo_count(&self) -> Result { + let n = self.echo_count.lock().await; + Ok(CallToolResult::success(vec![Content::text(format!( + "Total echo calls: {n}" + ))])) + } +} + +#[tool_handler] +impl ServerHandler for DemoServer { + fn get_info(&self) -> ServerInfo { + ServerInfo { + protocol_version: ProtocolVersion::LATEST, + capabilities: ServerCapabilities::builder() + .enable_tools() + .enable_resources() + .build(), + server_info: Implementation { + name: "integration-test".to_string(), + title: None, + version: "0.1.0".to_string(), + description: None, + icons: None, + website_url: None, + }, + instructions: None, + } + } + + async fn list_resources( + &self, + _req: Option, + _ctx: RequestContext, + ) -> Result { + Ok(ListResourcesResult { + resources: vec![ + RawResource::new("demo://readme", "Demo README".to_string()).no_annotation() + ], + next_cursor: None, + meta: None, + }) + } + + async fn read_resource( + &self, + req: ReadResourceRequestParams, + _ctx: RequestContext, + ) -> Result { + match req.uri.as_str() { + "demo://readme" => Ok(ReadResourceResult { + contents: vec![ResourceContents::text("Demo content.", req.uri)], + }), + other => Err(ErrorData::resource_not_found( + "not_found", + Some(serde_json::json!({ "uri": other })), + )), + } + } +} + +#[derive(Clone, Default)] +struct DemoClient; +impl ClientHandler for DemoClient {} + +fn first_text(result: &CallToolResult) -> String { + result + .content + .iter() + .find_map(|c| match &c.raw { + RawContent::Text(t) => Some(t.text.clone()), + _ => None, + }) + .unwrap_or_default() +} + +// ── Test ───────────────────────────────────────────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_local_rmcp() { + let (server_io, client_io) = tokio::io::duplex(65536); + + let server_handle = tokio::spawn(async move { + DemoServer::new() + .serve(server_io) + .await + .expect("serve") + .waiting() + .await + .expect("server error"); + }); + + let client = DemoClient.serve(client_io).await.expect("client init"); + + let tools = client.list_all_tools().await.expect("list tools"); + assert_eq!(tools.len(), 3); + + let add = client + .call_tool(CallToolRequestParams { + name: "add".into(), + arguments: serde_json::from_value(serde_json::json!({ "a": 7, "b": 5 })).ok(), + meta: None, + task: None, + }) + .await + .expect("call add"); + assert!(first_text(&add).contains("12")); + + let resources = client.list_all_resources().await.expect("list resources"); + assert_eq!(resources.len(), 1); + + match client + .call_tool(CallToolRequestParams { + name: "no_such_tool".into(), + arguments: None, + meta: None, + task: None, + }) + .await + { + Err(_) => {} + Ok(r) if r.is_error.unwrap_or(false) => {} + Ok(_) => panic!("expected unknown tool to fail"), + } + + client.cancel().await.expect("cancel"); + server_handle.abort(); +} diff --git a/tests/transport_integration.rs b/tests/transport_integration.rs new file mode 100644 index 0000000..7bcaf1a --- /dev/null +++ b/tests/transport_integration.rs @@ -0,0 +1,3379 @@ +//! Integration tests — transport-level flows using MockRelayPool. +//! +//! Each test wires client and/or server transports to an in-memory mock relay +//! network so that the full event-loop logic (subscription, publish, routing, +//! encryption-mode enforcement, and authorization) is exercised without +//! connecting to real relays. + +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; +use std::time::Duration; + +use async_trait::async_trait; +use contextvm_sdk::core::constants::tags; +use contextvm_sdk::core::constants::{ + mcp_protocol_version, CTXVM_MESSAGES_KIND, EPHEMERAL_GIFT_WRAP_KIND, GIFT_WRAP_KIND, + PROMPTS_LIST_KIND, RESOURCES_LIST_KIND, RESOURCETEMPLATES_LIST_KIND, SERVER_ANNOUNCEMENT_KIND, + TOOLS_LIST_KIND, +}; +use contextvm_sdk::core::types::{EncryptionMode, GiftWrapMode}; +use contextvm_sdk::relay::mock::MockRelayPool; +use contextvm_sdk::transport::client::{NostrClientTransport, NostrClientTransportConfig}; +use contextvm_sdk::transport::server::{NostrServerTransport, NostrServerTransportConfig}; +use contextvm_sdk::{ + CapabilityExclusion, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, + RelayPoolTrait, ServerInfo, +}; +use nostr_sdk::prelude::*; + +fn as_pool(pool: MockRelayPool) -> Arc { + Arc::new(pool) +} + +struct TestRelayPool { + inner: Arc, + publish_delay: Duration, + failures_remaining: AtomicUsize, + publish_attempts: AtomicUsize, +} + +impl TestRelayPool { + fn with_publish_delay(inner: Arc, publish_delay: Duration) -> Self { + Self { + inner, + publish_delay, + failures_remaining: AtomicUsize::new(0), + publish_attempts: AtomicUsize::new(0), + } + } + + fn with_publish_failures(inner: Arc, failures: usize) -> Self { + Self { + inner, + publish_delay: Duration::ZERO, + failures_remaining: AtomicUsize::new(failures), + publish_attempts: AtomicUsize::new(0), + } + } + + fn publish_attempts(&self) -> usize { + self.publish_attempts.load(Ordering::SeqCst) + } +} + +#[async_trait] +impl RelayPoolTrait for TestRelayPool { + async fn connect(&self, relay_urls: &[String]) -> contextvm_sdk::Result<()> { + self.inner.connect(relay_urls).await + } + + async fn disconnect(&self) -> contextvm_sdk::Result<()> { + self.inner.disconnect().await + } + + async fn publish_event(&self, event: &Event) -> contextvm_sdk::Result { + if !self.publish_delay.is_zero() { + tokio::time::sleep(self.publish_delay).await; + } + self.publish_attempts.fetch_add(1, Ordering::SeqCst); + let should_fail = self + .failures_remaining + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |remaining| { + remaining.checked_sub(1) + }) + .is_ok(); + + if should_fail { + return Err(contextvm_sdk::Error::Transport( + "injected publish failure".to_string(), + )); + } + + self.inner.publish_event(event).await + } + + async fn publish(&self, builder: EventBuilder) -> contextvm_sdk::Result { + if !self.publish_delay.is_zero() { + tokio::time::sleep(self.publish_delay).await; + } + self.inner.publish(builder).await + } + + async fn sign(&self, builder: EventBuilder) -> contextvm_sdk::Result { + self.inner.sign(builder).await + } + + async fn signer(&self) -> contextvm_sdk::Result> { + self.inner.signer().await + } + + fn notifications(&self) -> tokio::sync::broadcast::Receiver { + self.inner.notifications() + } + + async fn public_key(&self) -> contextvm_sdk::Result { + self.inner.public_key().await + } + + async fn subscribe(&self, filters: Vec) -> contextvm_sdk::Result<()> { + self.inner.subscribe(filters).await + } +} + +/// Let spawned event loops call `notifications()` before we publish anything. +/// Without this, broadcast messages can be lost on slow CI runners. +async fn let_event_loops_start() { + tokio::time::sleep(Duration::from_millis(10)).await; +} + +// ── 1. Full initialization handshake ──────────────────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn full_initialization_handshake() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + let mut client_rx = client + .take_message_receiver() + .expect("client message receiver"); + + server.start().await.expect("server start"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + // Client sends initialize request. + let init_request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: Some(serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "capabilities": {}, + "clientInfo": { "name": "test-client", "version": "0.0.0" } + })), + }); + client + .send(&init_request) + .await + .expect("client send initialize"); + + // Server should receive the initialize request. + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout waiting for server to receive init request") + .expect("server channel closed"); + + assert_eq!( + incoming.message.method(), + Some("initialize"), + "server must receive initialize request" + ); + + // Server sends initialize response. + let init_response = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + result: serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "serverInfo": { "name": "test-server", "version": "0.0.0" }, + "capabilities": {} + }), + }); + server + .send_response(&incoming.event_id, init_response) + .await + .expect("server send response"); + + // Client should receive the initialize response. + let response = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .expect("timeout waiting for client to receive init response") + .expect("client channel closed"); + + assert!(response.is_response(), "client must receive a response"); + assert_eq!(response.id(), Some(&serde_json::json!(1))); +} + +// ── 2. Server announcement publishing ─────────────────────────────────────── + +#[tokio::test] +async fn server_announcement_publishing() { + let pool = Arc::new(MockRelayPool::new()); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Disabled) + .with_server_info(ServerInfo::default().with_name("Phase3-Test-Server".to_string())) + .with_announced_server(true), + Arc::clone(&pool) as Arc, + ) + .await + .expect("create server transport"); + + server.start().await.expect("server start"); + server.announce().await.expect("server announce"); + + let events = pool.stored_events().await; + let announcement = events + .iter() + .find(|e| e.kind == Kind::Custom(SERVER_ANNOUNCEMENT_KIND)); + + assert!( + announcement.is_some(), + "kind {} event must be published after announce()", + SERVER_ANNOUNCEMENT_KIND + ); + + let ann = announcement.unwrap(); + let content: serde_json::Value = + serde_json::from_str(&ann.content).expect("announcement content must be JSON"); + assert_eq!( + content["name"], "Phase3-Test-Server", + "announcement content must include server name" + ); +} + +// ── 3. Encryption mode Optional accepts plaintext ─────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn encryption_mode_optional_accepts_plaintext() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + // Server uses Optional — should accept both encrypted and plaintext. + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Optional), + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + server.start().await.expect("server start"); + + // Client uses Disabled — sends plaintext kind 25910. + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + client.start().await.expect("client start"); + let_event_loops_start().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("plain-1"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send plaintext request"); + + // Server must receive and process the plaintext message. + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout waiting for server to receive plaintext request") + .expect("server channel closed"); + + assert_eq!( + incoming.message.method(), + Some("tools/list"), + "Optional-mode server must accept plaintext kind 25910" + ); + assert!( + !incoming.is_encrypted, + "plaintext request must not be marked as encrypted" + ); +} + +// ── 4. Auth allowlist blocks disallowed pubkey ────────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn auth_allowlist_blocks_disallowed_pubkey() { + let allowed_keys = Keys::generate(); // a DIFFERENT pubkey + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + // Server allows only `allowed_keys` — client_keys is NOT allowed. + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Disabled) + .with_allowed_public_keys(vec![allowed_keys.public_key().to_hex()]), + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + server.start().await.expect("server start"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + client.start().await.expect("client start"); + let_event_loops_start().await; + + // Send a non-initialize request (those are always allowed). + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(42), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send request"); + + // The server should NOT forward the request (pubkey is disallowed). + let result = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()).await; + assert!( + result.is_err(), + "disallowed pubkey request must not reach the server handler" + ); +} + +// ── 5. Encryption mode Required drops plaintext ───────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn encryption_mode_required_drops_plaintext() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + // Server requires encryption — plaintext must be dropped. + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Required), + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + server.start().await.expect("server start"); + + // Client sends plaintext (Disabled mode). + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + client.start().await.expect("client start"); + let_event_loops_start().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("drop-me"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send plaintext request"); + + // Server must NOT receive the plaintext message. + let result = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()).await; + assert!( + result.is_err(), + "Required-mode server must drop plaintext kind 25910 events" + ); +} + +// ── 6. Encrypted gift-wrap roundtrip ──────────────────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn encrypted_gift_wrap_roundtrip() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + let server_pool = Arc::new(server_pool); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Required), + Arc::clone(&server_pool) as Arc, + ) + .await + .expect("create server transport"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Required), + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + let mut client_rx = client + .take_message_receiver() + .expect("client message receiver"); + + server.start().await.expect("server start"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + // Client sends encrypted request. + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("enc-1"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send encrypted request"); + + // Verify the published event is a gift-wrap (kind 1059). + let events = server_pool.stored_events().await; + assert!( + events + .iter() + .any(|e| e.kind == Kind::Custom(GIFT_WRAP_KIND)), + "client must publish a kind 1059 gift-wrap event" + ); + + // Server should decrypt and receive the request. + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout waiting for server to decrypt gift-wrap request") + .expect("server channel closed"); + + assert_eq!(incoming.message.method(), Some("tools/list")); + assert!(incoming.is_encrypted, "message must be marked encrypted"); + + // Server sends an encrypted response back. + let response = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("enc-1"), + result: serde_json::json!({ "tools": [] }), + }); + server + .send_response(&incoming.event_id, response) + .await + .expect("server send encrypted response"); + + // Client should decrypt and receive the response. + let client_msg = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .expect("timeout waiting for client to decrypt gift-wrap response") + .expect("client channel closed"); + + assert!(client_msg.is_response()); + assert_eq!(client_msg.id(), Some(&serde_json::json!("enc-1"))); +} + +// ── 7. Gift-wrap dedup skips duplicate delivery ───────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn gift_wrap_dedup_skips_duplicate_delivery() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + let server_pool = Arc::new(server_pool); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Required), + Arc::clone(&server_pool) as Arc, + ) + .await + .expect("create server transport"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Required), + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + + server.start().await.expect("server start"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + // Client sends a gift-wrapped request. + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("dedup-1"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send request"); + + // Server receives the first delivery. + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout waiting for first delivery") + .expect("server channel closed"); + assert_eq!(incoming.message.method(), Some("tools/list")); + assert!(incoming.is_encrypted); + + // Re-deliver the same gift-wrap event (simulates relay redelivery). + let events = server_pool.stored_events().await; + let gift_wrap = events + .iter() + .find(|e| e.kind == Kind::Custom(GIFT_WRAP_KIND)) + .expect("gift-wrap event must exist") + .clone(); + server_pool + .publish_event(&gift_wrap) + .await + .expect("re-inject duplicate"); + + // Server must NOT process the duplicate. + let result = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()).await; + assert!( + result.is_err(), + "duplicate gift-wrap (same outer event id) must be skipped" + ); +} + +// ── 8. Correlated notification has e tag ───────────────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn correlated_notification_has_e_tag() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + let server_pool = Arc::new(server_pool); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), + Arc::clone(&server_pool) as Arc, + ) + .await + .expect("create server transport"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + let mut client_rx = client + .take_message_receiver() + .expect("client message receiver"); + + server.start().await.expect("server start"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + // Client sends a tools/list request. + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("notif-corr"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send request"); + + // Server receives the request and captures the event_id. + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout waiting for server to receive request") + .expect("server channel closed"); + assert_eq!(incoming.message.method(), Some("tools/list")); + let request_event_id = incoming.event_id.clone(); + + // Server sends a correlated notifications/progress notification. + let notification = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/progress".to_string(), + params: Some(serde_json::json!({ + "progressToken": "tok-1", + "progress": 50, + "total": 100 + })), + }); + server + .send_notification( + &incoming.client_pubkey, + ¬ification, + Some(&request_event_id), + ) + .await + .expect("send correlated notification"); + + // Client should receive the notification. + let client_msg = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .expect("timeout waiting for client to receive notification") + .expect("client channel closed"); + + assert!(client_msg.is_notification()); + assert_eq!(client_msg.method(), Some("notifications/progress")); + + // The published notification event must carry an e tag referencing the request. + let events = server_pool.stored_events().await; + let notif_event = events + .iter() + .find(|e| e.pubkey == server_pubkey && e.content.contains("notifications/progress")) + .expect("notification event must be in stored events"); + + let e_tag = contextvm_sdk::core::serializers::get_tag_value(¬if_event.tags, "e"); + assert_eq!( + e_tag.as_deref(), + Some(request_event_id.as_str()), + "notification event must have e tag referencing the original request event id" + ); +} + +// ── 9. Encryption Required client, Optional server ────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn encryption_required_client_optional_server() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Optional), + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Required), + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + + server.start().await.expect("server start"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("enc-opt-1"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send encrypted request"); + + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout waiting for server to receive encrypted request") + .expect("server channel closed"); + + assert_eq!( + incoming.message.method(), + Some("tools/list"), + "Optional-mode server must accept encrypted messages from Required-mode client" + ); + assert!( + incoming.is_encrypted, + "message from Required-mode client must be marked encrypted" + ); +} + +// ── 10. Encryption Optional both sides, encrypted path ────────────────────── +// Optional client defaults to encrypting (unwrap_or(true)), Optional server +// accepts encrypted messages. Tests the Optional/Optional negotiation path. + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn encryption_optional_both_sides_encrypted_path() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Optional), + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Optional), + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + + server.start().await.expect("server start"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("opt-both-1"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send request"); + + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout waiting for server to receive request") + .expect("server channel closed"); + + assert_eq!(incoming.message.method(), Some("tools/list")); + assert!( + incoming.is_encrypted, + "Optional client defaults to encrypting; Optional server must accept" + ); +} + +// ── 11. Announce includes encryption tags ──────────────────────────────────── + +#[tokio::test] +async fn announce_includes_encryption_tags() { + let pool = Arc::new(MockRelayPool::new()); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Required) + .with_server_info(ServerInfo::default().with_name("Encrypted-Server".to_string())) + .with_announced_server(true), + Arc::clone(&pool) as Arc, + ) + .await + .expect("create server transport"); + + server.start().await.expect("server start"); + server.announce().await.expect("server announce"); + + let events = pool.stored_events().await; + let announcement = events + .iter() + .find(|e| e.kind == Kind::Custom(SERVER_ANNOUNCEMENT_KIND)) + .expect("kind 11316 event must be published"); + + // support_encryption is a valueless tag — check tag name directly. + let has_support_encryption = announcement + .tags + .iter() + .any(|t| t.clone().to_vec().first().map(|s| s.as_str()) == Some("support_encryption")); + let has_support_encryption_ephemeral = announcement.tags.iter().any(|t| { + t.clone().to_vec().first().map(|s| s.as_str()) == Some("support_encryption_ephemeral") + }); + + assert!( + has_support_encryption, + "announcement must include support_encryption tag" + ); + assert!( + has_support_encryption_ephemeral, + "announcement must include support_encryption_ephemeral tag" + ); +} + +// ── 12. Announce includes server metadata tags ────────────────────────────── + +#[tokio::test] +async fn announce_includes_server_metadata_tags() { + let pool = Arc::new(MockRelayPool::new()); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Disabled) + .with_server_info( + ServerInfo::default() + .with_name("Meta-Server".to_string()) + .with_about("A test server".to_string()) + .with_website("https://example.com".to_string()) + .with_picture("https://example.com/pic.png".to_string()), + ) + .with_announced_server(true), + Arc::clone(&pool) as Arc, + ) + .await + .expect("create server transport"); + + server.start().await.expect("server start"); + server.announce().await.expect("server announce"); + + let events = pool.stored_events().await; + let announcement = events + .iter() + .find(|e| e.kind == Kind::Custom(SERVER_ANNOUNCEMENT_KIND)) + .expect("kind 11316 event must be published"); + + let name_tag = contextvm_sdk::core::serializers::get_tag_value(&announcement.tags, "name"); + let about_tag = contextvm_sdk::core::serializers::get_tag_value(&announcement.tags, "about"); + let website_tag = + contextvm_sdk::core::serializers::get_tag_value(&announcement.tags, "website"); + let picture_tag = + contextvm_sdk::core::serializers::get_tag_value(&announcement.tags, "picture"); + + assert_eq!( + name_tag.as_deref(), + Some("Meta-Server"), + "name tag must be present" + ); + assert_eq!( + about_tag.as_deref(), + Some("A test server"), + "about tag must be present" + ); + assert_eq!( + website_tag.as_deref(), + Some("https://example.com"), + "website tag must be present" + ); + assert_eq!( + picture_tag.as_deref(), + Some("https://example.com/pic.png"), + "picture tag must be present" + ); +} + +// ── 13. Publish tools produces correct kind ───────────────────────────────── + +#[tokio::test] +async fn publish_tools_produces_correct_kind() { + let pool = Arc::new(MockRelayPool::new()); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Disabled) + .with_server_info(ServerInfo::default().with_name("Tools-Server".to_string())) + .with_announced_server(true), + Arc::clone(&pool) as Arc, + ) + .await + .expect("create server transport"); + + server.start().await.expect("server start"); + server.announce().await.expect("server announce"); + + let tools = vec![serde_json::json!({ + "name": "get_weather", + "description": "Get the weather", + "inputSchema": { "type": "object" } + })]; + server.publish_tools(tools).await.expect("publish tools"); + + let events = pool.stored_events().await; + let tools_event = events + .iter() + .find(|e| e.kind == Kind::Custom(TOOLS_LIST_KIND)) + .expect("kind 11317 event must be published"); + + let content: serde_json::Value = + serde_json::from_str(&tools_event.content).expect("tools content must be JSON"); + assert!( + content.get("tools").is_some(), + "tools event content must contain 'tools' key" + ); + let tools_arr = content["tools"].as_array().expect("tools must be an array"); + assert_eq!(tools_arr.len(), 1); + assert_eq!(tools_arr[0]["name"], "get_weather"); +} + +// ── 14. Broadcast notification reaches initialized client ───────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn broadcast_notification_reaches_initialized_client() { + let (c1_pool, s_pool) = MockRelayPool::create_pair(); + let server_pk = s_pool.mock_public_key(); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), + as_pool(s_pool), + ) + .await + .expect("create server transport"); + + let mut srv_rx = server + .take_message_receiver() + .expect("server message receiver"); + server.start().await.expect("server start"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pk.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + as_pool(c1_pool), + ) + .await + .expect("create client transport"); + let mut c_rx = client + .take_message_receiver() + .expect("client message receiver"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + // Client sends initialize request. + let init_req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: Some(serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "capabilities": {}, + "clientInfo": { "name": "c1", "version": "0.0.0" } + })), + }); + client + .send(&init_req) + .await + .expect("client send initialize"); + + let incoming = tokio::time::timeout(Duration::from_millis(500), srv_rx.recv()) + .await + .expect("timeout") + .expect("channel closed"); + + // Server responds to initialize. + let init_resp = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + result: serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "serverInfo": { "name": "test-server", "version": "0.0.0" }, + "capabilities": {} + }), + }); + server + .send_response(&incoming.event_id, init_resp) + .await + .expect("send init response"); + + // Client receives the init response. + let _ = tokio::time::timeout(Duration::from_millis(500), c_rx.recv()) + .await + .expect("timeout") + .expect("channel closed"); + + // Client sends notifications/initialized → session becomes initialized. + let init_notif = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/initialized".to_string(), + params: None, + }); + client + .send(&init_notif) + .await + .expect("send initialized notification"); + + // Drain srv_rx until we see notifications/initialized (skipping any + // echoed events from the shared mock relay broadcast channel). + loop { + let msg = tokio::time::timeout(Duration::from_millis(500), srv_rx.recv()) + .await + .expect("timeout waiting for notifications/initialized on server") + .expect("server channel closed"); + if msg.message.method() == Some("notifications/initialized") { + break; + } + } + + // Now broadcast — only the initialized client session should receive it. + let broadcast = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/progress".to_string(), + params: Some(serde_json::json!({ "progressToken": "bc-1", "progress": 1, "total": 1 })), + }); + server + .broadcast_notification(&broadcast) + .await + .expect("broadcast notification"); + + let msg = tokio::time::timeout(Duration::from_millis(500), c_rx.recv()) + .await + .expect("timeout waiting for client to receive broadcast") + .expect("client channel closed"); + + assert_eq!(msg.method(), Some("notifications/progress")); +} + +// ── 15. Uncorrelated notification passes through ──────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn uncorrelated_notification_passes_through() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + let mut client_rx = client + .take_message_receiver() + .expect("client message receiver"); + + server.start().await.expect("server start"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + let init_req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("unc-init"), + method: "initialize".to_string(), + params: Some(serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "capabilities": {}, + "clientInfo": { "name": "unc-test", "version": "0.0.0" } + })), + }); + client.send(&init_req).await.expect("send initialize"); + + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout") + .expect("channel closed"); + + let init_resp = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("unc-init"), + result: serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "serverInfo": { "name": "test", "version": "0.0.0" }, + "capabilities": {} + }), + }); + server + .send_response(&incoming.event_id, init_resp) + .await + .expect("send init response"); + + let _ = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .expect("timeout") + .expect("channel closed"); + + // Uncorrelated notification (no e tag) must pass through to client. + let notification = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/progress".to_string(), + params: Some(serde_json::json!({ "progressToken": "unc-1", "progress": 50, "total": 100 })), + }); + server + .send_notification(&incoming.client_pubkey, ¬ification, None) + .await + .expect("send uncorrelated notification"); + + let client_msg = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .expect("timeout waiting for client to receive notification") + .expect("client channel closed"); + + assert!(client_msg.is_notification()); + assert_eq!(client_msg.method(), Some("notifications/progress")); +} + +// ── 16. Correlated notification with unknown e tag is dropped ─────────────── +// NOTE: The Rust SDK drops ANY server event whose e-tag references an unknown +// pending request, including notifications. The TS SDK may forward such events. +// This test documents the Rust SDK's stricter correlation enforcement. + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn correlated_notification_unknown_e_tag_is_dropped() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + let mut client_rx = client + .take_message_receiver() + .expect("client message receiver"); + + server.start().await.expect("server start"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + let init_req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("corr-init"), + method: "initialize".to_string(), + params: Some(serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "capabilities": {}, + "clientInfo": { "name": "corr-test", "version": "0.0.0" } + })), + }); + client.send(&init_req).await.expect("send initialize"); + + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout") + .expect("channel closed"); + + let init_resp = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("corr-init"), + result: serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "serverInfo": { "name": "test", "version": "0.0.0" }, + "capabilities": {} + }), + }); + server + .send_response(&incoming.event_id, init_resp) + .await + .expect("send init response"); + + let _ = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .expect("timeout") + .expect("channel closed"); + + // Notification with e tag referencing unknown event id must be dropped. + let fake_event_id = "a".repeat(64); + let notification = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/progress".to_string(), + params: Some(serde_json::json!({ "progressToken": "fake", "progress": 1, "total": 1 })), + }); + server + .send_notification(&incoming.client_pubkey, ¬ification, Some(&fake_event_id)) + .await + .expect("send notification with unknown e tag"); + + let result = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()).await; + assert!( + result.is_err(), + "notification with unknown e tag must be dropped by client" + ); +} + +// ── 17. Auth: allowed pubkey receives response ────────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn auth_allowed_pubkey_receives_response() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + let client_pubkey = client_pool.mock_public_key(); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Disabled) + .with_allowed_public_keys(vec![client_pubkey.to_hex()]), + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + server.start().await.expect("server start"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut client_rx = client + .take_message_receiver() + .expect("client message receiver"); + + client.start().await.expect("client start"); + let_event_loops_start().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("auth-1"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send request"); + + // Server should receive it (pubkey is in the allowlist). + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout waiting for server to receive request") + .expect("server channel closed"); + + assert_eq!(incoming.message.method(), Some("tools/list")); + + // Server sends response back. + let response = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("auth-1"), + result: serde_json::json!({ "tools": [] }), + }); + server + .send_response(&incoming.event_id, response) + .await + .expect("send response"); + + // Client should receive the response. + let client_msg = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .expect("timeout waiting for client to receive response") + .expect("client channel closed"); + + assert!(client_msg.is_response()); + assert_eq!(client_msg.id(), Some(&serde_json::json!("auth-1"))); +} + +// ── 18. Excluded capability bypasses auth ─────────────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn excluded_capability_bypasses_auth() { + let allowed_keys = Keys::generate(); // a DIFFERENT pubkey, NOT the client + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Disabled) + .with_allowed_public_keys(vec![allowed_keys.public_key().to_hex()]) + .with_excluded_capabilities(vec![CapabilityExclusion { + method: "tools/list".to_string(), + name: None, + }]), + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + server.start().await.expect("server start"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + client.start().await.expect("client start"); + let_event_loops_start().await; + + // Client's pubkey is NOT in the allowlist, but "tools/list" is excluded from auth. + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("excl-1"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send request"); + + // Server should receive it because the method is in excluded_capabilities. + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout waiting for server to receive excluded-capability request") + .expect("server channel closed"); + + assert_eq!( + incoming.message.method(), + Some("tools/list"), + "excluded capability must bypass auth allowlist" + ); +} + +// ── 19. Publish resources produces correct kind ───────────────────────────── + +#[tokio::test] +async fn publish_resources_produces_correct_kind() { + let pool = Arc::new(MockRelayPool::new()); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), + Arc::clone(&pool) as Arc, + ) + .await + .expect("create server transport"); + + server.start().await.expect("server start"); + + let resources = vec![serde_json::json!({ + "uri": "file:///readme.md", + "name": "readme", + "mimeType": "text/markdown" + })]; + server + .publish_resources(resources) + .await + .expect("publish resources"); + + let events = pool.stored_events().await; + let event = events + .iter() + .find(|e| e.kind == Kind::Custom(RESOURCES_LIST_KIND)) + .expect("kind 11318 event must be published"); + + let content: serde_json::Value = + serde_json::from_str(&event.content).expect("content must be JSON"); + let arr = content["resources"] + .as_array() + .expect("resources must be an array"); + assert_eq!(arr.len(), 1); + assert_eq!(arr[0]["name"], "readme"); +} + +// ── 20. Publish prompts produces correct kind ─────────────────────────────── + +#[tokio::test] +async fn publish_prompts_produces_correct_kind() { + let pool = Arc::new(MockRelayPool::new()); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), + Arc::clone(&pool) as Arc, + ) + .await + .expect("create server transport"); + + server.start().await.expect("server start"); + + let prompts = vec![serde_json::json!({ + "name": "summarize", + "description": "Summarize text" + })]; + server + .publish_prompts(prompts) + .await + .expect("publish prompts"); + + let events = pool.stored_events().await; + let event = events + .iter() + .find(|e| e.kind == Kind::Custom(PROMPTS_LIST_KIND)) + .expect("kind 11320 event must be published"); + + let content: serde_json::Value = + serde_json::from_str(&event.content).expect("content must be JSON"); + let arr = content["prompts"] + .as_array() + .expect("prompts must be an array"); + assert_eq!(arr.len(), 1); + assert_eq!(arr[0]["name"], "summarize"); +} + +// ── 21. Publish resource templates produces correct kind ──────────────────── + +#[tokio::test] +async fn publish_resource_templates_produces_correct_kind() { + let pool = Arc::new(MockRelayPool::new()); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), + Arc::clone(&pool) as Arc, + ) + .await + .expect("create server transport"); + + server.start().await.expect("server start"); + + let templates = vec![serde_json::json!({ + "uriTemplate": "file:///{path}", + "name": "file", + "mimeType": "application/octet-stream" + })]; + server + .publish_resource_templates(templates) + .await + .expect("publish resource templates"); + + let events = pool.stored_events().await; + let event = events + .iter() + .find(|e| e.kind == Kind::Custom(RESOURCETEMPLATES_LIST_KIND)) + .expect("kind 11319 event must be published"); + + let content: serde_json::Value = + serde_json::from_str(&event.content).expect("content must be JSON"); + let arr = content["resourceTemplates"] + .as_array() + .expect("resourceTemplates must be an array"); + assert_eq!(arr.len(), 1); + assert_eq!(arr[0]["name"], "file"); +} + +// ── 22. Publish tools with empty list ─────────────────────────────────────── + +#[tokio::test] +async fn publish_tools_empty_list() { + let pool = Arc::new(MockRelayPool::new()); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), + Arc::clone(&pool) as Arc, + ) + .await + .expect("create server transport"); + + server.start().await.expect("server start"); + server + .publish_tools(vec![]) + .await + .expect("publish empty tools"); + + let events = pool.stored_events().await; + let event = events + .iter() + .find(|e| e.kind == Kind::Custom(TOOLS_LIST_KIND)) + .expect("kind 11317 event must be published for empty list"); + + let content: serde_json::Value = + serde_json::from_str(&event.content).expect("content must be JSON"); + let arr = content["tools"].as_array().expect("tools must be an array"); + assert!(arr.is_empty(), "empty tools list must produce tools: []"); +} + +// ── 23. Delete announcements k tags match kinds ───────────────────────────── + +#[tokio::test] +async fn delete_announcements_k_tags_match_kinds() { + let pool = Arc::new(MockRelayPool::new()); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Disabled) + .with_server_info(ServerInfo::default().with_name("KTag-Server".to_string())) + .with_announced_server(true), + Arc::clone(&pool) as Arc, + ) + .await + .expect("create server transport"); + + server.start().await.expect("server start"); + server.announce().await.expect("server announce"); + server + .delete_announcements("shutting down") + .await + .expect("delete announcements"); + + let events = pool.stored_events().await; + let kind5_events: Vec<_> = events + .iter() + .filter(|e| e.kind == Kind::Custom(5)) + .collect(); + + assert_eq!(kind5_events.len(), 5); + + // Collect k tag values from all kind-5 events. + let mut k_values: Vec = kind5_events + .iter() + .filter_map(|e| { + contextvm_sdk::core::serializers::get_tag_value(&e.tags, "k") + .and_then(|v| v.parse::().ok()) + }) + .collect(); + k_values.sort(); + + let mut expected = vec![ + SERVER_ANNOUNCEMENT_KIND, + TOOLS_LIST_KIND, + RESOURCES_LIST_KIND, + RESOURCETEMPLATES_LIST_KIND, + PROMPTS_LIST_KIND, + ]; + expected.sort(); + + assert_eq!( + k_values, expected, + "each kind-5 event must have a k tag matching one announcement kind" + ); +} + +// ── 24. Encryption Disabled server rejects gift-wrap ──────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn encryption_disabled_server_rejects_gift_wrap() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + // Server has encryption disabled — must reject gift-wrap events. + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + server.start().await.expect("server start"); + + // Client requires encryption — sends gift-wrap (kind 1059). + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Required), + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + client.start().await.expect("client start"); + let_event_loops_start().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("gw-reject"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send encrypted request"); + + let result = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()).await; + assert!( + result.is_err(), + "Disabled-mode server must drop gift-wrap events" + ); +} + +// ── 25. Response mirrors client encryption format ─────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn response_mirrors_client_encryption_format() { + // Part A: Disabled client → Optional server → response must be plaintext (kind 25910). + { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + let server_pool = Arc::new(server_pool); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Optional), + Arc::clone(&server_pool) as Arc, + ) + .await + .expect("create server transport"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + let mut client_rx = client + .take_message_receiver() + .expect("client message receiver"); + + server.start().await.expect("server start"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("mirror-plain"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send plaintext request"); + + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout") + .expect("channel closed"); + assert!(!incoming.is_encrypted); + + let response = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("mirror-plain"), + result: serde_json::json!({ "tools": [] }), + }); + server + .send_response(&incoming.event_id, response) + .await + .expect("send plaintext response"); + + // Client receives the response. + let _ = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .expect("timeout") + .expect("channel closed"); + + // Verify response event is plaintext kind 25910, not gift-wrap. + let events = server_pool.stored_events().await; + let response_events: Vec<_> = events + .iter() + .filter(|e| e.pubkey == server_pubkey && e.content.contains("mirror-plain")) + .collect(); + assert!( + !response_events.is_empty(), + "server must publish a response event" + ); + assert!( + response_events + .iter() + .all(|e| e.kind == Kind::Custom(CTXVM_MESSAGES_KIND)), + "response to plaintext client must be kind {} (plaintext)", + CTXVM_MESSAGES_KIND + ); + } + + // Part B: Required client → Optional server → response must be gift-wrap (kind 1059). + { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + let server_pool = Arc::new(server_pool); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Optional), + Arc::clone(&server_pool) as Arc, + ) + .await + .expect("create server transport"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Required), + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + let mut client_rx = client + .take_message_receiver() + .expect("client message receiver"); + + server.start().await.expect("server start"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("mirror-enc"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send encrypted request"); + + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout") + .expect("channel closed"); + assert!(incoming.is_encrypted); + + // Snapshot gift-wrap count before server responds. + let gw_before = server_pool + .stored_events() + .await + .iter() + .filter(|e| e.kind == Kind::Custom(GIFT_WRAP_KIND)) + .count(); + + let response = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("mirror-enc"), + result: serde_json::json!({ "tools": [] }), + }); + server + .send_response(&incoming.event_id, response) + .await + .expect("send encrypted response"); + + // Client receives the response. + let _ = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .expect("timeout") + .expect("channel closed"); + + // Verify server published exactly one new gift-wrap for the response. + let gw_after = server_pool + .stored_events() + .await + .iter() + .filter(|e| e.kind == Kind::Custom(GIFT_WRAP_KIND)) + .count(); + assert_eq!( + gw_after, + gw_before + 1, + "server must publish one new gift-wrap (kind {}) as the response", + GIFT_WRAP_KIND + ); + } +} + +// ── 26. send_response is one-shot under concurrency ──────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn send_response_is_one_shot_under_concurrency() { + let (client_pool, server_pool_raw) = MockRelayPool::create_pair(); + let server_pubkey = server_pool_raw.mock_public_key(); + let server_pool = Arc::new(server_pool_raw); + + // Delay publish so both concurrent responders have a chance to race. + // Correct behavior is still one-shot: exactly one send_response succeeds. + let delayed_server_pool: Arc = Arc::new(TestRelayPool::with_publish_delay( + Arc::clone(&server_pool), + Duration::from_millis(25), + )); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), + delayed_server_pool, + ) + .await + .expect("create server transport"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + let mut client_rx = client + .take_message_receiver() + .expect("client message receiver"); + + server.start().await.expect("server start"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("one-shot-req"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send request"); + + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout waiting for server to receive request") + .expect("server channel closed"); + + let response = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("placeholder"), + result: serde_json::json!({ "one_shot": "ok" }), + }); + + let event_id = incoming.event_id.clone(); + let f1 = server.send_response(&event_id, response.clone()); + let f2 = server.send_response(&event_id, response); + let (r1, r2) = tokio::join!(f1, f2); + + assert_ne!( + r1.is_ok(), + r2.is_ok(), + "exactly one concurrent send_response call must succeed" + ); + + let msg = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .expect("timeout waiting for client to receive response") + .expect("client channel closed"); + assert!(msg.is_response(), "client must receive one response"); + assert_eq!( + msg.id(), + Some(&serde_json::json!("one-shot-req")), + "server must restore original request id in response" + ); + + let second = tokio::time::timeout(Duration::from_millis(200), client_rx.recv()).await; + assert!( + second.is_err(), + "client must not receive duplicate response" + ); + + let events = server_pool.stored_events().await; + let response_events = events + .iter() + .filter(|e| e.pubkey == server_pubkey && e.content.contains("\"one_shot\":\"ok\"")) + .count(); + assert_eq!( + response_events, 1, + "only one response event must be published" + ); +} + +// ── 27. send_response publish failure allows retry ───────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn send_response_publish_failure_allows_one_successful_retry() { + let (client_pool, server_pool_raw) = MockRelayPool::create_pair(); + let server_pubkey = server_pool_raw.mock_public_key(); + let server_pool = Arc::new(server_pool_raw); + let failing_server_pool = Arc::new(TestRelayPool::with_publish_failures( + Arc::clone(&server_pool), + 1, + )); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), + Arc::clone(&failing_server_pool) as Arc, + ) + .await + .expect("create server transport"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + let mut client_rx = client + .take_message_receiver() + .expect("client message receiver"); + + server.start().await.expect("server start"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("retry-once"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send request"); + + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout waiting for server request") + .expect("server channel closed"); + assert_eq!(incoming.message.method(), Some("tools/list")); + + let response = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("placeholder"), + result: serde_json::json!({ "tools": [] }), + }); + + let stored_before_failure = server_pool.stored_events().await.len(); + server + .send_response(&incoming.event_id, response.clone()) + .await + .expect_err("first response publish must fail"); + + assert_eq!( + failing_server_pool.publish_attempts(), + 1, + "failed response should attempt exactly one publish" + ); + assert_eq!( + server_pool.stored_events().await.len(), + stored_before_failure, + "failed publish must not store a response event" + ); + + server + .send_response(&incoming.event_id, response.clone()) + .await + .expect("retry must still find the route and publish"); + + let client_msg = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .expect("timeout waiting for retried response") + .expect("client channel closed"); + assert!(client_msg.is_response()); + assert_eq!(client_msg.id(), Some(&serde_json::json!("retry-once"))); + assert_eq!( + failing_server_pool.publish_attempts(), + 2, + "retry should perform the second and final publish" + ); + assert_eq!( + server_pool.stored_events().await.len(), + stored_before_failure + 1, + "successful retry must publish exactly one response event" + ); + + server + .send_response(&incoming.event_id, response) + .await + .expect_err("route must be consumed after the successful retry"); + assert_eq!( + failing_server_pool.publish_attempts(), + 2, + "consumed route should fail before another publish attempt" + ); + assert_eq!( + server_pool.stored_events().await.len(), + stored_before_failure + 1, + "post-success retry must not publish another response" + ); + + let second_delivery = tokio::time::timeout(Duration::from_millis(50), client_rx.recv()).await; + assert!( + second_delivery.is_err(), + "client must receive the retried response exactly once" + ); +} + +// ── 28. Announced server sends unauthorized error response ─────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn announced_server_sends_unauthorized_error_response() { + let allowed_keys = Keys::generate(); // a DIFFERENT pubkey — client is NOT in the allowlist + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + // Announced server with an allowlist that does NOT include the client. + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Disabled) + .with_announced_server(true) + .with_allowed_public_keys(vec![allowed_keys.public_key().to_hex()]), + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + server.start().await.expect("server start"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut client_rx = client + .take_message_receiver() + .expect("client message receiver"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + // Send a non-initialize request from the unauthorized client. + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(42), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send request"); + + // The server handler must NOT receive the request (it's unauthorized). + let server_forward = tokio::time::timeout(Duration::from_millis(300), server_rx.recv()).await; + assert!( + server_forward.is_err(), + "unauthorized request must not reach the server handler" + ); + + // The client MUST receive a -32000 Unauthorized error response. + let error_msg = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .expect("timeout waiting for unauthorized error response") + .expect("client channel closed"); + + match error_msg { + JsonRpcMessage::ErrorResponse(err) => { + assert_eq!(err.error.code, -32000, "error code must be -32000"); + assert_eq!( + err.error.message, "Unauthorized", + "error message must be 'Unauthorized'" + ); + } + other => panic!( + "expected ErrorResponse, got: {:?}", + std::mem::discriminant(&other) + ), + } +} + +// ── 29. Private server silently drops unauthorized request ─────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn private_server_silently_drops_unauthorized_request() { + let allowed_keys = Keys::generate(); + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + // Private server (is_announced_server defaults to false). + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Disabled) + .with_allowed_public_keys(vec![allowed_keys.public_key().to_hex()]), + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + server.start().await.expect("server start"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut client_rx = client + .take_message_receiver() + .expect("client message receiver"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(99), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send request"); + + // Server handler must not receive it. + let server_forward = tokio::time::timeout(Duration::from_millis(300), server_rx.recv()).await; + assert!( + server_forward.is_err(), + "unauthorized request must not reach the server handler" + ); + + // Client must NOT receive any error response (private server silently drops). + let client_response = tokio::time::timeout(Duration::from_millis(300), client_rx.recv()).await; + assert!( + client_response.is_err(), + "private server must silently drop unauthorized requests without sending an error" + ); +} + +// ── 30. Announced server does not error on unauthorized notification ───────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn announced_server_does_not_error_on_unauthorized_notification() { + let allowed_keys = Keys::generate(); + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Disabled) + .with_announced_server(true) + .with_allowed_public_keys(vec![allowed_keys.public_key().to_hex()]), + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + server.start().await.expect("server start"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut client_rx = client + .take_message_receiver() + .expect("client message receiver"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + // Send a notification (not a request) from the unauthorized client. + let notification = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/progress".to_string(), + params: None, + }); + client.send(¬ification).await.expect("send notification"); + + // Server handler must not receive the notification. + let server_forward = tokio::time::timeout(Duration::from_millis(300), server_rx.recv()).await; + assert!( + server_forward.is_err(), + "unauthorized notification must not reach the server handler" + ); + + // Client must NOT receive an error (notifications never get error replies). + let client_response = tokio::time::timeout(Duration::from_millis(300), client_rx.recv()).await; + assert!( + client_response.is_err(), + "announced server must not send error response for unauthorized notifications" + ); +} + +// ── 31. First response includes discovery tags (upstream CEP-19) ───────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn first_response_includes_discovery_tags() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + let s_pool = Arc::new(server_pool); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Optional) + .with_gift_wrap_mode(GiftWrapMode::Optional) + .with_server_info(ServerInfo::default().with_name("Disco-Server".to_string())) + .with_announced_server(true), + Arc::clone(&s_pool) as Arc, + ) + .await + .expect("create server transport"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + + server.start().await.expect("server start"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + // Send first request (triggers first response with common tags) + let request1 = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("req-1"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request1).await.expect("send request 1"); + + let incoming1 = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout") + .expect("channel closed"); + + let response1 = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("req-1"), + result: serde_json::json!({ "tools": [] }), + }); + server + .send_response(&incoming1.event_id, response1) + .await + .expect("send response 1"); + + // Send second request (should NOT include common tags) + let request2 = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("req-2"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request2).await.expect("send request 2"); + + let incoming2 = loop { + let msg = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout") + .expect("channel closed"); + if msg.message.is_request() && msg.message.id() == Some(&serde_json::json!("req-2")) { + break msg; + } + }; + + let response2 = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("req-2"), + result: serde_json::json!({ "tools": [] }), + }); + server + .send_response(&incoming2.event_id, response2) + .await + .expect("send response 2"); + + let events = s_pool.stored_events().await; + let responses: Vec<_> = events + .iter() + .filter(|e| e.kind == Kind::Custom(CTXVM_MESSAGES_KIND)) + .cloned() + .collect(); + + let resp1 = responses + .iter() + .find(|e| e.content.contains("req-1") && e.content.contains("result")) + .expect("resp1 missing"); + let resp2 = responses + .iter() + .find(|e| e.content.contains("req-2") && e.content.contains("result")) + .expect("resp2 missing"); + + let name1 = contextvm_sdk::core::serializers::get_tag_value(&resp1.tags, "name"); + let enc1 = resp1 + .tags + .iter() + .any(|t| t.clone().to_vec().first().map(|s| s.as_str()) == Some("support_encryption")); + + let name2 = contextvm_sdk::core::serializers::get_tag_value(&resp2.tags, "name"); + let enc2 = resp2 + .tags + .iter() + .any(|t| t.clone().to_vec().first().map(|s| s.as_str()) == Some("support_encryption")); + + assert_eq!(name1.as_deref(), Some("Disco-Server")); + assert!(enc1); + + assert_eq!(name2, None); + assert!(!enc2); +} + +// ── 32. Notification mirror selection wrt CEP 19 ────────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn notification_mirror_selection_wrt_cep_19() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + let s_pool = Arc::new(server_pool); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Optional) + .with_gift_wrap_mode(GiftWrapMode::Optional), + Arc::clone(&s_pool) as Arc, + ) + .await + .expect("create server transport"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Optional) + .with_gift_wrap_mode(GiftWrapMode::Ephemeral), + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + + server.start().await.expect("server start"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + let request1 = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("req-1"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request1).await.expect("send request 1"); + + let incoming1 = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout") + .expect("channel closed"); + + let notification = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/progress".to_string(), + params: None, + }); + server + .send_notification( + &incoming1.client_pubkey, + ¬ification, + Some(&incoming1.event_id), + ) + .await + .expect("send notification"); + + let events = s_pool.stored_events().await; + let ephemeral_wraps: Vec<_> = events + .iter() + .filter(|e| e.kind == Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND)) + .cloned() + .collect(); + + assert!( + ephemeral_wraps.len() >= 2, + "Expected ephemeral wraps for both request and notification" + ); +} + +// ── CEP-35: Server-side discovery tag emission & capability learning ───────── + +fn event_tag_vecs(event: &Event) -> Vec> { + event.tags.iter().map(|t| t.clone().to_vec()).collect() +} + +fn has_tag_name(event: &Event, name: &str) -> bool { + event_tag_vecs(event) + .iter() + .any(|v| v.first().map(|s| s.as_str()) == Some(name)) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn server_response_includes_encryption_tags_when_enabled() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + let server_pool_arc = Arc::new(server_pool); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Optional) + .with_gift_wrap_mode(GiftWrapMode::Optional), + Arc::clone(&server_pool_arc) as Arc, + ) + .await + .unwrap(); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + as_pool(client_pool), + ) + .await + .unwrap(); + + let mut server_rx = server.take_message_receiver().unwrap(); + let mut client_rx = client.take_message_receiver().unwrap(); + server.start().await.unwrap(); + client.start().await.unwrap(); + let_event_loops_start().await; + + client + .send(&JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: None, + })) + .await + .unwrap(); + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .unwrap() + .unwrap(); + server + .send_response( + &incoming.event_id, + JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + result: serde_json::json!({}), + }), + ) + .await + .unwrap(); + let _ = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .unwrap(); + + let events = server_pool_arc.stored_events().await; + let response_event = events + .iter() + .find(|e| e.kind == Kind::Custom(CTXVM_MESSAGES_KIND) && has_tag_name(e, "e")) + .expect("response event must exist"); + + assert!( + has_tag_name(response_event, tags::SUPPORT_ENCRYPTION), + "first response must include support_encryption when mode != Disabled" + ); + assert!( + has_tag_name(response_event, tags::SUPPORT_ENCRYPTION_EPHEMERAL), + "first response must include support_encryption_ephemeral when GiftWrapMode != Persistent" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn server_response_excludes_ephemeral_tag_when_persistent() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + let server_pool_arc = Arc::new(server_pool); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Optional) + .with_gift_wrap_mode(GiftWrapMode::Persistent), + Arc::clone(&server_pool_arc) as Arc, + ) + .await + .unwrap(); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + as_pool(client_pool), + ) + .await + .unwrap(); + + let mut server_rx = server.take_message_receiver().unwrap(); + let mut client_rx = client.take_message_receiver().unwrap(); + server.start().await.unwrap(); + client.start().await.unwrap(); + let_event_loops_start().await; + + client + .send(&JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: None, + })) + .await + .unwrap(); + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .unwrap() + .unwrap(); + server + .send_response( + &incoming.event_id, + JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + result: serde_json::json!({}), + }), + ) + .await + .unwrap(); + let _ = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .unwrap(); + + let events = server_pool_arc.stored_events().await; + let response_event = events + .iter() + .find(|e| e.kind == Kind::Custom(CTXVM_MESSAGES_KIND) && has_tag_name(e, "e")) + .unwrap(); + + assert!( + has_tag_name(response_event, tags::SUPPORT_ENCRYPTION), + "support_encryption must be present" + ); + assert!( + !has_tag_name(response_event, tags::SUPPORT_ENCRYPTION_EPHEMERAL), + "support_encryption_ephemeral must NOT be present when GiftWrapMode is Persistent" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn server_learns_capabilities_from_client_request() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), + as_pool(server_pool), + ) + .await + .unwrap(); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + as_pool(client_pool), + ) + .await + .unwrap(); + + let mut server_rx = server.take_message_receiver().unwrap(); + server.start().await.unwrap(); + client.start().await.unwrap(); + let_event_loops_start().await; + + client + .send(&JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: None, + })) + .await + .unwrap(); + + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .unwrap() + .unwrap(); + + assert_eq!(incoming.message.method(), Some("initialize")); + client + .send(&JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(2), + method: "tools/list".to_string(), + params: None, + })) + .await + .unwrap(); + let incoming2 = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .unwrap() + .unwrap(); + assert_eq!(incoming2.message.method(), Some("tools/list")); + assert_eq!(incoming.client_pubkey, incoming2.client_pubkey); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn server_disabled_encryption_omits_encryption_tags() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + let server_pool_arc = Arc::new(server_pool); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Disabled) + .with_server_info(ServerInfo::default().with_name("NoEncrypt".to_string())), + Arc::clone(&server_pool_arc) as Arc, + ) + .await + .unwrap(); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + as_pool(client_pool), + ) + .await + .unwrap(); + + let mut server_rx = server.take_message_receiver().unwrap(); + let mut client_rx = client.take_message_receiver().unwrap(); + server.start().await.unwrap(); + client.start().await.unwrap(); + let_event_loops_start().await; + + client + .send(&JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: None, + })) + .await + .unwrap(); + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .unwrap() + .unwrap(); + server + .send_response( + &incoming.event_id, + JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + result: serde_json::json!({}), + }), + ) + .await + .unwrap(); + let _ = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .unwrap(); + + let events = server_pool_arc.stored_events().await; + let response_event = events + .iter() + .find(|e| e.kind == Kind::Custom(CTXVM_MESSAGES_KIND) && has_tag_name(e, "e")) + .unwrap(); + + assert!(has_tag_name(response_event, tags::NAME)); + assert!( + !has_tag_name(response_event, tags::SUPPORT_ENCRYPTION), + "encryption tags must be omitted when EncryptionMode is Disabled" + ); + assert!(!has_tag_name( + response_event, + tags::SUPPORT_ENCRYPTION_EPHEMERAL + )); +} + +// ── CEP-35: Client-side discovery tag emission & capability learning ───────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn client_disabled_encryption_emits_no_discovery_tags() { + // Disabled encryption: client must not emit cap tags. Positive case (Optional + // mode emits tags) is covered by unit test client_capability_tags_encryption_optional. + let pool = Arc::new(MockRelayPool::new()); + let server_keys = Keys::generate(); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_keys.public_key().to_hex()) + .with_encryption_mode(EncryptionMode::Disabled) + .with_gift_wrap_mode(GiftWrapMode::Optional), + Arc::clone(&pool) as Arc, + ) + .await + .unwrap(); + + client.start().await.unwrap(); + let_event_loops_start().await; + + // With Disabled encryption, no cap tags are emitted (correct per spec). + // Verify the event is published with p tag but without cap tags. + client + .send(&JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: None, + })) + .await + .unwrap(); + + let events = pool.stored_events().await; + let client_event = events + .iter() + .find(|e| e.kind == Kind::Custom(CTXVM_MESSAGES_KIND)) + .expect("client must publish a request event"); + + // p tag must be present (routing) + assert!(has_tag_name(client_event, "p")); + // No encryption tags when Disabled (the unit test covers the Optional case) + assert!( + !has_tag_name(client_event, tags::SUPPORT_ENCRYPTION), + "Disabled client must not emit support_encryption" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn client_second_request_carries_no_discovery_tags() { + // Second request must never carry discovery tags. One-shot flag behavior + // is covered by unit test client_discovery_tags_sent_once. + let pool = Arc::new(MockRelayPool::new()); + let server_keys = Keys::generate(); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_keys.public_key().to_hex()) + .with_encryption_mode(EncryptionMode::Disabled) + .with_gift_wrap_mode(GiftWrapMode::Optional), + Arc::clone(&pool) as Arc, + ) + .await + .unwrap(); + + client.start().await.unwrap(); + let_event_loops_start().await; + + // First request + client + .send(&JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: None, + })) + .await + .unwrap(); + + // Second request + client + .send(&JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(2), + method: "tools/list".to_string(), + params: None, + })) + .await + .unwrap(); + + let events = pool.stored_events().await; + let ctxvm_events: Vec<&Event> = events + .iter() + .filter(|e| e.kind == Kind::Custom(CTXVM_MESSAGES_KIND)) + .collect(); + assert!(ctxvm_events.len() >= 2); + + let second_event = ctxvm_events + .iter() + .find(|e| e.content.contains("tools/list")) + .expect("second request event must exist"); + + assert!( + !has_tag_name(second_event, tags::SUPPORT_ENCRYPTION), + "second request must NOT include discovery tags" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn client_learns_server_capabilities_from_first_response() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Optional) + .with_gift_wrap_mode(GiftWrapMode::Optional) + .with_server_info(ServerInfo::default().with_name("CapServer".to_string())), + as_pool(server_pool), + ) + .await + .unwrap(); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + as_pool(client_pool), + ) + .await + .unwrap(); + + let mut server_rx = server.take_message_receiver().unwrap(); + let mut client_rx = client.take_message_receiver().unwrap(); + server.start().await.unwrap(); + client.start().await.unwrap(); + let_event_loops_start().await; + + client + .send(&JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: None, + })) + .await + .unwrap(); + + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .unwrap() + .unwrap(); + + server + .send_response( + &incoming.event_id, + JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + result: serde_json::json!({}), + }), + ) + .await + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .unwrap(); + + // Client should have learned capabilities from server's first response + let caps = client.discovered_server_capabilities(); + assert!( + caps.supports_encryption, + "client must learn support_encryption from server response tags" + ); + assert!( + caps.supports_ephemeral_encryption, + "client must learn support_encryption_ephemeral from server response tags" + ); + + let baseline = client.get_server_initialize_event(); + assert!(baseline.is_some(), "baseline event must be set"); +} + +// ── CEP-35: OR-assign, baseline-freeze, and Optional emission ──────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn client_or_assigns_capabilities_across_responses() { + // Server with Persistent gift-wrap emits support_encryption but NOT + // support_encryption_ephemeral on the first response. A second event + // carrying support_encryption_ephemeral must OR-assign into the client's + // learned caps without downgrading the already-learned support_encryption. + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + let server_keys = server_pool.mock_keys(); + + let client_pool = Arc::new(client_pool); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Optional) + .with_gift_wrap_mode(GiftWrapMode::Persistent) + .with_server_info(ServerInfo::default().with_name("PersistentServer".to_string())), + as_pool(server_pool), + ) + .await + .unwrap(); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + Arc::clone(&client_pool) as Arc, + ) + .await + .unwrap(); + + let mut server_rx = server.take_message_receiver().unwrap(); + let mut client_rx = client.take_message_receiver().unwrap(); + server.start().await.unwrap(); + client.start().await.unwrap(); + let_event_loops_start().await; + + // First roundtrip — server responds with support_encryption only. + client + .send(&JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: None, + })) + .await + .unwrap(); + + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .unwrap() + .unwrap(); + + server + .send_response( + &incoming.event_id, + JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + result: serde_json::json!({}), + }), + ) + .await + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .unwrap(); + + let caps_after_first = client.discovered_server_capabilities(); + assert!( + caps_after_first.supports_encryption, + "first response must teach support_encryption" + ); + assert!( + !caps_after_first.supports_ephemeral_encryption, + "Persistent server must NOT advertise ephemeral on first response" + ); + + // Inject a second plaintext event signed by the server, carrying + // support_encryption_ephemeral (simulates a capability upgrade). + let client_pubkey = client_pool.mock_public_key(); + let second_response = serde_json::json!({ + "jsonrpc": "2.0", + "method": "notifications/progress" + }); + let inject_event = EventBuilder::new( + Kind::Custom(CTXVM_MESSAGES_KIND), + second_response.to_string(), + ) + .tags(vec![ + Tag::public_key(client_pubkey), + Tag::custom( + TagKind::Custom(tags::SUPPORT_ENCRYPTION_EPHEMERAL.into()), + Vec::::new(), + ), + ]) + .sign_with_keys(&server_keys) + .unwrap(); + + client_pool.publish_event(&inject_event).await.unwrap(); + tokio::time::sleep(Duration::from_millis(50)).await; + + let caps_after_second = client.discovered_server_capabilities(); + assert!( + caps_after_second.supports_encryption, + "support_encryption must survive OR-assign (not downgraded)" + ); + assert!( + caps_after_second.supports_ephemeral_encryption, + "support_encryption_ephemeral must be OR-assigned from second event" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn client_baseline_event_not_replaced_by_later_responses() { + // The first inbound event carrying discovery tags becomes the baseline. + // Later events with different tags must NOT replace it. + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + let server_keys = server_pool.mock_keys(); + + let client_pool = Arc::new(client_pool); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Optional) + .with_gift_wrap_mode(GiftWrapMode::Optional) + .with_server_info(ServerInfo::default().with_name("BaselineServer".to_string())), + as_pool(server_pool), + ) + .await + .unwrap(); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + Arc::clone(&client_pool) as Arc, + ) + .await + .unwrap(); + + let mut server_rx = server.take_message_receiver().unwrap(); + let mut client_rx = client.take_message_receiver().unwrap(); + server.start().await.unwrap(); + client.start().await.unwrap(); + let_event_loops_start().await; + + // First roundtrip — establishes baseline. + client + .send(&JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: None, + })) + .await + .unwrap(); + + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .unwrap() + .unwrap(); + + server + .send_response( + &incoming.event_id, + JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + result: serde_json::json!({}), + }), + ) + .await + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .unwrap(); + + let baseline = client.get_server_initialize_event(); + assert!( + baseline.is_some(), + "baseline must be set after first response" + ); + let baseline_id = baseline.unwrap().id; + + // Inject a second event with different discovery tags. + let client_pubkey = client_pool.mock_public_key(); + let notification = serde_json::json!({ + "jsonrpc": "2.0", + "method": "notifications/progress" + }); + let inject_event = + EventBuilder::new(Kind::Custom(CTXVM_MESSAGES_KIND), notification.to_string()) + .tags(vec![ + Tag::public_key(client_pubkey), + Tag::custom( + TagKind::Custom(tags::SUPPORT_ENCRYPTION.into()), + Vec::::new(), + ), + ]) + .sign_with_keys(&server_keys) + .unwrap(); + + client_pool.publish_event(&inject_event).await.unwrap(); + tokio::time::sleep(Duration::from_millis(50)).await; + + let baseline_after = client.get_server_initialize_event(); + assert_eq!( + baseline_after.unwrap().id, + baseline_id, + "baseline event must NOT be replaced by later events" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn client_optional_encryption_emits_discovery_tags() { + // Client with Optional encryption must include discovery tags in the + // inner signed event. We decrypt the published gift wrap to verify. + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + let server_keys = server_pool.mock_keys(); + + let client_pool = Arc::new(client_pool); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Optional) + .with_gift_wrap_mode(GiftWrapMode::Optional), + Arc::clone(&client_pool) as Arc, + ) + .await + .unwrap(); + + client.start().await.unwrap(); + let_event_loops_start().await; + + client + .send(&JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: None, + })) + .await + .unwrap(); + + let events = client_pool.stored_events().await; + let gift_wrap = events + .iter() + .find(|e| { + e.kind == Kind::Custom(GIFT_WRAP_KIND) + || e.kind == Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND) + }) + .expect("Optional encryption must produce a gift-wrapped event"); + + // Decrypt using the server's keys (the recipient). + let signer: Arc = Arc::new(server_keys); + let decrypted_json = + contextvm_sdk::encryption::decrypt_gift_wrap_single_layer(&signer, gift_wrap) + .await + .expect("gift wrap must be decryptable with server keys"); + + let inner: Event = + serde_json::from_str(&decrypted_json).expect("decrypted content must be a valid Event"); + + assert!( + has_tag_name(&inner, tags::SUPPORT_ENCRYPTION), + "inner event must carry support_encryption tag" + ); + assert!( + has_tag_name(&inner, tags::SUPPORT_ENCRYPTION_EPHEMERAL), + "inner event must carry support_encryption_ephemeral tag (Optional gift-wrap mode)" + ); +} +// ── Multi-client support ───────────────────────────────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn multi_client_concurrent_requests_both_get_responses() { + // Two different clients send requests to the same server; both must get + // their own response (the single-peer barrier is removed). + let mut pools = MockRelayPool::create_linked_group(3); + let server_pool = pools.remove(0); + let client_b_pool = pools.remove(1); + let client_a_pool = pools.remove(0); + let server_pubkey = server_pool.mock_public_key(); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut client_a = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + as_pool(client_a_pool), + ) + .await + .expect("create client A"); + + let mut client_b = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + as_pool(client_b_pool), + ) + .await + .expect("create client B"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + let mut client_a_rx = client_a + .take_message_receiver() + .expect("client A message receiver"); + let mut client_b_rx = client_b + .take_message_receiver() + .expect("client B message receiver"); + + server.start().await.expect("server start"); + client_a.start().await.expect("client A start"); + client_b.start().await.expect("client B start"); + let_event_loops_start().await; + + // Client A sends a request. + let req_a = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "tools/list".to_string(), + params: None, + }); + client_a.send(&req_a).await.expect("client A send"); + + // Client B sends a request. + let req_b = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(2), + method: "tools/list".to_string(), + params: None, + }); + client_b.send(&req_b).await.expect("client B send"); + + // Server receives both requests (order may vary). + let incoming_1 = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout rx 1") + .expect("rx closed 1"); + let incoming_2 = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout rx 2") + .expect("rx closed 2"); + + // Send responses to both. + let resp_1 = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: incoming_1.message.id().unwrap().clone(), + result: serde_json::json!({"tools": []}), + }); + server + .send_response(&incoming_1.event_id, resp_1) + .await + .expect("server respond to 1"); + + let resp_2 = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: incoming_2.message.id().unwrap().clone(), + result: serde_json::json!({"tools": []}), + }); + server + .send_response(&incoming_2.event_id, resp_2) + .await + .expect("server respond to 2"); + + // Both clients must receive their respective response. + let resp_a = tokio::time::timeout(Duration::from_millis(500), client_a_rx.recv()) + .await + .expect("timeout client A response") + .expect("client A channel closed"); + let resp_b = tokio::time::timeout(Duration::from_millis(500), client_b_rx.recv()) + .await + .expect("timeout client B response") + .expect("client B channel closed"); + + assert!( + matches!(resp_a, JsonRpcMessage::Response(_)), + "client A must receive a response" + ); + assert!( + matches!(resp_b, JsonRpcMessage::Response(_)), + "client B must receive a response" + ); +} + +// ── Session store LRU tests ───────────────────────────────────────────────── + +use contextvm_sdk::transport::server::SessionStore; +use contextvm_sdk::ServerEventRouteStore; + +#[tokio::test] +async fn session_store_lru_eviction() { + let store = SessionStore::with_capacity(3); + let r = ServerEventRouteStore::new(); + store.get_or_create_session("a", false, &r).await; + store.get_or_create_session("b", false, &r).await; + store.get_or_create_session("c", false, &r).await; + + // 4th session evicts the oldest ("a") + store.get_or_create_session("d", false, &r).await; + + assert!( + store.get_session("a").await.is_none(), + "oldest session must be evicted when capacity is exceeded" + ); + assert!(store.get_session("b").await.is_some()); + assert!(store.get_session("c").await.is_some()); + assert!(store.get_session("d").await.is_some()); + assert_eq!(store.session_count().await, 3); +} + +#[tokio::test] +async fn session_store_eviction_callback_fires() { + let evicted_keys: Arc>> = + Arc::new(std::sync::Mutex::new(Vec::new())); + let captured = evicted_keys.clone(); + let r = ServerEventRouteStore::new(); + + let mut store = SessionStore::with_capacity(2); + store.set_eviction_callback(std::sync::Arc::new(move |pubkey| { + captured.lock().unwrap().push(pubkey); + })); + + store.get_or_create_session("x", false, &r).await; + store.get_or_create_session("y", false, &r).await; + // Adding "z" evicts "x" + store.get_or_create_session("z", false, &r).await; + + let keys = evicted_keys.lock().unwrap(); + assert_eq!(keys.len(), 1, "callback must fire exactly once"); + assert_eq!(keys[0], "x", "evicted key must be the oldest session"); +} + +// ── Event loop cancellation on close() ────────────────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn client_close_stops_event_loop() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut rx = client.take_message_receiver().expect("message receiver"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + // Close should cancel the event loop, causing the rx channel to close. + client.close().await.expect("client close"); + + // The receiver must resolve to None (closed) within a short timeout. + let result = tokio::time::timeout(Duration::from_millis(200), rx.recv()).await; + assert!( + matches!(result, Ok(None)), + "after close(), message receiver must yield None (channel closed)" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn server_close_stops_event_loop() { + let (_client_pool, server_pool) = MockRelayPool::create_pair(); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut rx = server.take_message_receiver().expect("message receiver"); + server.start().await.expect("server start"); + let_event_loops_start().await; + + // Close should cancel both event loop and cleanup tasks. + server.close().await.expect("server close"); + + // The receiver must resolve to None (closed) within a short timeout. + let result = tokio::time::timeout(Duration::from_millis(200), rx.recv()).await; + assert!( + matches!(result, Ok(None)), + "after close(), message receiver must yield None (channel closed)" + ); +}