From 31e4bd14baace2722e20aa78be40fcc401f054af Mon Sep 17 00:00:00 2001 From: Akhil Dhyani Date: Fri, 3 Apr 2026 10:51:23 +0530 Subject: [PATCH 01/69] fix: resolve clippy warnings and align naming with TS SDK MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace manual Default impl for EncryptionMode with #[derive(Default)] - Remove unused Instant import from server.rs - Rename is_public_server → is_announced_server to match TS SDK (cd7f411) - Update hardcoded protocol version to 2025-07-02 - Update examples, README, and DESIGN.md accordingly --- DESIGN.md | 2 +- README.md | 4 ++-- examples/gateway.rs | 2 +- src/core/types.rs | 9 ++------- src/gateway/mod.rs | 6 +++--- src/transport/client.rs | 4 ++-- src/transport/server.rs | 10 +++++----- 7 files changed, 16 insertions(+), 21 deletions(-) diff --git a/DESIGN.md b/DESIGN.md index fc9467f..7fe0ab7 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -221,7 +221,7 @@ tracing-subscriber = "0.3" - Unit test: rejects events from wrong server pubkey - [x] **2.4** Implement `NostrServerTransport` - - Config: relay_urls, encryption_mode, server_info, is_public_server, allowed_public_keys, excluded_capabilities, cleanup_interval_ms, session_timeout_ms + - Config: relay_urls, encryption_mode, server_info, is_announced_server, allowed_public_keys, excluded_capabilities, cleanup_interval_ms, session_timeout_ms - Implements `Transport` trait - Features: - Subscribe to events targeting server pubkey diff --git a/README.md b/README.md index 79ed657..dfe1d66 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,7 @@ async fn main() -> contextvm_sdk::Result<()> { about: Some("Tools via Nostr".into()), ..Default::default() }), - is_public_server: true, + is_announced_server: true, ..Default::default() }, }; @@ -201,7 +201,7 @@ metadata-private delivery. Server announcements (kinds 11316–11320) are always | `relay_urls` | `["wss://relay.damus.io"]` | Nostr relays to connect to | | `encryption_mode` | `Optional` | Encryption policy | | `server_info` | `None` | Server metadata for announcements | -| `is_public_server` | `false` | Whether to publish announcements | +| `is_announced_server` | `false` | Whether to publish announcements (CEP-6) | | `allowed_public_keys` | `[]` (allow all) | Client pubkey allowlist (hex) | | `excluded_capabilities` | `[]` | Methods exempt from allowlist | | `session_timeout` | `300s` | Inactive session expiry | diff --git a/examples/gateway.rs b/examples/gateway.rs index ee6b1c5..2effaba 100644 --- a/examples/gateway.rs +++ b/examples/gateway.rs @@ -25,7 +25,7 @@ async fn main() -> contextvm_sdk::Result<()> { about: Some("A simple echo tool exposed via ContextVM".to_string()), ..Default::default() }), - is_public_server: true, + is_announced_server: true, ..Default::default() }, }; diff --git a/src/core/types.rs b/src/core/types.rs index cb10773..4ea2e34 100644 --- a/src/core/types.rs +++ b/src/core/types.rs @@ -10,10 +10,11 @@ use std::time::Instant; /// /// Controls whether MCP messages are sent as plaintext kind 25910 events /// or wrapped in NIP-59 gift wraps (kind 1059) for end-to-end encryption. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum EncryptionMode { /// Encrypt responses only when the incoming request was encrypted (mirror mode). + #[default] Optional, /// Enforce encryption for all messages; reject plaintext. Required, @@ -21,12 +22,6 @@ pub enum EncryptionMode { Disabled, } -impl Default for EncryptionMode { - fn default() -> Self { - Self::Optional - } -} - // ── Server info ───────────────────────────────────────────────────── /// Server information for announcements (kind 11316). diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index d907611..3752e79 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -98,7 +98,7 @@ mod tests { version: Some("1.0.0".to_string()), ..Default::default() }), - is_public_server: true, + is_announced_server: true, allowed_public_keys: vec!["abc123".to_string()], excluded_capabilities: vec![], cleanup_interval: Duration::from_secs(120), @@ -109,7 +109,7 @@ mod tests { assert_eq!(config.nostr_config.relay_urls, vec!["wss://relay.example.com"]); assert_eq!(config.nostr_config.encryption_mode, EncryptionMode::Required); - assert!(config.nostr_config.is_public_server); + assert!(config.nostr_config.is_announced_server); assert_eq!(config.nostr_config.allowed_public_keys.len(), 1); assert!(config.nostr_config.server_info.as_ref().unwrap().name.as_ref().unwrap() == "Test Gateway"); } @@ -120,6 +120,6 @@ mod tests { nostr_config: NostrServerTransportConfig::default(), }; assert_eq!(config.nostr_config.encryption_mode, EncryptionMode::Optional); - assert!(!config.nostr_config.is_public_server); + assert!(!config.nostr_config.is_announced_server); } } diff --git a/src/transport/client.rs b/src/transport/client.rs index 7dffcf5..f29a103 100644 --- a/src/transport/client.rs +++ b/src/transport/client.rs @@ -153,7 +153,7 @@ impl NostrClientTransport { jsonrpc: "2.0".to_string(), id: request_id.clone(), result: serde_json::json!({ - "protocolVersion": "2025-03-26", + "protocolVersion": "2025-07-02", "serverInfo": { "name": "Emulated-Stateless-Server", "version": "1.0.0" @@ -275,7 +275,7 @@ mod tests { jsonrpc: "2.0".to_string(), id: request_id.clone(), result: serde_json::json!({ - "protocolVersion": "2025-03-26", + "protocolVersion": "2025-07-02", "serverInfo": { "name": "Emulated-Stateless-Server", "version": "1.0.0" diff --git a/src/transport/server.rs b/src/transport/server.rs index 43d97a8..beb2429 100644 --- a/src/transport/server.rs +++ b/src/transport/server.rs @@ -6,7 +6,7 @@ use std::collections::HashMap; use std::sync::Arc; -use std::time::{Duration, Instant}; +use std::time::Duration; use nostr_sdk::prelude::*; use tokio::sync::RwLock; @@ -27,8 +27,8 @@ pub struct NostrServerTransportConfig { pub encryption_mode: EncryptionMode, /// Server information for announcements. pub server_info: Option, - /// Whether this is a public server (publishes announcements). - pub is_public_server: bool, + /// Whether this server publishes public announcements (CEP-6). + pub is_announced_server: bool, /// Allowed client public keys (hex). Empty = allow all. pub allowed_public_keys: Vec, /// Capabilities excluded from pubkey whitelisting. @@ -45,7 +45,7 @@ impl Default for NostrServerTransportConfig { relay_urls: vec!["wss://relay.damus.io".to_string()], encryption_mode: EncryptionMode::Optional, server_info: None, - is_public_server: false, + is_announced_server: false, allowed_public_keys: Vec::new(), excluded_capabilities: Vec::new(), cleanup_interval: Duration::from_secs(60), @@ -745,7 +745,7 @@ mod tests { fn test_config_defaults() { let config = NostrServerTransportConfig::default(); assert_eq!(config.relay_urls, vec!["wss://relay.damus.io".to_string()]); - assert!(!config.is_public_server); + assert!(!config.is_announced_server); assert!(config.allowed_public_keys.is_empty()); assert!(config.excluded_capabilities.is_empty()); assert_eq!(config.cleanup_interval, Duration::from_secs(60)); From f2f7eb9d8896e35bf67b8a34669328c3effd1d7c Mon Sep 17 00:00:00 2001 From: Akhil Dhyani Date: Fri, 3 Apr 2026 10:53:19 +0530 Subject: [PATCH 02/69] feat: add CEP-17/CEP-19 constants and bootstrap relay URLs - Add EPHEMERAL_GIFT_WRAP_KIND (21059) for CEP-19 - Add RELAY_LIST_METADATA_KIND (10002) for CEP-17 - Add tags: RELAY, SUPPORT_ENCRYPTION_EPHEMERAL - Add DEFAULT_BOOTSTRAP_RELAY_URLS matching TS SDK - Add DEFAULT_LRU_SIZE, DEFAULT_TIMEOUT_MS - Add INITIALIZE_METHOD, NOTIFICATIONS_INITIALIZED_METHOD - Add 11 unit tests verifying values against spec and NIP-01 kind ranges --- src/core/constants.rs | 150 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 150 insertions(+) diff --git a/src/core/constants.rs b/src/core/constants.rs index 3748a3b..345c91b 100644 --- a/src/core/constants.rs +++ b/src/core/constants.rs @@ -9,6 +9,15 @@ pub const CTXVM_MESSAGES_KIND: u16 = 25910; /// Encrypted messages using NIP-59 Gift Wrap (kind 1059) pub const GIFT_WRAP_KIND: u16 = 1059; +/// Ephemeral variant of NIP-59 Gift Wrap (kind 21059, CEP-19) +/// +/// Same structure and semantics as kind 1059, but in NIP-01's ephemeral range. +/// Relays are not expected to store ephemeral events beyond transient forwarding. +pub const EPHEMERAL_GIFT_WRAP_KIND: u16 = 21059; + +/// Replaceable relay list metadata event following NIP-65 (CEP-17) +pub const RELAY_LIST_METADATA_KIND: u16 = 10002; + /// Server announcement (addressable, kind 11316) pub const SERVER_ANNOUNCEMENT_KIND: u16 = 11316; @@ -29,6 +38,9 @@ pub mod tags { /// Public key tag pub const PUBKEY: &str = "p"; + /// Relay URL tag (CEP-17) + pub const RELAY: &str = "r"; + /// Event ID tag for correlation pub const EVENT_ID: &str = "e"; @@ -49,11 +61,39 @@ pub mod tags { /// Support encryption tag pub const SUPPORT_ENCRYPTION: &str = "support_encryption"; + + /// Support ephemeral gift wrap kind (21059) for encrypted messages (CEP-19) + pub const SUPPORT_ENCRYPTION_EPHEMERAL: &str = "support_encryption_ephemeral"; } /// Maximum message size (1MB) pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024; +/// Default LRU cache size for deduplication +pub const DEFAULT_LRU_SIZE: usize = 5000; + +/// Default timeout for network/relay operations (30 seconds) +pub const DEFAULT_TIMEOUT_MS: u64 = 30_000; + +/// Default relay targets for discoverability publication (CEP-17). +/// +/// These are used as additional publication targets for server metadata, +/// even when they are not part of the server's operational relay list. +pub const DEFAULT_BOOTSTRAP_RELAY_URLS: &[&str] = &[ + "wss://relay.damus.io", + "wss://relay.primal.net", + "wss://nos.lol", + "wss://relay.snort.social/", + "wss://nostr.mom/", + "wss://nostr.oxtr.dev/", +]; + +/// MCP protocol method for the initialization request +pub const INITIALIZE_METHOD: &str = "initialize"; + +/// MCP protocol method for the initialized notification +pub const NOTIFICATIONS_INITIALIZED_METHOD: &str = "notifications/initialized"; + /// Kinds that should never be encrypted (public announcements) pub const UNENCRYPTED_KINDS: &[u16] = &[ SERVER_ANNOUNCEMENT_KIND, @@ -62,3 +102,113 @@ pub const UNENCRYPTED_KINDS: &[u16] = &[ RESOURCETEMPLATES_LIST_KIND, PROMPTS_LIST_KIND, ]; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_event_kind_values_match_spec() { + assert_eq!(CTXVM_MESSAGES_KIND, 25910); + assert_eq!(GIFT_WRAP_KIND, 1059); + assert_eq!(EPHEMERAL_GIFT_WRAP_KIND, 21059); + assert_eq!(RELAY_LIST_METADATA_KIND, 10002); + assert_eq!(SERVER_ANNOUNCEMENT_KIND, 11316); + assert_eq!(TOOLS_LIST_KIND, 11317); + assert_eq!(RESOURCES_LIST_KIND, 11318); + assert_eq!(RESOURCETEMPLATES_LIST_KIND, 11319); + assert_eq!(PROMPTS_LIST_KIND, 11320); + } + + #[test] + fn test_tag_values_match_ts_sdk() { + assert_eq!(tags::PUBKEY, "p"); + assert_eq!(tags::RELAY, "r"); + assert_eq!(tags::EVENT_ID, "e"); + assert_eq!(tags::CAPABILITY, "cap"); + assert_eq!(tags::NAME, "name"); + assert_eq!(tags::WEBSITE, "website"); + assert_eq!(tags::PICTURE, "picture"); + assert_eq!(tags::ABOUT, "about"); + assert_eq!(tags::SUPPORT_ENCRYPTION, "support_encryption"); + assert_eq!( + tags::SUPPORT_ENCRYPTION_EPHEMERAL, + "support_encryption_ephemeral" + ); + } + + #[test] + fn test_ephemeral_gift_wrap_in_ephemeral_range() { + // NIP-01: ephemeral events are 20000 <= kind < 30000 + assert!(EPHEMERAL_GIFT_WRAP_KIND >= 20000); + assert!(EPHEMERAL_GIFT_WRAP_KIND < 30000); + } + + #[test] + fn test_ctxvm_messages_in_ephemeral_range() { + // NIP-01: ephemeral events are 20000 <= kind < 30000 + assert!(CTXVM_MESSAGES_KIND >= 20000); + assert!(CTXVM_MESSAGES_KIND < 30000); + } + + #[test] + fn test_relay_list_metadata_in_replaceable_range() { + // NIP-01: replaceable events are 10000 <= kind < 20000 + assert!(RELAY_LIST_METADATA_KIND >= 10000); + assert!(RELAY_LIST_METADATA_KIND < 20000); + } + + #[test] + fn test_announcement_kinds_in_addressable_range() { + // NIP-01: addressable events are 30000 <= kind < 40000 + // However, the spec uses 11316-11320 which are in the replaceable range. + // These are parameterized replaceable events per the ContextVM spec. + for &kind in UNENCRYPTED_KINDS { + assert!(kind >= 11316); + assert!(kind <= 11320); + } + } + + #[test] + fn test_bootstrap_relays_are_wss() { + for url in DEFAULT_BOOTSTRAP_RELAY_URLS { + assert!( + url.starts_with("wss://"), + "Bootstrap relay must use wss: {url}" + ); + } + } + + #[test] + fn test_bootstrap_relays_nonempty() { + assert!( + !DEFAULT_BOOTSTRAP_RELAY_URLS.is_empty(), + "Must have at least one bootstrap relay" + ); + } + + #[test] + fn test_mcp_method_constants() { + assert_eq!(INITIALIZE_METHOD, "initialize"); + assert_eq!( + NOTIFICATIONS_INITIALIZED_METHOD, + "notifications/initialized" + ); + } + + #[test] + fn test_unencrypted_kinds_contains_all_announcements() { + assert!(UNENCRYPTED_KINDS.contains(&SERVER_ANNOUNCEMENT_KIND)); + assert!(UNENCRYPTED_KINDS.contains(&TOOLS_LIST_KIND)); + assert!(UNENCRYPTED_KINDS.contains(&RESOURCES_LIST_KIND)); + assert!(UNENCRYPTED_KINDS.contains(&RESOURCETEMPLATES_LIST_KIND)); + assert!(UNENCRYPTED_KINDS.contains(&PROMPTS_LIST_KIND)); + } + + #[test] + fn test_gift_wrap_not_in_unencrypted() { + assert!(!UNENCRYPTED_KINDS.contains(&GIFT_WRAP_KIND)); + assert!(!UNENCRYPTED_KINDS.contains(&EPHEMERAL_GIFT_WRAP_KIND)); + } +} + From 924bc3d35e094bb60c399a531f48c9c6b5e97cbb Mon Sep 17 00:00:00 2001 From: Akhil Dhyani Date: Fri, 3 Apr 2026 10:59:23 +0530 Subject: [PATCH 03/69] feat: support ephemeral gift wraps (kind 21059) per CEP-19 - Update base transport to subscribe to both kind 1059 and 21059 - Handle kind 21059 in both client and server transport event loops - Advertise support_encryption_ephemeral tag in server announcements - Clean up unused imports in base.rs and client.rs tests --- src/transport/base.rs | 16 ++++++++++++++-- src/transport/client.rs | 5 +++-- src/transport/server.rs | 8 +++++++- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/src/transport/base.rs b/src/transport/base.rs index 0c19a4a..77c048b 100644 --- a/src/transport/base.rs +++ b/src/transport/base.rs @@ -71,10 +71,23 @@ impl BaseTransport { let two_days_ago = Timestamp::from(Timestamp::now().as_u64().saturating_sub(2 * 24 * 3600)); let gift_wrap_filter = Filter::new() .kind(Kind::Custom(GIFT_WRAP_KIND)) + .custom_tag(SingleLetterTag::lowercase(Alphabet::P), p_tag.clone()) + .since(two_days_ago); + + // CEP-19: Ephemeral gift wraps (kind 21059) — same semantics as 1059 + // but in the ephemeral range so relays don't persist them. + let ephemeral_gift_wrap_filter = Filter::new() + .kind(Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND)) .custom_tag(SingleLetterTag::lowercase(Alphabet::P), p_tag) .since(two_days_ago); - self.relay_pool.subscribe(vec![ephemeral_filter, gift_wrap_filter]).await + self.relay_pool + .subscribe(vec![ + ephemeral_filter, + gift_wrap_filter, + ephemeral_gift_wrap_filter, + ]) + .await } /// Convert a Nostr event to an MCP message with validation. @@ -163,7 +176,6 @@ impl BaseTransport { mod tests { use super::*; use crate::core::types::*; - use nostr_sdk::prelude::*; // Test should_encrypt logic without constructing full BaseTransport fn should_encrypt(mode: EncryptionMode, kind: u16, is_encrypted: Option) -> bool { diff --git a/src/transport/client.rs b/src/transport/client.rs index f29a103..fa49126 100644 --- a/src/transport/client.rs +++ b/src/transport/client.rs @@ -181,7 +181,9 @@ impl NostrClientTransport { if let RelayPoolNotification::Event { event, .. } = notification { // Handle gift-wrapped events let (actual_event_content, actual_pubkey, e_tag) = - if event.kind == Kind::Custom(GIFT_WRAP_KIND) { + if event.kind == Kind::Custom(GIFT_WRAP_KIND) + || event.kind == Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND) + { // Single-layer NIP-44 decrypt (matches JS/TS SDK) let signer = match client.signer().await { Ok(s) => s, @@ -246,7 +248,6 @@ impl NostrClientTransport { #[cfg(test)] mod tests { use super::*; - use crate::core::types::*; #[test] fn test_config_defaults() { diff --git a/src/transport/server.rs b/src/transport/server.rs index beb2429..c6483a6 100644 --- a/src/transport/server.rs +++ b/src/transport/server.rs @@ -323,6 +323,10 @@ impl NostrServerTransport { TagKind::Custom(tags::SUPPORT_ENCRYPTION.into()), Vec::::new(), )); + tags.push(Tag::custom( + TagKind::Custom(tags::SUPPORT_ENCRYPTION_EPHEMERAL.into()), + Vec::::new(), + )); } let builder = @@ -429,7 +433,9 @@ impl NostrServerTransport { while let Ok(notification) = notifications.recv().await { if let RelayPoolNotification::Event { event, .. } = notification { let (content, sender_pubkey, event_id, is_encrypted) = - if event.kind == Kind::Custom(GIFT_WRAP_KIND) { + if event.kind == Kind::Custom(GIFT_WRAP_KIND) + || event.kind == Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND) + { if encryption_mode == EncryptionMode::Disabled { tracing::warn!("Received encrypted message but encryption is disabled"); continue; From 79c62faef708d685d8bdae291ff5bf987329146b Mon Sep 17 00:00:00 2001 From: piyush-1337 Date: Fri, 3 Apr 2026 12:44:33 +0530 Subject: [PATCH 04/69] fix: return signed event id for encrypted message correlation --- src/transport/base.rs | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/transport/base.rs b/src/transport/base.rs index 0c19a4a..de17b83 100644 --- a/src/transport/base.rs +++ b/src/transport/base.rs @@ -101,7 +101,8 @@ impl BaseTransport { /// Send an MCP message to a recipient, optionally encrypting. /// - /// Returns the event ID of the published event. + /// Returns the signed MCP event ID. + /// When encrypted, this is the inner signed event ID. pub async fn send_mcp_message( &self, message: &JsonRpcMessage, @@ -113,6 +114,7 @@ impl BaseTransport { let should_encrypt = self.should_encrypt(kind, is_encrypted); let event = self.create_signed_event(message, kind, tags).await?; + let signed_event_id = event.id; if should_encrypt { // Single-layer gift wrap: JSON.stringify(signedEvent) → NIP-44 encrypt @@ -124,14 +126,18 @@ impl BaseTransport { let gift_wrap_event = encryption::gift_wrap_single_layer( &signer, recipient, &event_json, ).await?; - let event_id = self.relay_pool.publish_event(&gift_wrap_event).await?; - tracing::debug!(event_id = %event_id, "Sent encrypted MCP message"); - Ok(event_id) + self.relay_pool.publish_event(&gift_wrap_event).await?; + tracing::debug!( + signed_event_id = %signed_event_id, + envelope_id = %gift_wrap_event.id, + "Sent encrypted MCP message" + ); } else { - let event_id = self.relay_pool.publish_event(&event).await?; - tracing::debug!(event_id = %event_id, "Sent unencrypted MCP message"); - Ok(event_id) + self.relay_pool.publish_event(&event).await?; + tracing::debug!(signed_event_id = %signed_event_id, "Sent unencrypted MCP message"); } + + Ok(signed_event_id) } /// Determine whether a message should be encrypted. From d7d1ee5c4ec7b07d9484beca5d0238207f68d293 Mon Sep 17 00:00:00 2001 From: Akhil Dhyani Date: Fri, 3 Apr 2026 15:20:17 +0530 Subject: [PATCH 05/69] chore: align with CEP-4 and official rmcp SDK - Update gift wrap subscription to use since:now() per CEP-4 - Add rmcp v1.3.0 dependency - Use ProtocolVersion::LATEST in emulated client handshake and tests --- Cargo.toml | 1 + src/transport/base.rs | 11 +++-------- src/transport/client.rs | 5 +++-- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 165dae7..b2e3a60 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ nostr-sdk = { version = "0.43", features = ["nip59"] } # Logging tracing = "0.1" +rmcp = "1.3.0" [dev-dependencies] tokio-test = "0.4" diff --git a/src/transport/base.rs b/src/transport/base.rs index 77c048b..d6e0b00 100644 --- a/src/transport/base.rs +++ b/src/transport/base.rs @@ -60,26 +60,21 @@ impl BaseTransport { pub async fn subscribe_for_pubkey(&self, pubkey: &PublicKey) -> Result<()> { let p_tag = pubkey.to_hex(); - // Ephemeral ContextVM messages — safe to use since:now() let ephemeral_filter = Filter::new() .kind(Kind::Custom(CTXVM_MESSAGES_KIND)) .custom_tag(SingleLetterTag::lowercase(Alphabet::P), p_tag.clone()) .since(Timestamp::now()); - // NIP-59 gift wraps — timestamps are randomized (up to ±48h or more), - // so we must NOT use since:now(). Limit to recent window instead. - let two_days_ago = Timestamp::from(Timestamp::now().as_u64().saturating_sub(2 * 24 * 3600)); + let now = Timestamp::now(); let gift_wrap_filter = Filter::new() .kind(Kind::Custom(GIFT_WRAP_KIND)) .custom_tag(SingleLetterTag::lowercase(Alphabet::P), p_tag.clone()) - .since(two_days_ago); + .since(now); - // CEP-19: Ephemeral gift wraps (kind 21059) — same semantics as 1059 - // but in the ephemeral range so relays don't persist them. let ephemeral_gift_wrap_filter = Filter::new() .kind(Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND)) .custom_tag(SingleLetterTag::lowercase(Alphabet::P), p_tag) - .since(two_days_ago); + .since(now); self.relay_pool .subscribe(vec![ diff --git a/src/transport/client.rs b/src/transport/client.rs index fa49126..5a0255b 100644 --- a/src/transport/client.rs +++ b/src/transport/client.rs @@ -17,6 +17,7 @@ use crate::core::types::*; use crate::encryption; use crate::relay::RelayPool; use crate::transport::base::BaseTransport; +use rmcp::model::ProtocolVersion; /// Configuration for the client transport. pub struct NostrClientTransportConfig { @@ -153,7 +154,7 @@ impl NostrClientTransport { jsonrpc: "2.0".to_string(), id: request_id.clone(), result: serde_json::json!({ - "protocolVersion": "2025-07-02", + "protocolVersion": ProtocolVersion::LATEST.to_string(), "serverInfo": { "name": "Emulated-Stateless-Server", "version": "1.0.0" @@ -276,7 +277,7 @@ mod tests { jsonrpc: "2.0".to_string(), id: request_id.clone(), result: serde_json::json!({ - "protocolVersion": "2025-07-02", + "protocolVersion": ProtocolVersion::LATEST.to_string(), "serverInfo": { "name": "Emulated-Stateless-Server", "version": "1.0.0" From e569fa994031128b5a3c9ba2c4b398cf9e7c3308 Mon Sep 17 00:00:00 2001 From: Kushagra Date: Fri, 3 Apr 2026 16:33:22 +0530 Subject: [PATCH 06/69] feat: implemented basic rmcp support helpers and functions --- Cargo.toml | 8 ++ src/lib.rs | 6 + src/rmcp_transport/convert.rs | 44 ++++++ src/rmcp_transport/mod.rs | 13 ++ src/rmcp_transport/worker.rs | 246 ++++++++++++++++++++++++++++++++++ 5 files changed, 317 insertions(+) create mode 100644 src/rmcp_transport/convert.rs create mode 100644 src/rmcp_transport/mod.rs create mode 100644 src/rmcp_transport/worker.rs diff --git a/Cargo.toml b/Cargo.toml index 165dae7..8ea0a1b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,14 @@ nostr-sdk = { version = "0.43", features = ["nip59"] } # Logging tracing = "0.1" +# Optional MCP integration (Rust equivalent to TS @modelcontextprotocol/sdk) +rmcp = { version = "0.16.0", features = ["server", "client", "macros", "transport-worker"], optional = true } + +[features] +# Enable rmcp by default while keeping legacy APIs available. +default = ["rmcp"] +rmcp = ["dep:rmcp"] + [dev-dependencies] tokio-test = "0.4" tracing-subscriber = "0.3" diff --git a/src/lib.rs b/src/lib.rs index becd92e..cc2a09e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -45,6 +45,9 @@ pub mod relay; pub mod signer; pub mod transport; +#[cfg(feature = "rmcp")] +pub mod rmcp_transport; + // Re-export commonly used types pub use core::error::{Error, Result}; pub use core::types::{ @@ -55,3 +58,6 @@ pub use discovery::ServerAnnouncement; pub use relay::RelayPool; pub use transport::client::{NostrClientTransport, NostrClientTransportConfig}; pub use transport::server::{IncomingRequest, NostrServerTransport, NostrServerTransportConfig}; + +#[cfg(feature = "rmcp")] +pub use rmcp; diff --git a/src/rmcp_transport/convert.rs b/src/rmcp_transport/convert.rs new file mode 100644 index 0000000..e6c49bb --- /dev/null +++ b/src/rmcp_transport/convert.rs @@ -0,0 +1,44 @@ +//! Conversion boundary between internal JSON-RPC messages and rmcp message types. +//! +//! These helpers intentionally convert via serde JSON to preserve wire-level +//! compatibility and avoid fragile hand-mapping between evolving type systems. + +use crate::core::types::JsonRpcMessage; + +/// Convert internal JSON-RPC message into rmcp server RX message. +/// +/// Role mapping: +/// - RoleServer RX receives client-originated messages. +pub fn internal_to_rmcp_server_rx( + msg: &JsonRpcMessage, +) -> Option> { + let value = serde_json::to_value(msg).ok()?; + serde_json::from_value(value).ok() +} + +/// Convert internal JSON-RPC message into rmcp client RX message. +/// +/// Role mapping: +/// - RoleClient RX receives server-originated messages. +pub fn internal_to_rmcp_client_rx( + msg: &JsonRpcMessage, +) -> Option> { + let value = serde_json::to_value(msg).ok()?; + serde_json::from_value(value).ok() +} + +/// Convert rmcp server TX message back into internal JSON-RPC. +pub fn rmcp_server_tx_to_internal( + msg: rmcp::service::TxJsonRpcMessage, +) -> Option { + let value = serde_json::to_value(msg).ok()?; + serde_json::from_value(value).ok() +} + +/// Convert rmcp client TX message back into internal JSON-RPC. +pub fn rmcp_client_tx_to_internal( + msg: rmcp::service::TxJsonRpcMessage, +) -> Option { + let value = serde_json::to_value(msg).ok()?; + serde_json::from_value(value).ok() +} diff --git a/src/rmcp_transport/mod.rs b/src/rmcp_transport/mod.rs new file mode 100644 index 0000000..c1857f8 --- /dev/null +++ b/src/rmcp_transport/mod.rs @@ -0,0 +1,13 @@ +//! RMCP integration scaffolding. +//! +//! This module will host adapters that bridge the existing Nostr transport +//! implementation with rmcp services. + +pub mod convert; +pub mod worker; + +pub use convert::{ + internal_to_rmcp_client_rx, internal_to_rmcp_server_rx, rmcp_client_tx_to_internal, + rmcp_server_tx_to_internal, +}; +pub use worker::{NostrClientWorker, NostrServerWorker}; diff --git a/src/rmcp_transport/worker.rs b/src/rmcp_transport/worker.rs new file mode 100644 index 0000000..6807fbc --- /dev/null +++ b/src/rmcp_transport/worker.rs @@ -0,0 +1,246 @@ +//! rmcp worker adapters. +//! +//! This file defines wrapper types that bind existing ContextVM Nostr +//! transports to rmcp's worker abstraction. + +use crate::core::error::Result; +use crate::core::types::JsonRpcMessage; +use crate::transport::client::{NostrClientTransport, NostrClientTransportConfig}; +use crate::transport::server::{NostrServerTransport, NostrServerTransportConfig}; +use rmcp::transport::worker::{Worker, WorkerContext, WorkerQuitReason}; + +use super::convert::{ + internal_to_rmcp_client_rx, internal_to_rmcp_server_rx, rmcp_client_tx_to_internal, + rmcp_server_tx_to_internal, +}; + +/// rmcp server worker wrapper for ContextVM Nostr server transport. +pub struct NostrServerWorker { + transport: NostrServerTransport, +} + +impl NostrServerWorker { + /// Create a new server worker from existing server transport config. + pub async fn new(signer: T, config: NostrServerTransportConfig) -> Result + where + T: nostr_sdk::prelude::IntoNostrSigner, + { + let transport = NostrServerTransport::new(signer, config).await?; + Ok(Self { transport }) + } + + /// Access the wrapped transport. + pub fn transport(&self) -> &NostrServerTransport { + &self.transport + } +} + +impl Worker for NostrServerWorker { + type Error = crate::core::error::Error; + type Role = rmcp::RoleServer; + + fn err_closed() -> Self::Error { + Self::Error::Transport("rmcp worker channel closed".to_string()) + } + + fn err_join(e: tokio::task::JoinError) -> Self::Error { + Self::Error::Other(format!("rmcp worker join error: {e}")) + } + + async fn run( + mut self, + mut context: WorkerContext, + ) -> std::result::Result<(), WorkerQuitReason> { + self.transport + .start() + .await + .map_err(WorkerQuitReason::fatal_context("starting server transport"))?; + + let mut rx = self.transport.take_message_receiver().ok_or_else(|| { + WorkerQuitReason::fatal( + Self::Error::Other("server message receiver already taken".to_string()), + "taking server message receiver", + ) + })?; + + let cancellation_token = context.cancellation_token.clone(); + + let quit_reason = loop { + tokio::select! { + _ = cancellation_token.cancelled() => { + break WorkerQuitReason::Cancelled; + } + incoming = rx.recv() => { + let Some(mut incoming) = incoming else { + break WorkerQuitReason::TransportClosed; + }; + + // Use event_id as internal request id to keep response routing deterministic. + if let JsonRpcMessage::Request(ref mut req) = incoming.message { + req.id = serde_json::json!(incoming.event_id.clone()); + } + + if let Some(rmcp_msg) = internal_to_rmcp_server_rx(&incoming.message) { + if let Err(reason) = context.send_to_handler(rmcp_msg).await { + break reason; + } + } else { + tracing::warn!("Failed to convert incoming server-side message to rmcp format"); + } + } + outbound = context.recv_from_handler() => { + let outbound = match outbound { + Ok(outbound) => outbound, + Err(reason) => break reason, + }; + + let result = if let Some(internal_msg) = rmcp_server_tx_to_internal(outbound.message) { + self.forward_server_internal(internal_msg).await + } else { + Err(Self::Error::Validation( + "failed converting rmcp server message to internal JSON-RPC".to_string(), + )) + }; + + let _ = outbound.responder.send(result); + } + } + }; + + if let Err(e) = self.transport.close().await { + tracing::warn!("Failed to close server transport cleanly: {e}"); + } + + Err(quit_reason) + } +} + +/// rmcp client worker wrapper for ContextVM Nostr client transport. +pub struct NostrClientWorker { + transport: NostrClientTransport, +} + +impl NostrClientWorker { + /// Create a new client worker from existing client transport config. + pub async fn new(signer: T, config: NostrClientTransportConfig) -> Result + where + T: nostr_sdk::prelude::IntoNostrSigner, + { + let transport = NostrClientTransport::new(signer, config).await?; + Ok(Self { transport }) + } + + /// Access the wrapped transport. + pub fn transport(&self) -> &NostrClientTransport { + &self.transport + } +} + +impl Worker for NostrClientWorker { + type Error = crate::core::error::Error; + type Role = rmcp::RoleClient; + + fn err_closed() -> Self::Error { + Self::Error::Transport("rmcp worker channel closed".to_string()) + } + + fn err_join(e: tokio::task::JoinError) -> Self::Error { + Self::Error::Other(format!("rmcp worker join error: {e}")) + } + + async fn run( + mut self, + mut context: WorkerContext, + ) -> std::result::Result<(), WorkerQuitReason> { + self.transport + .start() + .await + .map_err(WorkerQuitReason::fatal_context("starting client transport"))?; + + let mut rx = self.transport.take_message_receiver().ok_or_else(|| { + WorkerQuitReason::fatal( + Self::Error::Other("client message receiver already taken".to_string()), + "taking client message receiver", + ) + })?; + + let cancellation_token = context.cancellation_token.clone(); + + let quit_reason = loop { + tokio::select! { + _ = cancellation_token.cancelled() => { + break WorkerQuitReason::Cancelled; + } + incoming = rx.recv() => { + let Some(incoming) = incoming else { + break WorkerQuitReason::TransportClosed; + }; + + if let Some(rmcp_msg) = internal_to_rmcp_client_rx(&incoming) { + if let Err(reason) = context.send_to_handler(rmcp_msg).await { + break reason; + } + } else { + tracing::warn!("Failed to convert incoming client-side message to rmcp format"); + } + } + outbound = context.recv_from_handler() => { + let outbound = match outbound { + Ok(outbound) => outbound, + Err(reason) => break reason, + }; + + let result = if let Some(internal_msg) = rmcp_client_tx_to_internal(outbound.message) { + self.transport.send(&internal_msg).await + } else { + Err(Self::Error::Validation( + "failed converting rmcp client message to internal JSON-RPC".to_string(), + )) + }; + + let _ = outbound.responder.send(result); + } + } + }; + + if let Err(e) = self.transport.close().await { + tracing::warn!("Failed to close client transport cleanly: {e}"); + } + + Err(quit_reason) + } +} + +impl NostrServerWorker { + async fn forward_server_internal(&self, message: JsonRpcMessage) -> Result<()> { + match message { + JsonRpcMessage::Response(resp) => { + let event_id = resp.id.as_str().map(str::to_owned).ok_or_else(|| { + crate::core::error::Error::Validation( + "rmcp server response id must be a string event id".to_string(), + ) + })?; + self.transport + .send_response(&event_id, JsonRpcMessage::Response(resp)) + .await + } + JsonRpcMessage::ErrorResponse(resp) => { + let event_id = resp.id.as_str().map(str::to_owned).ok_or_else(|| { + crate::core::error::Error::Validation( + "rmcp server error response id must be a string event id".to_string(), + ) + })?; + self.transport + .send_response(&event_id, JsonRpcMessage::ErrorResponse(resp)) + .await + } + JsonRpcMessage::Notification(notification) => { + let message = JsonRpcMessage::Notification(notification); + self.transport.broadcast_notification(&message).await + } + JsonRpcMessage::Request(_) => Err(crate::core::error::Error::Validation( + "rmcp server worker cannot forward outbound request messages".to_string(), + )), + } + } +} From 78edc9cb0c68695ad01c4b10dc1622459dd31bf2 Mon Sep 17 00:00:00 2001 From: Kushagra Date: Fri, 3 Apr 2026 20:46:38 +0530 Subject: [PATCH 07/69] feat: wired logic in gateway --- src/core/constants.rs | 6 + src/gateway/mod.rs | 26 +++ src/rmcp_transport/convert.rs | 101 +++++++++ src/rmcp_transport/mod.rs | 6 +- src/rmcp_transport/pipeline_tests.rs | 309 +++++++++++++++++++++++++++ src/rmcp_transport/worker.rs | 113 ++++++++-- src/transport/client.rs | 16 +- src/transport/server.rs | 49 +++++ 8 files changed, 598 insertions(+), 28 deletions(-) create mode 100644 src/rmcp_transport/pipeline_tests.rs diff --git a/src/core/constants.rs b/src/core/constants.rs index 345c91b..8aecf78 100644 --- a/src/core/constants.rs +++ b/src/core/constants.rs @@ -69,6 +69,12 @@ pub mod tags { /// Maximum message size (1MB) pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024; +/// MCP protocol version string used in initialize responses. +/// +/// Matches the `protocolVersion` field of the `InitializeResult` JSON-RPC response. +/// Keep this in sync with the MCP spec and rmcp's `ProtocolVersion::LATEST`. +pub const MCP_PROTOCOL_VERSION: &str = "2025-11-25"; + /// Default LRU cache size for deduplication pub const DEFAULT_LRU_SIZE: usize = 5000; diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 3752e79..345f2b3 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -81,6 +81,32 @@ impl NostrMCPGateway { } } +#[cfg(feature = "rmcp")] +impl NostrMCPGateway { + /// Start a gateway directly from an rmcp server handler. + /// + /// This additive API keeps the existing `new/start/send_response` flow intact, + /// while allowing rmcp-first usage through the worker adapter. + pub async fn serve_handler( + signer: T, + config: GatewayConfig, + handler: H, + ) -> Result> + where + T: nostr_sdk::prelude::IntoNostrSigner, + H: rmcp::ServerHandler, + { + use crate::rmcp_transport::NostrServerWorker; + use rmcp::ServiceExt; + + let worker = NostrServerWorker::new(signer, config.nostr_config).await?; + handler + .serve(worker) + .await + .map_err(|e| Error::Other(format!("rmcp server initialization failed: {e}"))) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/rmcp_transport/convert.rs b/src/rmcp_transport/convert.rs index e6c49bb..e575c37 100644 --- a/src/rmcp_transport/convert.rs +++ b/src/rmcp_transport/convert.rs @@ -42,3 +42,104 @@ pub fn rmcp_client_tx_to_internal( let value = serde_json::to_value(msg).ok()?; serde_json::from_value(value).ok() } + +#[cfg(all(test, feature = "rmcp"))] +mod tests { + use super::*; + use crate::core::types::{JsonRpcRequest, JsonRpcResponse}; + + #[test] + fn test_internal_request_to_rmcp_server_rx_ping() { + let internal = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "ping".to_string(), + params: None, + }); + + let rmcp_msg = internal_to_rmcp_server_rx(&internal) + .expect("expected conversion to rmcp server rx message"); + let value = serde_json::to_value(rmcp_msg).expect("serialize rmcp message to JSON"); + + assert_eq!(value.get("method"), Some(&serde_json::json!("ping"))); + assert_eq!(value.get("id"), Some(&serde_json::json!(1))); + } + + #[test] + fn test_internal_response_to_rmcp_client_rx_empty_result() { + let internal = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(42), + result: serde_json::json!({}), + }); + + let rmcp_msg = internal_to_rmcp_client_rx(&internal) + .expect("expected conversion to rmcp client rx message"); + let value = serde_json::to_value(rmcp_msg).expect("serialize rmcp message to JSON"); + + assert_eq!(value.get("id"), Some(&serde_json::json!(42))); + assert_eq!(value.get("result"), Some(&serde_json::json!({}))); + } + + #[test] + fn test_rmcp_server_tx_to_internal_response() { + let rmcp_msg = rmcp::model::ServerJsonRpcMessage::response( + rmcp::model::ServerResult::empty(()), + rmcp::model::RequestId::Number(7), + ); + + let internal = rmcp_server_tx_to_internal(rmcp_msg) + .expect("expected conversion from rmcp server tx to internal JSON-RPC"); + + match internal { + JsonRpcMessage::Response(resp) => { + assert_eq!(resp.id, serde_json::json!(7)); + assert_eq!(resp.result, serde_json::json!({})); + } + other => panic!("expected internal response, got {other:?}"), + } + } + + #[test] + fn test_rmcp_client_tx_to_internal_response() { + let rmcp_msg = rmcp::model::ClientJsonRpcMessage::response( + rmcp::model::ClientResult::empty(()), + rmcp::model::RequestId::Number(9), + ); + + let internal = rmcp_client_tx_to_internal(rmcp_msg) + .expect("expected conversion from rmcp client tx to internal JSON-RPC"); + + match internal { + JsonRpcMessage::Response(resp) => { + assert_eq!(resp.id, serde_json::json!(9)); + assert_eq!(resp.result, serde_json::json!({})); + } + other => panic!("expected internal response, got {other:?}"), + } + } + + #[test] + fn test_server_rx_roundtrip_preserves_wire_shape() { + let internal = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("abc"), + method: "ping".to_string(), + params: None, + }); + + let rmcp_msg = internal_to_rmcp_server_rx(&internal) + .expect("expected conversion to rmcp server rx message"); + let value = serde_json::to_value(rmcp_msg).expect("serialize rmcp message to JSON"); + let roundtrip_internal: JsonRpcMessage = + serde_json::from_value(value).expect("deserialize back to internal JSON-RPC"); + + match roundtrip_internal { + JsonRpcMessage::Request(req) => { + assert_eq!(req.id, serde_json::json!("abc")); + assert_eq!(req.method, "ping"); + } + other => panic!("expected internal request, got {other:?}"), + } + } +} diff --git a/src/rmcp_transport/mod.rs b/src/rmcp_transport/mod.rs index c1857f8..44a2ab1 100644 --- a/src/rmcp_transport/mod.rs +++ b/src/rmcp_transport/mod.rs @@ -1,11 +1,13 @@ //! RMCP integration scaffolding. //! -//! This module will host adapters that bridge the existing Nostr transport -//! implementation with rmcp services. +//! This module bridges the existing Nostr transport implementation with rmcp services. pub mod convert; pub mod worker; +#[cfg(test)] +mod pipeline_tests; + pub use convert::{ internal_to_rmcp_client_rx, internal_to_rmcp_server_rx, rmcp_client_tx_to_internal, rmcp_server_tx_to_internal, diff --git a/src/rmcp_transport/pipeline_tests.rs b/src/rmcp_transport/pipeline_tests.rs new file mode 100644 index 0000000..d8da58d --- /dev/null +++ b/src/rmcp_transport/pipeline_tests.rs @@ -0,0 +1,309 @@ +//! End-to-end pipeline tests for the rmcp ↔ Nostr transport integration. +//! +//! These tests verify every step of the message journey without requiring a live +//! relay connection: +//! +//! ```text +//! Nostr event content (JSON string) +//! → serializers::nostr_event_to_mcp_message [Layer 1: deserialise] +//! → internal_to_rmcp_server_rx [Layer 2: type bridge] +//! → (rmcp handler processes it) [Layer 3: rmcp dispatch – simulated] +//! → rmcp_server_tx_to_internal [Layer 4: type bridge back] +//! → send_response (event_id correlation) [Layer 5: route back to Nostr – mocked] +//! ``` + +#[cfg(all(test, feature = "rmcp"))] +mod tests { + use std::collections::HashMap; + + use rmcp::model::{ClientJsonRpcMessage, ClientResult, RequestId, ServerJsonRpcMessage, ServerResult}; + + use crate::core::serializers; + use crate::core::types::{JsonRpcError, JsonRpcErrorResponse, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse}; + use crate::rmcp_transport::convert::{ + internal_to_rmcp_client_rx, internal_to_rmcp_server_rx, + rmcp_client_tx_to_internal, rmcp_server_tx_to_internal, + }; + + // ── Layer 1: Nostr event content → JsonRpcMessage ────────────────────── + + #[test] + fn layer1_nostr_content_to_internal_request() { + let content = r#"{"jsonrpc":"2.0","id":1,"method":"ping","params":{}}"#; + let msg = serializers::nostr_event_to_mcp_message(content) + .expect("valid MCP request should parse"); + + assert!(msg.is_request()); + assert_eq!(msg.method(), Some("ping")); + assert_eq!(msg.id(), Some(&serde_json::json!(1))); + } + + #[test] + fn layer1_nostr_content_to_internal_tools_list() { + let content = r#"{"jsonrpc":"2.0","id":"abc","method":"tools/list","params":{}}"#; + let msg = serializers::nostr_event_to_mcp_message(content).unwrap(); + assert_eq!(msg.method(), Some("tools/list")); + assert_eq!(msg.id(), Some(&serde_json::json!("abc"))); + } + + #[test] + fn layer1_nostr_content_to_internal_notification() { + let content = r#"{"jsonrpc":"2.0","method":"notifications/initialized"}"#; + let msg = serializers::nostr_event_to_mcp_message(content).unwrap(); + assert!(!msg.is_request()); + assert_eq!(msg.method(), Some("notifications/initialized")); + } + + #[test] + fn layer1_nostr_content_invalid_json_returns_none() { + assert!(serializers::nostr_event_to_mcp_message("not json").is_none()); + } + + #[test] + fn layer1_nostr_event_to_mcp_message_no_version_check() { + // DESIGN NOTE: nostr_event_to_mcp_message uses raw serde deserialization — + // it does NOT reject invalid jsonrpc versions. Version enforcement happens + // one layer up in base.rs via validate_message(), which IS tested separately + // in core::validation::tests::test_invalid_version and + // transport::base::tests::test_convert_event_to_mcp_invalid_jsonrpc_version. + // + // A message with jsonrpc "1.0" will parse successfully at the serializer + // layer because JsonRpcRequest accepts any String for the jsonrpc field. + let content = r#"{"jsonrpc":"1.0","id":1,"method":"ping"}"#; + // It parses — the struct captures jsonrpc as a plain String. + let msg = serializers::nostr_event_to_mcp_message(content); + // We don't assert None here; rejection happens in base.rs, not here. + // What we DO assert: if it parsed, the method and id are intact. + if let Some(msg) = msg { + assert_eq!(msg.method(), Some("ping")); + } + // The real rejection path is covered by: + // transport::base::tests::test_convert_event_to_mcp_invalid_jsonrpc_version + } + + // ── Layer 2: JsonRpcMessage → rmcp RxJsonRpcMessage (server) ─────────── + + #[test] + fn layer2_internal_request_converts_to_rmcp_server_rx() { + let msg = make_request("ping", serde_json::json!(1), None); + let rmcp = internal_to_rmcp_server_rx(&msg).expect("ping should convert"); + + let v = serde_json::to_value(&rmcp).unwrap(); + assert_eq!(v["method"], "ping"); + assert_eq!(v["id"], serde_json::json!(1)); + assert_eq!(v["jsonrpc"], "2.0"); + } + + #[test] + fn layer2_string_id_preserved_through_bridge() { + let msg = make_request("tools/list", serde_json::json!("req-xyz"), None); + let rmcp = internal_to_rmcp_server_rx(&msg).unwrap(); + + let v = serde_json::to_value(&rmcp).unwrap(); + assert_eq!(v["id"], serde_json::json!("req-xyz")); + } + + #[test] + fn layer2_notification_converts_to_rmcp_server_rx() { + let msg = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/initialized".to_string(), + params: None, + }); + let rmcp = internal_to_rmcp_server_rx(&msg) + .expect("initialized notification should convert"); + let v = serde_json::to_value(&rmcp).unwrap(); + assert_eq!(v["method"], "notifications/initialized"); + } + + #[test] + fn layer2_tools_list_with_params_converts() { + let msg = make_request( + "tools/list", + serde_json::json!(7), + Some(serde_json::json!({"cursor": "next-page"})), + ); + let rmcp = internal_to_rmcp_server_rx(&msg).unwrap(); + let v = serde_json::to_value(&rmcp).unwrap(); + assert_eq!(v["method"], "tools/list"); + assert_eq!(v["params"]["cursor"], "next-page"); + } + + // ── Layer 3+4: Simulated handler → rmcp response → internal ──────────── + + #[test] + fn layer4_rmcp_ping_response_roundtrip_number_id() { + // Simulate rmcp handler producing a ping response + let rmcp_response = ServerJsonRpcMessage::response( + ServerResult::empty(()), + RequestId::Number(42), + ); + let internal = rmcp_server_tx_to_internal(rmcp_response) + .expect("ping response should convert back"); + + match internal { + JsonRpcMessage::Response(r) => { + assert_eq!(r.id, serde_json::json!(42)); + assert_eq!(r.jsonrpc, "2.0"); + } + other => panic!("expected Response, got {other:?}"), + } + } + + #[test] + fn layer4_rmcp_ping_response_roundtrip_string_id() { + let rmcp_response = ServerJsonRpcMessage::response( + ServerResult::empty(()), + RequestId::String(std::sync::Arc::from("req-xyz")), + ); + let internal = rmcp_server_tx_to_internal(rmcp_response).unwrap(); + + match internal { + JsonRpcMessage::Response(r) => { + assert_eq!(r.id, serde_json::json!("req-xyz")); + } + other => panic!("expected Response, got {other:?}"), + } + } + + // ── Full roundtrip: internal → rmcp → internal ────────────────────────── + + #[test] + fn full_server_roundtrip_request_id_preserved() { + // Layer 2: convert incoming request to rmcp + let original = make_request("ping", serde_json::json!(99), None); + let rmcp_rx = internal_to_rmcp_server_rx(&original).unwrap(); + + // Extract the ID that rmcp sees + let rmcp_value = serde_json::to_value(&rmcp_rx).unwrap(); + let id_seen_by_rmcp = rmcp_value["id"].clone(); + assert_eq!(id_seen_by_rmcp, serde_json::json!(99)); + + // Layer 4: rmcp produces a response with the same ID echoed back + let rmcp_tx = ServerJsonRpcMessage::response( + ServerResult::empty(()), + RequestId::Number(99), + ); + let response = rmcp_server_tx_to_internal(rmcp_tx).unwrap(); + + // The response ID must equal the original request ID + assert_eq!(response.id(), Some(&serde_json::json!(99))); + } + + #[test] + fn full_client_roundtrip_response_id_preserved() { + // Client side: rmcp produces an outbound request + let rmcp_tx = ClientJsonRpcMessage::response( + ClientResult::empty(()), + RequestId::Number(7), + ); + let internal = rmcp_client_tx_to_internal(rmcp_tx).unwrap(); + assert_eq!(internal.id(), Some(&serde_json::json!(7))); + + // And an incoming server response converts to rmcp correctly + let incoming_response = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(7), + result: serde_json::json!({"tools": []}), + }); + let rmcp_rx = internal_to_rmcp_client_rx(&incoming_response).unwrap(); + let v = serde_json::to_value(&rmcp_rx).unwrap(); + assert_eq!(v["id"], serde_json::json!(7)); + assert_eq!(v["result"]["tools"], serde_json::json!([])); + } + + // ── Layer 5: ID correlation map logic (mirrors NostrServerWorker) ──────── + + #[test] + fn layer5_worker_correlation_map_number_id() { + let mut request_id_to_event_id: HashMap = HashMap::new(); + let fake_event_id = "aaaaaa".to_string(); + + // Step 1: incoming request arrives — worker stores req_id → event_id + let req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(42), + method: "tools/list".to_string(), + params: None, + }); + + if let JsonRpcMessage::Request(ref r) = req { + let key = serde_json::to_string(&r.id).unwrap(); + request_id_to_event_id.insert(key, fake_event_id.clone()); + } + + // Step 2: rmcp response comes back with id=42 + let response = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(42), + result: serde_json::json!({}), + }); + + // Step 3: worker looks up the event_id to call send_response + if let JsonRpcMessage::Response(ref r) = response { + let key = serde_json::to_string(&r.id).unwrap(); + let found = request_id_to_event_id.remove(&key); + assert_eq!(found, Some(fake_event_id)); + } else { + panic!("expected Response"); + } + + // Map should be empty after handling + assert!(request_id_to_event_id.is_empty()); + } + + #[test] + fn layer5_worker_correlation_map_string_id() { + let mut request_id_to_event_id: HashMap = HashMap::new(); + let fake_event_id = "bbbbbb".to_string(); + + // String IDs serialize with surrounding quotes: "\"req-abc\"" + let req_id = serde_json::json!("req-abc"); + let key = serde_json::to_string(&req_id).unwrap(); + request_id_to_event_id.insert(key.clone(), fake_event_id.clone()); + + // The response ID serializes identically + let resp_id = serde_json::json!("req-abc"); + let resp_key = serde_json::to_string(&resp_id).unwrap(); + + // Key derived from response ID must match the one stored from request ID + assert_eq!(key, resp_key); + assert_eq!(request_id_to_event_id.remove(&resp_key), Some(fake_event_id)); + } + + #[test] + fn layer5_error_response_correlation_works() { + let mut map: HashMap = HashMap::new(); + map.insert(serde_json::to_string(&serde_json::json!(5)).unwrap(), "evt5".to_string()); + + let error_response = JsonRpcMessage::ErrorResponse(JsonRpcErrorResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(5), + error: JsonRpcError { + code: -32601, + message: "Method not found".to_string(), + data: None, + }, + }); + + if let JsonRpcMessage::ErrorResponse(ref r) = error_response { + let key = serde_json::to_string(&r.id).unwrap(); + assert_eq!(map.remove(&key), Some("evt5".to_string())); + } + } + + // ── Helper ────────────────────────────────────────────────────────────── + + fn make_request( + method: &str, + id: serde_json::Value, + params: Option, + ) -> JsonRpcMessage { + JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id, + method: method.to_string(), + params, + }) + } +} diff --git a/src/rmcp_transport/worker.rs b/src/rmcp_transport/worker.rs index 6807fbc..72f729f 100644 --- a/src/rmcp_transport/worker.rs +++ b/src/rmcp_transport/worker.rs @@ -3,6 +3,8 @@ //! This file defines wrapper types that bind existing ContextVM Nostr //! transports to rmcp's worker abstraction. +use std::collections::HashMap; + use crate::core::error::Result; use crate::core::types::JsonRpcMessage; use crate::transport::client::{NostrClientTransport, NostrClientTransportConfig}; @@ -17,6 +19,10 @@ use super::convert::{ /// rmcp server worker wrapper for ContextVM Nostr server transport. pub struct NostrServerWorker { transport: NostrServerTransport, + // rmcp service instance is single-peer. Keep one active client per worker. + active_client_pubkey: Option, + // Maps request id (serialized JSON value) -> incoming Nostr event id. + request_id_to_event_id: HashMap, } impl NostrServerWorker { @@ -26,7 +32,11 @@ impl NostrServerWorker { T: nostr_sdk::prelude::IntoNostrSigner, { let transport = NostrServerTransport::new(signer, config).await?; - Ok(Self { transport }) + Ok(Self { + transport, + active_client_pubkey: None, + request_id_to_event_id: HashMap::new(), + }) } /// Access the wrapped transport. @@ -71,16 +81,45 @@ impl Worker for NostrServerWorker { break WorkerQuitReason::Cancelled; } incoming = rx.recv() => { - let Some(mut incoming) = incoming else { + let Some(incoming) = incoming else { break WorkerQuitReason::TransportClosed; }; - // Use event_id as internal request id to keep response routing deterministic. - if let JsonRpcMessage::Request(ref mut req) = incoming.message { - req.id = serde_json::json!(incoming.event_id.clone()); + let crate::transport::server::IncomingRequest { + message, + client_pubkey, + event_id, + .. + } = incoming; + + match &self.active_client_pubkey { + Some(active) if active != &client_pubkey => { + tracing::warn!( + active_client = %active, + ignored_client = %client_pubkey, + "Ignoring message from second client: rmcp server worker currently supports one active client per worker" + ); + continue; + } + None => { + tracing::info!(client_pubkey = %client_pubkey, "Binding rmcp server worker to first client session"); + self.active_client_pubkey = Some(client_pubkey.clone()); + } + _ => {} + } + + if let JsonRpcMessage::Request(req) = &message { + match serde_json::to_string(&req.id) { + Ok(request_key) => { + self.request_id_to_event_id.insert(request_key, event_id); + } + Err(e) => { + tracing::warn!("Failed to serialize request id for correlation map: {e}"); + } + } } - if let Some(rmcp_msg) = internal_to_rmcp_server_rx(&incoming.message) { + if let Some(rmcp_msg) = internal_to_rmcp_server_rx(&message) { if let Err(reason) = context.send_to_handler(rmcp_msg).await { break reason; } @@ -212,35 +251,71 @@ impl Worker for NostrClientWorker { } impl NostrServerWorker { - async fn forward_server_internal(&self, message: JsonRpcMessage) -> Result<()> { + async fn forward_server_internal(&mut self, message: JsonRpcMessage) -> Result<()> { match message { JsonRpcMessage::Response(resp) => { - let event_id = resp.id.as_str().map(str::to_owned).ok_or_else(|| { - crate::core::error::Error::Validation( - "rmcp server response id must be a string event id".to_string(), - ) + let request_key = serde_json::to_string(&resp.id).map_err(|e| { + crate::core::error::Error::Validation(format!( + "failed to serialize rmcp response id for correlation lookup: {e}" + )) })?; + + let event_id = if let Some(event_id) = self.request_id_to_event_id.remove(&request_key) { + event_id + } else { + resp.id.as_str().map(str::to_owned).ok_or_else(|| { + crate::core::error::Error::Validation( + "rmcp server response id has no known correlation mapping and is not a string event id" + .to_string(), + ) + })? + }; + self.transport .send_response(&event_id, JsonRpcMessage::Response(resp)) .await } JsonRpcMessage::ErrorResponse(resp) => { - let event_id = resp.id.as_str().map(str::to_owned).ok_or_else(|| { - crate::core::error::Error::Validation( - "rmcp server error response id must be a string event id".to_string(), - ) + let request_key = serde_json::to_string(&resp.id).map_err(|e| { + crate::core::error::Error::Validation(format!( + "failed to serialize rmcp error response id for correlation lookup: {e}" + )) })?; + + let event_id = if let Some(event_id) = self.request_id_to_event_id.remove(&request_key) { + event_id + } else { + resp.id.as_str().map(str::to_owned).ok_or_else(|| { + crate::core::error::Error::Validation( + "rmcp server error response id has no known correlation mapping and is not a string event id" + .to_string(), + ) + })? + }; + self.transport .send_response(&event_id, JsonRpcMessage::ErrorResponse(resp)) .await } JsonRpcMessage::Notification(notification) => { + let target = self.active_client_pubkey.as_deref().ok_or_else(|| { + crate::core::error::Error::Validation( + "cannot forward rmcp server notification: no active client bound" + .to_string(), + ) + })?; let message = JsonRpcMessage::Notification(notification); - self.transport.broadcast_notification(&message).await + self.transport.send_notification(target, &message, None).await + } + JsonRpcMessage::Request(request) => { + let target = self.active_client_pubkey.as_deref().ok_or_else(|| { + crate::core::error::Error::Validation( + "cannot forward rmcp server request: no active client bound".to_string(), + ) + })?; + let message = JsonRpcMessage::Request(request); + self.transport.send_notification(target, &message, None).await } - JsonRpcMessage::Request(_) => Err(crate::core::error::Error::Validation( - "rmcp server worker cannot forward outbound request messages".to_string(), - )), } } } diff --git a/src/transport/client.rs b/src/transport/client.rs index 5a0255b..fbe4c42 100644 --- a/src/transport/client.rs +++ b/src/transport/client.rs @@ -17,7 +17,7 @@ use crate::core::types::*; use crate::encryption; use crate::relay::RelayPool; use crate::transport::base::BaseTransport; -use rmcp::model::ProtocolVersion; + /// Configuration for the client transport. pub struct NostrClientTransportConfig { @@ -134,10 +134,12 @@ impl NostrClientTransport { .send_mcp_message(message, &self.server_pubkey, CTXVM_MESSAGES_KIND, tags, None) .await?; - self.pending_requests - .write() - .await - .insert(event_id.to_hex()); + if matches!(message, JsonRpcMessage::Request(_)) { + self.pending_requests + .write() + .await + .insert(event_id.to_hex()); + } Ok(()) } @@ -154,7 +156,7 @@ impl NostrClientTransport { jsonrpc: "2.0".to_string(), id: request_id.clone(), result: serde_json::json!({ - "protocolVersion": ProtocolVersion::LATEST.to_string(), + "protocolVersion": crate::core::constants::MCP_PROTOCOL_VERSION, "serverInfo": { "name": "Emulated-Stateless-Server", "version": "1.0.0" @@ -277,7 +279,7 @@ mod tests { jsonrpc: "2.0".to_string(), id: request_id.clone(), result: serde_json::json!({ - "protocolVersion": ProtocolVersion::LATEST.to_string(), + "protocolVersion": crate::core::constants::MCP_PROTOCOL_VERSION, "serverInfo": { "name": "Emulated-Stateless-Server", "version": "1.0.0" diff --git a/src/transport/server.rs b/src/transport/server.rs index c6483a6..d45e1f9 100644 --- a/src/transport/server.rs +++ b/src/transport/server.rs @@ -395,6 +395,55 @@ impl NostrServerTransport { Ok(()) } + /// Publish tools list from rmcp typed tool descriptors. + #[cfg(feature = "rmcp")] + pub async fn publish_tools_typed(&self, tools: Vec) -> Result { + let tools = tools + .into_iter() + .map(serde_json::to_value) + .collect::, _>>()?; + self.publish_tools(tools).await + } + + /// Publish resources list from rmcp typed resource descriptors. + #[cfg(feature = "rmcp")] + pub async fn publish_resources_typed( + &self, + resources: Vec, + ) -> Result { + let resources = resources + .into_iter() + .map(serde_json::to_value) + .collect::, _>>()?; + self.publish_resources(resources).await + } + + /// Publish prompts list from rmcp typed prompt descriptors. + #[cfg(feature = "rmcp")] + pub async fn publish_prompts_typed( + &self, + prompts: Vec, + ) -> Result { + let prompts = prompts + .into_iter() + .map(serde_json::to_value) + .collect::, _>>()?; + self.publish_prompts(prompts).await + } + + /// Publish resource templates list from rmcp typed template descriptors. + #[cfg(feature = "rmcp")] + pub async fn publish_resource_templates_typed( + &self, + templates: Vec, + ) -> Result { + let templates = templates + .into_iter() + .map(serde_json::to_value) + .collect::, _>>()?; + self.publish_resource_templates(templates).await + } + // ── Internal ──────────────────────────────────────────────── fn is_capability_excluded( From 5505e3708b7e17b5237cc762ab3e0256253f7e90 Mon Sep 17 00:00:00 2001 From: Kushagra Date: Fri, 3 Apr 2026 20:48:18 +0530 Subject: [PATCH 08/69] fix: fixed duplicate dependency warning --- Cargo.toml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6b73c7d..0ca4b98 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,6 @@ nostr-sdk = { version = "0.43", features = ["nip59"] } # Logging tracing = "0.1" -rmcp = "1.3.0" # Optional MCP integration (Rust equivalent to TS @modelcontextprotocol/sdk) rmcp = { version = "0.16.0", features = ["server", "client", "macros", "transport-worker"], optional = true } @@ -35,4 +34,6 @@ rmcp = ["dep:rmcp"] [dev-dependencies] tokio-test = "0.4" -tracing-subscriber = "0.3" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +anyhow = "1" +schemars = "0.8" From d3d9a2fe7a3d32bfd73c3acf08f1bf1f6113dbd9 Mon Sep 17 00:00:00 2001 From: Akhil Dhyani Date: Fri, 3 Apr 2026 22:49:02 +0530 Subject: [PATCH 09/69] feat(rmcp): add typed discovery and client service integration - Implement discover_*_typed methods to parse Nostr discovery events into rmcp models - Add NostrMCPProxy::serve_client_handler to support rmcp client handler integration - Add parse_typed_list internal helper for discovery result deserialization --- src/discovery/mod.rs | 60 +++++++++++++++++++++++++++++++++++++++++++- src/proxy/mod.rs | 28 ++++++++++++++++++++- 2 files changed, 86 insertions(+), 2 deletions(-) diff --git a/src/discovery/mod.rs b/src/discovery/mod.rs index eb0e876..e198674 100644 --- a/src/discovery/mod.rs +++ b/src/discovery/mod.rs @@ -120,6 +120,50 @@ pub async fn discover_resource_templates( .await } +/// Discover tools and parse them into rmcp typed descriptors. +#[cfg(feature = "rmcp")] +pub async fn discover_tools_typed( + client: &Arc, + server_pubkey: &PublicKey, + relay_urls: &[String], +) -> Result> { + let raw = discover_tools(client, server_pubkey, relay_urls).await?; + parse_typed_list(raw) +} + +/// Discover resources and parse them into rmcp typed descriptors. +#[cfg(feature = "rmcp")] +pub async fn discover_resources_typed( + client: &Arc, + server_pubkey: &PublicKey, + relay_urls: &[String], +) -> Result> { + let raw = discover_resources(client, server_pubkey, relay_urls).await?; + parse_typed_list(raw) +} + +/// Discover prompts and parse them into rmcp typed descriptors. +#[cfg(feature = "rmcp")] +pub async fn discover_prompts_typed( + client: &Arc, + server_pubkey: &PublicKey, + relay_urls: &[String], +) -> Result> { + let raw = discover_prompts(client, server_pubkey, relay_urls).await?; + parse_typed_list(raw) +} + +/// Discover resource templates and parse them into rmcp typed descriptors. +#[cfg(feature = "rmcp")] +pub async fn discover_resource_templates_typed( + client: &Arc, + server_pubkey: &PublicKey, + relay_urls: &[String], +) -> Result> { + let raw = discover_resource_templates(client, server_pubkey, relay_urls).await?; + parse_typed_list(raw) +} + // ── Internal ──────────────────────────────────────────────────────── async fn fetch_list( @@ -153,6 +197,20 @@ async fn fetch_list( .unwrap_or_default()) } +#[cfg(feature = "rmcp")] +fn parse_typed_list(raw: Vec) -> Result> +where + T: serde::de::DeserializeOwned, +{ + let mut parsed = Vec::new(); + for item in raw { + let value = serde_json::from_value(item) + .map_err(|e| Error::Other(format!("Failed to parse typed discovery item: {e}")))?; + parsed.push(value); + } + Ok(parsed) +} + #[cfg(test)] mod tests { use super::*; @@ -229,4 +287,4 @@ mod tests { assert_eq!(announcement.pubkey, pubkey.to_hex()); assert_eq!(announcement.server_info.name, Some("Test".to_string())); } -} +} \ No newline at end of file diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index fab56d7..d7b75f5 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -70,6 +70,32 @@ impl NostrMCPProxy { } } +#[cfg(feature = "rmcp")] +impl NostrMCPProxy { + /// Start a proxy directly from an rmcp client handler. + /// + /// This additive API keeps the existing `new/start/send` flow intact, + /// while allowing rmcp-first usage through the worker adapter. + pub async fn serve_client_handler( + signer: T, + config: ProxyConfig, + handler: H, + ) -> Result> + where + T: nostr_sdk::prelude::IntoNostrSigner, + H: rmcp::ClientHandler, + { + use crate::rmcp_transport::NostrClientWorker; + use rmcp::ServiceExt; + + let worker = NostrClientWorker::new(signer, config.nostr_config).await?; + handler + .serve(worker) + .await + .map_err(|e| Error::Other(format!("rmcp client initialization failed: {e}"))) + } +} + #[cfg(test)] mod tests { use super::*; @@ -107,4 +133,4 @@ mod tests { assert!(!config.nostr_config.is_stateless); assert_eq!(config.nostr_config.encryption_mode, EncryptionMode::Optional); } -} +} \ No newline at end of file From 84911d7a81f374024778a9ec1f60085ff91737a6 Mon Sep 17 00:00:00 2001 From: Akhil Dhyani Date: Sat, 4 Apr 2026 02:17:06 +0530 Subject: [PATCH 10/69] feat: add rmcp integration test example Adds an integration test matrix for RMCP transport, covering local, hybrid, and full relay-based communication scenarios. --- examples/rmcp_integration_test.rs | 726 ++++++++++++++++++++++++++++++ 1 file changed, 726 insertions(+) create mode 100644 examples/rmcp_integration_test.rs diff --git a/examples/rmcp_integration_test.rs b/examples/rmcp_integration_test.rs new file mode 100644 index 0000000..75990ba --- /dev/null +++ b/examples/rmcp_integration_test.rs @@ -0,0 +1,726 @@ +//! Comprehensive rmcp integration matrix for ContextVM SDK. +//! +//! This example validates three scenarios: +//! 1) local rmcp transport (in-process duplex) +//! 2) hybrid relay mode (rmcp server + legacy JSON-RPC client) +//! 3) full rmcp over relays (rmcp server + rmcp client) +//! +//! Run: +//! cargo run --example rmcp_integration_test --features rmcp +//! cargo run --example rmcp_integration_test --features rmcp -- local +//! cargo run --example rmcp_integration_test --features rmcp -- hybrid +//! cargo run --example rmcp_integration_test --features rmcp -- relay-rmcp +//! cargo run --example rmcp_integration_test --features rmcp -- all +//! +//! Optional relay override: +//! CTXVM_RELAY_URL=wss://relay.primal.net cargo run --example rmcp_integration_test --features rmcp -- all +//! cargo run --example rmcp_integration_test --features rmcp -- all wss://relay.primal.net + +use anyhow::{anyhow, bail, Context, Result}; +use contextvm_sdk::core::constants::MCP_PROTOCOL_VERSION; +use contextvm_sdk::core::types::{ + EncryptionMode, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, ServerInfo as CtxServerInfo, +}; +use contextvm_sdk::gateway::{GatewayConfig, NostrMCPGateway}; +use contextvm_sdk::proxy::{NostrMCPProxy, ProxyConfig}; +use contextvm_sdk::signer; +use contextvm_sdk::transport::client::NostrClientTransportConfig; +use contextvm_sdk::transport::server::NostrServerTransportConfig; +use rmcp::{ + handler::server::wrapper::Parameters, + model::*, + schemars, + service::RequestContext, + tool, tool_handler, tool_router, ClientHandler, RoleServer, ServerHandler, ServiceExt, +}; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::Mutex; +use tokio::time::{sleep, timeout}; + +const DEFAULT_RELAY_URL: &str = "wss://relay.primal.net"; +const IO_TIMEOUT: Duration = Duration::from_secs(30); +const RELAY_WARMUP: Duration = Duration::from_secs(2); +const STARTUP_TIMEOUT: Duration = Duration::from_secs(20); + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Mode { + Local, + Hybrid, + RelayRmcp, + All, +} + +impl Mode { + fn parse(value: Option<&str>) -> Result { + match value.unwrap_or("all") { + "local" => Ok(Self::Local), + "hybrid" => Ok(Self::Hybrid), + "relay-rmcp" => Ok(Self::RelayRmcp), + "all" => Ok(Self::All), + other => bail!( + "Unknown mode '{other}'. Use one of: local | hybrid | relay-rmcp | all" + ), + } + } + + fn run_local(self) -> bool { + matches!(self, Self::Local | Self::All) + } + + fn run_hybrid(self) -> bool { + matches!(self, Self::Hybrid | Self::All) + } + + fn run_relay_rmcp(self) -> bool { + matches!(self, Self::RelayRmcp | Self::All) + } +} + +// Parameter structs with JSON schema for tools/list. +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +struct EchoParams { + message: String, +} + +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +struct AddParams { + a: i64, + b: i64, +} + +use rmcp::handler::server::router::tool::ToolRouter; + +#[derive(Clone)] +struct DemoServer { + echo_count: Arc>, + tool_router: ToolRouter, +} + +impl DemoServer { + fn new() -> Self { + Self { + echo_count: Arc::new(Mutex::new(0)), + tool_router: Self::tool_router(), + } + } +} + +#[tool_router] +impl DemoServer { + #[tool(description = "Echo a message back unchanged")] + async fn echo( + &self, + Parameters(EchoParams { message }): Parameters, + ) -> Result { + let mut n = self.echo_count.lock().await; + *n += 1; + Ok(CallToolResult::success(vec![Content::text(format!( + "Echo #{n}: {message}" + ))])) + } + + #[tool(description = "Add two integers and return their sum")] + fn add( + &self, + Parameters(AddParams { a, b }): Parameters, + ) -> Result { + Ok(CallToolResult::success(vec![Content::text(format!( + "{a} + {b} = {}", + a + b + ))])) + } + + #[tool(description = "Return the total number of echo calls made so far")] + async fn get_echo_count(&self) -> Result { + let n = self.echo_count.lock().await; + Ok(CallToolResult::success(vec![Content::text(format!( + "Total echo calls: {n}" + ))])) + } +} + +#[tool_handler] +impl ServerHandler for DemoServer { + fn get_info(&self) -> ServerInfo { + ServerInfo { + protocol_version: ProtocolVersion::LATEST, + capabilities: ServerCapabilities::builder() + .enable_tools() + .enable_resources() + .build(), + server_info: Implementation { + name: "contextvm-demo".to_string(), + title: Some("ContextVM Demo Server".to_string()), + version: "0.1.0".to_string(), + description: Some("Demonstrates rmcp integration over ContextVM".to_string()), + icons: None, + website_url: None, + }, + instructions: Some("Try: echo, add, get_echo_count".to_string()), + } + } + + async fn list_resources( + &self, + _req: Option, + _ctx: RequestContext, + ) -> Result { + Ok(ListResourcesResult { + resources: vec![ + RawResource::new("demo://readme", "Demo README".to_string()).no_annotation(), + ], + next_cursor: None, + meta: None, + }) + } + + async fn read_resource( + &self, + req: ReadResourceRequestParams, + _ctx: RequestContext, + ) -> Result { + match req.uri.as_str() { + "demo://readme" => Ok(ReadResourceResult { + contents: vec![ResourceContents::text( + "This server demonstrates the ContextVM rmcp integration.", + req.uri, + )], + }), + other => Err(ErrorData::resource_not_found( + "not_found", + Some(serde_json::json!({ "uri": other })), + )), + } + } +} + +#[derive(Clone, Default)] +struct DemoClient; +impl ClientHandler for DemoClient {} + +#[derive(Clone, Default)] +struct RelayRmcpClient; +impl ClientHandler for RelayRmcpClient {} + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::from_default_env() + .add_directive("rmcp=warn".parse()?) + .add_directive("contextvm_sdk=info".parse()?), + ) + .init(); + + let args: Vec = std::env::args().skip(1).collect(); + let mode = Mode::parse(args.first().map(String::as_str))?; + let relay_url = args + .get(1) + .cloned() + .or_else(|| std::env::var("CTXVM_RELAY_URL").ok()) + .unwrap_or_else(|| DEFAULT_RELAY_URL.to_string()); + + println!("========================================"); + println!("ContextVM SDK rmcp integration matrix"); + println!("mode: {:?}", mode); + println!("relay: {relay_url}"); + println!("========================================\n"); + + if mode.run_local() { + run_local_rmcp_case().await?; + } + + if mode.run_hybrid() { + run_hybrid_relay_case(&relay_url).await?; + } + + if mode.run_relay_rmcp() { + run_relay_rmcp_case(&relay_url).await?; + } + + println!("\nAll selected integration scenarios passed."); + Ok(()) +} + +async fn run_local_rmcp_case() -> Result<()> { + println!("[local-rmcp] start"); + + let (server_io, client_io) = tokio::io::duplex(65536); + + let server_handle = tokio::spawn(async move { + DemoServer::new() + .serve(server_io) + .await + .expect("server serve failed") + .waiting() + .await + .expect("server error"); + }); + + let client = DemoClient.serve(client_io).await?; + + let tools = client.list_all_tools().await?; + assert_eq!(tools.len(), 3, "expected 3 tools in local rmcp case"); + + let add_result = client + .call_tool(call_params( + "add", + Some(serde_json::json!({ "a": 7, "b": 5 })), + )) + .await?; + let add_text = first_text(&add_result); + assert!(add_text.contains("12"), "expected add result to include 12"); + + let resources = client.list_all_resources().await?; + assert_eq!(resources.len(), 1, "expected one resource in local rmcp case"); + + match client.call_tool(call_params("no_such_tool", None)).await { + Err(_) => {} + Ok(r) if r.is_error.unwrap_or(false) => {} + Ok(_) => bail!("expected unknown tool to fail in local rmcp case"), + } + + client.cancel().await?; + server_handle.abort(); + + println!("[local-rmcp] pass"); + Ok(()) +} + +async fn run_hybrid_relay_case(relay_url: &str) -> Result<()> { + println!("[relay-hybrid] start (rmcp server + legacy client)"); + + let server_keys = signer::generate(); + let server_pubkey_hex = server_keys.public_key().to_hex(); + + println!("[relay-hybrid] stage: spawning rmcp server task"); + let relay_url_owned = relay_url.to_string(); + let server_task = tokio::spawn(async move { + let server = NostrMCPGateway::serve_handler( + server_keys, + server_config(&relay_url_owned), + DemoServer::new(), + ) + .await + .with_context(|| { + format!("failed to start rmcp server on relay {relay_url_owned}") + })?; + + let _ = server + .waiting() + .await + .map_err(|e| anyhow!("rmcp server exited with error: {e}"))?; + + Err(anyhow!("rmcp server stopped unexpectedly")) + }); + + sleep(RELAY_WARMUP).await; + + if server_task.is_finished() { + let res = server_task + .await + .map_err(|e| anyhow!("rmcp server task join error: {e}"))?; + return res.context("rmcp server task ended before client startup"); + } + + let outcome: Result<()> = async { + println!("[relay-hybrid] stage: creating legacy proxy client"); + + let mut proxy = timeout( + STARTUP_TIMEOUT, + NostrMCPProxy::new( + signer::generate(), + client_config(relay_url, server_pubkey_hex.clone()), + ), + ) + .await + .with_context(|| { + format!( + "timed out creating legacy proxy client after {:?}", + STARTUP_TIMEOUT + ) + })? + .context("failed to create legacy proxy client")?; + + println!("[relay-hybrid] stage: starting legacy proxy transport"); + let mut rx = timeout(STARTUP_TIMEOUT, proxy.start()) + .await + .with_context(|| { + format!( + "timed out starting legacy proxy transport after {:?}", + STARTUP_TIMEOUT + ) + })? + .context("failed to start legacy proxy")?; + println!("[relay-hybrid] stage: legacy proxy started"); + + let init_id = serde_json::json!(1); + let init_request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: init_id.clone(), + method: "initialize".to_string(), + params: Some(serde_json::json!({ + "protocolVersion": MCP_PROTOCOL_VERSION, + "capabilities": { + "tools": {}, + "resources": {} + }, + "clientInfo": { + "name": "legacy-hybrid-client", + "version": "0.1.0" + } + })), + }); + + let init_response = + send_legacy_request_and_wait(&proxy, &mut rx, init_request, &init_id).await?; + assert_initialize_shape(&init_response)?; + + proxy + .send(&JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/initialized".to_string(), + params: None, + })) + .await + .context("failed to send initialized notification")?; + + let tools_id = serde_json::json!(2); + let tools_request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: tools_id.clone(), + method: "tools/list".to_string(), + params: Some(serde_json::json!({})), + }); + + let tools_response = + send_legacy_request_and_wait(&proxy, &mut rx, tools_request, &tools_id).await?; + let tools = extract_tools_list(&tools_response)?; + assert!( + tools + .iter() + .any(|t| t.get("name") == Some(&serde_json::json!("echo"))), + "expected echo tool in hybrid case" + ); + + let call_id = serde_json::json!(3); + let call_request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: call_id.clone(), + method: "tools/call".to_string(), + params: Some(serde_json::json!({ + "name": "echo", + "arguments": { "message": "legacy-client-hello" } + })), + }); + + let call_response = send_legacy_request_and_wait(&proxy, &mut rx, call_request, &call_id) + .await + .context("tools/call failed in hybrid case")?; + let echo_text = extract_first_content_text(&call_response)?; + assert!( + echo_text.contains("legacy-client-hello"), + "unexpected echo output in hybrid case: {echo_text}" + ); + + let unknown_id = serde_json::json!(4); + let unknown_request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: unknown_id.clone(), + method: "tools/call".to_string(), + params: Some(serde_json::json!({ + "name": "no_such_tool", + "arguments": {} + })), + }); + + let unknown_response = + send_legacy_request_and_wait(&proxy, &mut rx, unknown_request, &unknown_id).await?; + assert_error_response(&unknown_response)?; + + proxy.stop().await.context("failed to stop legacy proxy")?; + + Ok(()) + } + .await; + + server_task.abort(); + + if server_task.is_finished() { + let _ = server_task.await; + } + + outcome?; + + println!("[relay-hybrid] pass"); + Ok(()) +} + +async fn run_relay_rmcp_case(relay_url: &str) -> Result<()> { + println!("[relay-rmcp] start (rmcp server + rmcp client)"); + + let server_keys = signer::generate(); + let server_pubkey_hex = server_keys.public_key().to_hex(); + + println!("[relay-rmcp] stage: spawning rmcp server task"); + let relay_url_owned = relay_url.to_string(); + let server_task = tokio::spawn(async move { + let server = NostrMCPGateway::serve_handler( + server_keys, + server_config(&relay_url_owned), + DemoServer::new(), + ) + .await + .with_context(|| { + format!("failed to start rmcp server on relay {relay_url_owned}") + })?; + + let _ = server + .waiting() + .await + .map_err(|e| anyhow!("rmcp server exited with error: {e}"))?; + + Err(anyhow!("rmcp server stopped unexpectedly")) + }); + + sleep(RELAY_WARMUP).await; + + if server_task.is_finished() { + let res = server_task + .await + .map_err(|e| anyhow!("rmcp server task join error: {e}"))?; + return res.context("rmcp server task ended before rmcp client startup"); + } + + let outcome: Result<()> = async { + println!("[relay-rmcp] stage: starting rmcp relay client worker"); + + let client = timeout( + STARTUP_TIMEOUT, + NostrMCPProxy::serve_client_handler( + signer::generate(), + client_config(relay_url, server_pubkey_hex), + RelayRmcpClient, + ), + ) + .await + .with_context(|| { + format!( + "timed out starting rmcp relay client worker after {:?}", + STARTUP_TIMEOUT + ) + })? + .context("failed to start rmcp relay client")?; + println!("[relay-rmcp] stage: rmcp relay client started"); + + let peer = client + .peer_info() + .ok_or_else(|| anyhow!("rmcp relay client did not receive peer info"))?; + let negotiated = peer.protocol_version.to_string(); + assert!( + is_supported_protocol(&negotiated), + "unexpected negotiated protocol version: {negotiated}" + ); + + let tools = client.list_all_tools().await?; + assert!( + tools.iter().any(|t| t.name == "echo"), + "expected echo tool in rmcp relay case" + ); + + let echo = client + .call_tool(call_params( + "echo", + Some(serde_json::json!({ "message": "rmcp-relay-hello" })), + )) + .await?; + let echo_text = first_text(&echo); + assert!( + echo_text.contains("rmcp-relay-hello"), + "unexpected rmcp relay echo output: {echo_text}" + ); + + let resources = client.list_all_resources().await?; + assert!( + resources.iter().any(|r| r.uri.as_str() == "demo://readme"), + "expected demo://readme resource in rmcp relay case" + ); + + match client.call_tool(call_params("no_such_tool", None)).await { + Err(_) => {} + Ok(r) if r.is_error.unwrap_or(false) => {} + Ok(_) => bail!("expected unknown tool to fail in rmcp relay case"), + } + + client + .cancel() + .await + .context("failed to cancel rmcp relay client")?; + + Ok(()) + } + .await; + + server_task.abort(); + + if server_task.is_finished() { + let _ = server_task.await; + } + + outcome?; + + println!("[relay-rmcp] pass"); + Ok(()) +} + +fn server_config(relay_url: &str) -> GatewayConfig { + GatewayConfig { + nostr_config: NostrServerTransportConfig { + relay_urls: vec![relay_url.to_string()], + encryption_mode: EncryptionMode::Optional, + server_info: Some(CtxServerInfo { + name: Some("rmcp-matrix-server".to_string()), + about: Some("rmcp matrix coverage server".to_string()), + ..Default::default() + }), + is_announced_server: false, + ..Default::default() + }, + } +} + +fn client_config(relay_url: &str, server_pubkey: String) -> ProxyConfig { + ProxyConfig { + nostr_config: NostrClientTransportConfig { + relay_urls: vec![relay_url.to_string()], + server_pubkey, + encryption_mode: EncryptionMode::Optional, + ..Default::default() + }, + } +} + +async fn send_legacy_request_and_wait( + proxy: &NostrMCPProxy, + rx: &mut tokio::sync::mpsc::UnboundedReceiver, + request: JsonRpcMessage, + expected_id: &serde_json::Value, +) -> Result { + proxy.send(&request).await?; + + loop { + let maybe_msg = timeout(IO_TIMEOUT, rx.recv()) + .await + .context("timed out waiting for legacy response")?; + + let msg = maybe_msg.ok_or_else(|| anyhow!("legacy response channel closed"))?; + + if msg.id() == Some(expected_id) { + return Ok(msg); + } + + if msg.is_notification() { + continue; + } + } +} + +fn extract_tools_list(response: &JsonRpcMessage) -> Result<&Vec> { + let JsonRpcMessage::Response(resp) = response else { + bail!("expected tools/list response, got {response:?}"); + }; + + resp.result + .get("tools") + .and_then(|v| v.as_array()) + .ok_or_else(|| anyhow!("tools/list response missing tools array")) +} + +fn extract_first_content_text(response: &JsonRpcMessage) -> Result { + let JsonRpcMessage::Response(resp) = response else { + bail!("expected tools/call response, got {response:?}"); + }; + + let text = resp + .result + .get("content") + .and_then(|v| v.as_array()) + .and_then(|items| items.first()) + .and_then(|item| item.get("text")) + .and_then(|text| text.as_str()) + .ok_or_else(|| anyhow!("tools/call response missing content[0].text"))?; + + Ok(text.to_string()) +} + +fn assert_initialize_shape(response: &JsonRpcMessage) -> Result<()> { + let JsonRpcMessage::Response(resp) = response else { + bail!("expected initialize response, got {response:?}"); + }; + + let protocol = resp + .result + .get("protocolVersion") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow!("initialize response missing protocolVersion"))?; + + if !is_supported_protocol(protocol) { + bail!( + "unexpected protocolVersion in initialize response: expected one of [{MCP_PROTOCOL_VERSION}, {}], got {protocol}", + ProtocolVersion::LATEST + ); + } + + if resp.result.get("serverInfo").is_none() { + bail!("initialize response missing serverInfo"); + } + + Ok(()) +} + +fn is_supported_protocol(protocol: &str) -> bool { + protocol == MCP_PROTOCOL_VERSION || protocol == ProtocolVersion::LATEST.to_string() +} + +fn assert_error_response(response: &JsonRpcMessage) -> Result<()> { + match response { + JsonRpcMessage::ErrorResponse(err) => { + if err.error.code >= 0 { + bail!("expected negative JSON-RPC error code, got {}", err.error.code); + } + Ok(()) + } + JsonRpcMessage::Response(resp) => { + if resp.result.get("isError") == Some(&serde_json::json!(true)) { + Ok(()) + } else { + bail!("expected error response but received success result") + } + } + _ => bail!("expected error response, got {response:?}"), + } +} + +fn call_params(name: &'static str, args: Option) -> CallToolRequestParams { + CallToolRequestParams { + name: name.into(), + arguments: args.and_then(|v| serde_json::from_value(v).ok()), + meta: None, + task: None, + } +} + +fn first_text(result: &CallToolResult) -> String { + result + .content + .iter() + .find_map(|content| { + if let RawContent::Text(t) = &content.raw { + Some(t.text.clone()) + } else { + None + } + }) + .unwrap_or_default() +} \ No newline at end of file From 61da0162424dd2071f34d8b45dcbbfa4139283d5 Mon Sep 17 00:00:00 2001 From: Akhil Dhyani Date: Sat, 4 Apr 2026 16:52:05 +0530 Subject: [PATCH 11/69] chore: cargo formatting fixes --- examples/discovery.rs | 3 +- examples/proxy.rs | 5 +- examples/rmcp_integration_test.rs | 71 ++-- src/core/constants.rs | 1 - src/core/serializers.rs | 2 +- src/discovery/mod.rs | 13 +- src/encryption/mod.rs | 18 +- src/gateway/mod.rs | 31 +- src/proxy/mod.rs | 21 +- src/rmcp_transport/mod.rs | 4 +- src/rmcp_transport/pipeline_tests.rs | 48 +-- src/rmcp_transport/worker.rs | 560 ++++++++++++++------------- src/transport/base.rs | 76 +++- src/transport/client.rs | 76 ++-- src/transport/server.rs | 254 +++++++----- 15 files changed, 668 insertions(+), 515 deletions(-) diff --git a/examples/discovery.rs b/examples/discovery.rs index 166cb3f..edff853 100644 --- a/examples/discovery.rs +++ b/examples/discovery.rs @@ -53,8 +53,7 @@ async fn main() -> contextvm_sdk::Result<()> { println!(" Resources: {} found", resources.len()); } - let prompts = - discovery::discover_prompts(client, &server.pubkey_parsed, &relays).await?; + let prompts = discovery::discover_prompts(client, &server.pubkey_parsed, &relays).await?; if !prompts.is_empty() { println!(" Prompts: {} found", prompts.len()); } diff --git a/examples/proxy.rs b/examples/proxy.rs index a10663a..c0fcd55 100644 --- a/examples/proxy.rs +++ b/examples/proxy.rs @@ -42,7 +42,10 @@ async fn main() -> contextvm_sdk::Result<()> { // Wait for response if let Some(response) = rx.recv().await { - println!("Response: {}", serde_json::to_string_pretty(&response).unwrap()); + println!( + "Response: {}", + serde_json::to_string_pretty(&response).unwrap() + ); } proxy.stop().await?; diff --git a/examples/rmcp_integration_test.rs b/examples/rmcp_integration_test.rs index 75990ba..c9656ef 100644 --- a/examples/rmcp_integration_test.rs +++ b/examples/rmcp_integration_test.rs @@ -19,7 +19,8 @@ use anyhow::{anyhow, bail, Context, Result}; use contextvm_sdk::core::constants::MCP_PROTOCOL_VERSION; use contextvm_sdk::core::types::{ - EncryptionMode, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, ServerInfo as CtxServerInfo, + EncryptionMode, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, + ServerInfo as CtxServerInfo, }; use contextvm_sdk::gateway::{GatewayConfig, NostrMCPGateway}; use contextvm_sdk::proxy::{NostrMCPProxy, ProxyConfig}; @@ -27,11 +28,8 @@ use contextvm_sdk::signer; use contextvm_sdk::transport::client::NostrClientTransportConfig; use contextvm_sdk::transport::server::NostrServerTransportConfig; use rmcp::{ - handler::server::wrapper::Parameters, - model::*, - schemars, - service::RequestContext, - tool, tool_handler, tool_router, ClientHandler, RoleServer, ServerHandler, ServiceExt, + handler::server::wrapper::Parameters, model::*, schemars, service::RequestContext, tool, + tool_handler, tool_router, ClientHandler, RoleServer, ServerHandler, ServiceExt, }; use std::sync::Arc; use std::time::Duration; @@ -58,9 +56,7 @@ impl Mode { "hybrid" => Ok(Self::Hybrid), "relay-rmcp" => Ok(Self::RelayRmcp), "all" => Ok(Self::All), - other => bail!( - "Unknown mode '{other}'. Use one of: local | hybrid | relay-rmcp | all" - ), + other => bail!("Unknown mode '{other}'. Use one of: local | hybrid | relay-rmcp | all"), } } @@ -168,7 +164,7 @@ impl ServerHandler for DemoServer { ) -> Result { Ok(ListResourcesResult { resources: vec![ - RawResource::new("demo://readme", "Demo README".to_string()).no_annotation(), + RawResource::new("demo://readme", "Demo README".to_string()).no_annotation() ], next_cursor: None, meta: None, @@ -273,7 +269,11 @@ async fn run_local_rmcp_case() -> Result<()> { assert!(add_text.contains("12"), "expected add result to include 12"); let resources = client.list_all_resources().await?; - assert_eq!(resources.len(), 1, "expected one resource in local rmcp case"); + assert_eq!( + resources.len(), + 1, + "expected one resource in local rmcp case" + ); match client.call_tool(call_params("no_such_tool", None)).await { Err(_) => {} @@ -303,9 +303,7 @@ async fn run_hybrid_relay_case(relay_url: &str) -> Result<()> { DemoServer::new(), ) .await - .with_context(|| { - format!("failed to start rmcp server on relay {relay_url_owned}") - })?; + .with_context(|| format!("failed to start rmcp server on relay {relay_url_owned}"))?; let _ = server .waiting() @@ -355,23 +353,23 @@ async fn run_hybrid_relay_case(relay_url: &str) -> Result<()> { .context("failed to start legacy proxy")?; println!("[relay-hybrid] stage: legacy proxy started"); - let init_id = serde_json::json!(1); - let init_request = JsonRpcMessage::Request(JsonRpcRequest { - jsonrpc: "2.0".to_string(), - id: init_id.clone(), - method: "initialize".to_string(), - params: Some(serde_json::json!({ - "protocolVersion": MCP_PROTOCOL_VERSION, - "capabilities": { - "tools": {}, - "resources": {} - }, - "clientInfo": { - "name": "legacy-hybrid-client", - "version": "0.1.0" - } - })), - }); + let init_id = serde_json::json!(1); + let init_request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: init_id.clone(), + method: "initialize".to_string(), + params: Some(serde_json::json!({ + "protocolVersion": MCP_PROTOCOL_VERSION, + "capabilities": { + "tools": {}, + "resources": {} + }, + "clientInfo": { + "name": "legacy-hybrid-client", + "version": "0.1.0" + } + })), + }); let init_response = send_legacy_request_and_wait(&proxy, &mut rx, init_request, &init_id).await?; @@ -472,9 +470,7 @@ async fn run_relay_rmcp_case(relay_url: &str) -> Result<()> { DemoServer::new(), ) .await - .with_context(|| { - format!("failed to start rmcp server on relay {relay_url_owned}") - })?; + .with_context(|| format!("failed to start rmcp server on relay {relay_url_owned}"))?; let _ = server .waiting() @@ -687,7 +683,10 @@ fn assert_error_response(response: &JsonRpcMessage) -> Result<()> { match response { JsonRpcMessage::ErrorResponse(err) => { if err.error.code >= 0 { - bail!("expected negative JSON-RPC error code, got {}", err.error.code); + bail!( + "expected negative JSON-RPC error code, got {}", + err.error.code + ); } Ok(()) } @@ -723,4 +722,4 @@ fn first_text(result: &CallToolResult) -> String { } }) .unwrap_or_default() -} \ No newline at end of file +} diff --git a/src/core/constants.rs b/src/core/constants.rs index 8aecf78..b7f8a8b 100644 --- a/src/core/constants.rs +++ b/src/core/constants.rs @@ -217,4 +217,3 @@ mod tests { assert!(!UNENCRYPTED_KINDS.contains(&EPHEMERAL_GIFT_WRAP_KIND)); } } - diff --git a/src/core/serializers.rs b/src/core/serializers.rs index 3d641fd..69c519c 100644 --- a/src/core/serializers.rs +++ b/src/core/serializers.rs @@ -50,7 +50,7 @@ pub fn get_tag_value_from_slice(tags: &[Tag], name: &str) -> Option { #[cfg(test)] mod tests { use super::*; - use crate::core::types::{JsonRpcRequest, JsonRpcMessage}; + use crate::core::types::{JsonRpcMessage, JsonRpcRequest}; #[test] fn test_roundtrip() { diff --git a/src/discovery/mod.rs b/src/discovery/mod.rs index e198674..d5ffa3a 100644 --- a/src/discovery/mod.rs +++ b/src/discovery/mod.rs @@ -64,8 +64,7 @@ pub async fn discover_servers( let mut announcements = Vec::new(); for event in events { - let server_info: ServerInfo = - serde_json::from_str(&event.content).unwrap_or_default(); + let server_info: ServerInfo = serde_json::from_str(&event.content).unwrap_or_default(); announcements.push(ServerAnnouncement { pubkey: event.pubkey.to_hex(), pubkey_parsed: event.pubkey, @@ -233,7 +232,10 @@ mod tests { assert_eq!(parsed.version, Some("1.0.0".to_string())); assert_eq!(parsed.about, Some("A test MCP server".to_string())); assert_eq!(parsed.website, Some("https://example.com".to_string())); - assert_eq!(parsed.picture, Some("https://example.com/pic.png".to_string())); + assert_eq!( + parsed.picture, + Some("https://example.com/pic.png".to_string()) + ); } #[test] @@ -280,11 +282,12 @@ mod tests { }, event_id: EventId::from_hex( "0000000000000000000000000000000000000000000000000000000000000001", - ).unwrap(), + ) + .unwrap(), created_at: Timestamp::now(), }; assert_eq!(announcement.pubkey, pubkey.to_hex()); assert_eq!(announcement.server_info.name, Some("Test".to_string())); } -} \ No newline at end of file +} diff --git a/src/encryption/mod.rs b/src/encryption/mod.rs index dadf331..d5171e0 100644 --- a/src/encryption/mod.rs +++ b/src/encryption/mod.rs @@ -43,10 +43,7 @@ where /// - The gift wrap event has NIP-44 encrypted content (single layer) /// - Decrypt using recipient's key + event's pubkey (ephemeral sender) /// - Returns the decrypted plaintext content string -pub async fn decrypt_gift_wrap_single_layer( - signer: &T, - event: &Event, -) -> Result +pub async fn decrypt_gift_wrap_single_layer(signer: &T, event: &Event) -> Result where T: NostrSigner, { @@ -73,8 +70,8 @@ where let encrypted = encrypt_nip44(&ephemeral, recipient, plaintext).await?; - let builder = EventBuilder::new(Kind::Custom(GIFT_WRAP_KIND), encrypted) - .tag(Tag::public_key(*recipient)); + let builder = + EventBuilder::new(Kind::Custom(GIFT_WRAP_KIND), encrypted).tag(Tag::public_key(*recipient)); builder .sign_with_keys(&ephemeral) @@ -142,10 +139,7 @@ mod tests { /// 2. NIP-44 encrypt the plaintext using ephemeral_secret + recipient_pubkey /// 3. Build kind 1059 event with encrypted content, `p` tag = recipient /// 4. Sign with ephemeral key - async fn create_js_style_gift_wrap( - plaintext: &str, - recipient: &PublicKey, - ) -> (Event, Keys) { + async fn create_js_style_gift_wrap(plaintext: &str, recipient: &PublicKey) -> (Event, Keys) { let ephemeral = Keys::generate(); // Single-layer NIP-44 encrypt @@ -154,8 +148,8 @@ mod tests { .unwrap(); // Build kind 1059 event - let builder = EventBuilder::new(Kind::Custom(1059), encrypted) - .tag(Tag::public_key(*recipient)); + let builder = + EventBuilder::new(Kind::Custom(1059), encrypted).tag(Tag::public_key(*recipient)); let event = builder.sign_with_keys(&ephemeral).unwrap(); (event, ephemeral) diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 345f2b3..427b48f 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -40,9 +40,7 @@ impl NostrMCPGateway { /// /// The caller is responsible for processing requests and calling /// `send_response` for each one. - pub async fn start( - &mut self, - ) -> Result> { + pub async fn start(&mut self) -> Result> { if self.is_running { return Err(Error::Other("Gateway already running".to_string())); } @@ -133,11 +131,27 @@ mod tests { let config = GatewayConfig { nostr_config }; - assert_eq!(config.nostr_config.relay_urls, vec!["wss://relay.example.com"]); - assert_eq!(config.nostr_config.encryption_mode, EncryptionMode::Required); + assert_eq!( + config.nostr_config.relay_urls, + vec!["wss://relay.example.com"] + ); + assert_eq!( + config.nostr_config.encryption_mode, + EncryptionMode::Required + ); assert!(config.nostr_config.is_announced_server); assert_eq!(config.nostr_config.allowed_public_keys.len(), 1); - assert!(config.nostr_config.server_info.as_ref().unwrap().name.as_ref().unwrap() == "Test Gateway"); + assert!( + config + .nostr_config + .server_info + .as_ref() + .unwrap() + .name + .as_ref() + .unwrap() + == "Test Gateway" + ); } #[test] @@ -145,7 +159,10 @@ mod tests { let config = GatewayConfig { nostr_config: NostrServerTransportConfig::default(), }; - assert_eq!(config.nostr_config.encryption_mode, EncryptionMode::Optional); + assert_eq!( + config.nostr_config.encryption_mode, + EncryptionMode::Optional + ); assert!(!config.nostr_config.is_announced_server); } } diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index d7b75f5..25fc2a8 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -34,9 +34,7 @@ impl NostrMCPProxy { } /// Start the proxy. Returns a receiver for incoming responses/notifications. - pub async fn start( - &mut self, - ) -> Result> { + pub async fn start(&mut self) -> Result> { if self.is_running { return Err(Error::Other("Proxy already running".to_string())); } @@ -118,9 +116,15 @@ mod tests { let config = ProxyConfig { nostr_config }; - assert_eq!(config.nostr_config.relay_urls, vec!["wss://relay.example.com"]); + assert_eq!( + config.nostr_config.relay_urls, + vec!["wss://relay.example.com"] + ); assert_eq!(config.nostr_config.server_pubkey, server_pubkey); - assert_eq!(config.nostr_config.encryption_mode, EncryptionMode::Required); + assert_eq!( + config.nostr_config.encryption_mode, + EncryptionMode::Required + ); assert!(config.nostr_config.is_stateless); assert_eq!(config.nostr_config.timeout, Duration::from_secs(60)); } @@ -131,6 +135,9 @@ mod tests { nostr_config: NostrClientTransportConfig::default(), }; assert!(!config.nostr_config.is_stateless); - assert_eq!(config.nostr_config.encryption_mode, EncryptionMode::Optional); + assert_eq!( + config.nostr_config.encryption_mode, + EncryptionMode::Optional + ); } -} \ No newline at end of file +} diff --git a/src/rmcp_transport/mod.rs b/src/rmcp_transport/mod.rs index 44a2ab1..57919b5 100644 --- a/src/rmcp_transport/mod.rs +++ b/src/rmcp_transport/mod.rs @@ -9,7 +9,7 @@ pub mod worker; mod pipeline_tests; pub use convert::{ - internal_to_rmcp_client_rx, internal_to_rmcp_server_rx, rmcp_client_tx_to_internal, - rmcp_server_tx_to_internal, + internal_to_rmcp_client_rx, internal_to_rmcp_server_rx, rmcp_client_tx_to_internal, + rmcp_server_tx_to_internal, }; pub use worker::{NostrClientWorker, NostrServerWorker}; diff --git a/src/rmcp_transport/pipeline_tests.rs b/src/rmcp_transport/pipeline_tests.rs index d8da58d..a1a6859 100644 --- a/src/rmcp_transport/pipeline_tests.rs +++ b/src/rmcp_transport/pipeline_tests.rs @@ -16,13 +16,18 @@ mod tests { use std::collections::HashMap; - use rmcp::model::{ClientJsonRpcMessage, ClientResult, RequestId, ServerJsonRpcMessage, ServerResult}; + use rmcp::model::{ + ClientJsonRpcMessage, ClientResult, RequestId, ServerJsonRpcMessage, ServerResult, + }; use crate::core::serializers; - use crate::core::types::{JsonRpcError, JsonRpcErrorResponse, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse}; + use crate::core::types::{ + JsonRpcError, JsonRpcErrorResponse, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, + JsonRpcResponse, + }; use crate::rmcp_transport::convert::{ - internal_to_rmcp_client_rx, internal_to_rmcp_server_rx, - rmcp_client_tx_to_internal, rmcp_server_tx_to_internal, + internal_to_rmcp_client_rx, internal_to_rmcp_server_rx, rmcp_client_tx_to_internal, + rmcp_server_tx_to_internal, }; // ── Layer 1: Nostr event content → JsonRpcMessage ────────────────────── @@ -110,8 +115,8 @@ mod tests { method: "notifications/initialized".to_string(), params: None, }); - let rmcp = internal_to_rmcp_server_rx(&msg) - .expect("initialized notification should convert"); + let rmcp = + internal_to_rmcp_server_rx(&msg).expect("initialized notification should convert"); let v = serde_json::to_value(&rmcp).unwrap(); assert_eq!(v["method"], "notifications/initialized"); } @@ -134,12 +139,10 @@ mod tests { #[test] fn layer4_rmcp_ping_response_roundtrip_number_id() { // Simulate rmcp handler producing a ping response - let rmcp_response = ServerJsonRpcMessage::response( - ServerResult::empty(()), - RequestId::Number(42), - ); - let internal = rmcp_server_tx_to_internal(rmcp_response) - .expect("ping response should convert back"); + let rmcp_response = + ServerJsonRpcMessage::response(ServerResult::empty(()), RequestId::Number(42)); + let internal = + rmcp_server_tx_to_internal(rmcp_response).expect("ping response should convert back"); match internal { JsonRpcMessage::Response(r) => { @@ -180,10 +183,8 @@ mod tests { assert_eq!(id_seen_by_rmcp, serde_json::json!(99)); // Layer 4: rmcp produces a response with the same ID echoed back - let rmcp_tx = ServerJsonRpcMessage::response( - ServerResult::empty(()), - RequestId::Number(99), - ); + let rmcp_tx = + ServerJsonRpcMessage::response(ServerResult::empty(()), RequestId::Number(99)); let response = rmcp_server_tx_to_internal(rmcp_tx).unwrap(); // The response ID must equal the original request ID @@ -193,10 +194,7 @@ mod tests { #[test] fn full_client_roundtrip_response_id_preserved() { // Client side: rmcp produces an outbound request - let rmcp_tx = ClientJsonRpcMessage::response( - ClientResult::empty(()), - RequestId::Number(7), - ); + let rmcp_tx = ClientJsonRpcMessage::response(ClientResult::empty(()), RequestId::Number(7)); let internal = rmcp_client_tx_to_internal(rmcp_tx).unwrap(); assert_eq!(internal.id(), Some(&serde_json::json!(7))); @@ -268,13 +266,19 @@ mod tests { // Key derived from response ID must match the one stored from request ID assert_eq!(key, resp_key); - assert_eq!(request_id_to_event_id.remove(&resp_key), Some(fake_event_id)); + assert_eq!( + request_id_to_event_id.remove(&resp_key), + Some(fake_event_id) + ); } #[test] fn layer5_error_response_correlation_works() { let mut map: HashMap = HashMap::new(); - map.insert(serde_json::to_string(&serde_json::json!(5)).unwrap(), "evt5".to_string()); + map.insert( + serde_json::to_string(&serde_json::json!(5)).unwrap(), + "evt5".to_string(), + ); let error_response = JsonRpcMessage::ErrorResponse(JsonRpcErrorResponse { jsonrpc: "2.0".to_string(), diff --git a/src/rmcp_transport/worker.rs b/src/rmcp_transport/worker.rs index 72f729f..0d52db0 100644 --- a/src/rmcp_transport/worker.rs +++ b/src/rmcp_transport/worker.rs @@ -12,310 +12,316 @@ use crate::transport::server::{NostrServerTransport, NostrServerTransportConfig} use rmcp::transport::worker::{Worker, WorkerContext, WorkerQuitReason}; use super::convert::{ - internal_to_rmcp_client_rx, internal_to_rmcp_server_rx, rmcp_client_tx_to_internal, - rmcp_server_tx_to_internal, + internal_to_rmcp_client_rx, internal_to_rmcp_server_rx, rmcp_client_tx_to_internal, + rmcp_server_tx_to_internal, }; /// rmcp server worker wrapper for ContextVM Nostr server transport. pub struct NostrServerWorker { - transport: NostrServerTransport, - // rmcp service instance is single-peer. Keep one active client per worker. - active_client_pubkey: Option, - // Maps request id (serialized JSON value) -> incoming Nostr event id. - request_id_to_event_id: HashMap, + transport: NostrServerTransport, + // rmcp service instance is single-peer. Keep one active client per worker. + active_client_pubkey: Option, + // Maps request id (serialized JSON value) -> incoming Nostr event id. + request_id_to_event_id: HashMap, } impl NostrServerWorker { - /// Create a new server worker from existing server transport config. - pub async fn new(signer: T, config: NostrServerTransportConfig) -> Result - where - T: nostr_sdk::prelude::IntoNostrSigner, - { - let transport = NostrServerTransport::new(signer, config).await?; - Ok(Self { - transport, - active_client_pubkey: None, - request_id_to_event_id: HashMap::new(), - }) - } - - /// Access the wrapped transport. - pub fn transport(&self) -> &NostrServerTransport { - &self.transport - } + /// Create a new server worker from existing server transport config. + pub async fn new(signer: T, config: NostrServerTransportConfig) -> Result + where + T: nostr_sdk::prelude::IntoNostrSigner, + { + let transport = NostrServerTransport::new(signer, config).await?; + Ok(Self { + transport, + active_client_pubkey: None, + request_id_to_event_id: HashMap::new(), + }) + } + + /// Access the wrapped transport. + pub fn transport(&self) -> &NostrServerTransport { + &self.transport + } } impl Worker for NostrServerWorker { - type Error = crate::core::error::Error; - type Role = rmcp::RoleServer; - - fn err_closed() -> Self::Error { - Self::Error::Transport("rmcp worker channel closed".to_string()) - } - - fn err_join(e: tokio::task::JoinError) -> Self::Error { - Self::Error::Other(format!("rmcp worker join error: {e}")) - } - - async fn run( - mut self, - mut context: WorkerContext, - ) -> std::result::Result<(), WorkerQuitReason> { - self.transport - .start() - .await - .map_err(WorkerQuitReason::fatal_context("starting server transport"))?; - - let mut rx = self.transport.take_message_receiver().ok_or_else(|| { - WorkerQuitReason::fatal( - Self::Error::Other("server message receiver already taken".to_string()), - "taking server message receiver", - ) - })?; - - let cancellation_token = context.cancellation_token.clone(); - - let quit_reason = loop { - tokio::select! { - _ = cancellation_token.cancelled() => { - break WorkerQuitReason::Cancelled; - } - incoming = rx.recv() => { - let Some(incoming) = incoming else { - break WorkerQuitReason::TransportClosed; - }; - - let crate::transport::server::IncomingRequest { - message, - client_pubkey, - event_id, - .. - } = incoming; - - match &self.active_client_pubkey { - Some(active) if active != &client_pubkey => { - tracing::warn!( - active_client = %active, - ignored_client = %client_pubkey, - "Ignoring message from second client: rmcp server worker currently supports one active client per worker" - ); - continue; - } - None => { - tracing::info!(client_pubkey = %client_pubkey, "Binding rmcp server worker to first client session"); - self.active_client_pubkey = Some(client_pubkey.clone()); - } - _ => {} - } - - if let JsonRpcMessage::Request(req) = &message { - match serde_json::to_string(&req.id) { - Ok(request_key) => { - self.request_id_to_event_id.insert(request_key, event_id); - } - Err(e) => { - tracing::warn!("Failed to serialize request id for correlation map: {e}"); - } - } - } - - if let Some(rmcp_msg) = internal_to_rmcp_server_rx(&message) { - if let Err(reason) = context.send_to_handler(rmcp_msg).await { - break reason; - } - } else { - tracing::warn!("Failed to convert incoming server-side message to rmcp format"); - } - } - outbound = context.recv_from_handler() => { - let outbound = match outbound { - Ok(outbound) => outbound, - Err(reason) => break reason, - }; - - let result = if let Some(internal_msg) = rmcp_server_tx_to_internal(outbound.message) { - self.forward_server_internal(internal_msg).await - } else { - Err(Self::Error::Validation( - "failed converting rmcp server message to internal JSON-RPC".to_string(), - )) - }; - - let _ = outbound.responder.send(result); - } - } - }; - - if let Err(e) = self.transport.close().await { - tracing::warn!("Failed to close server transport cleanly: {e}"); - } - - Err(quit_reason) - } + type Error = crate::core::error::Error; + type Role = rmcp::RoleServer; + + fn err_closed() -> Self::Error { + Self::Error::Transport("rmcp worker channel closed".to_string()) + } + + fn err_join(e: tokio::task::JoinError) -> Self::Error { + Self::Error::Other(format!("rmcp worker join error: {e}")) + } + + async fn run( + mut self, + mut context: WorkerContext, + ) -> std::result::Result<(), WorkerQuitReason> { + self.transport + .start() + .await + .map_err(WorkerQuitReason::fatal_context("starting server transport"))?; + + let mut rx = self.transport.take_message_receiver().ok_or_else(|| { + WorkerQuitReason::fatal( + Self::Error::Other("server message receiver already taken".to_string()), + "taking server message receiver", + ) + })?; + + let cancellation_token = context.cancellation_token.clone(); + + let quit_reason = loop { + tokio::select! { + _ = cancellation_token.cancelled() => { + break WorkerQuitReason::Cancelled; + } + incoming = rx.recv() => { + let Some(incoming) = incoming else { + break WorkerQuitReason::TransportClosed; + }; + + let crate::transport::server::IncomingRequest { + message, + client_pubkey, + event_id, + .. + } = incoming; + + match &self.active_client_pubkey { + Some(active) if active != &client_pubkey => { + tracing::warn!( + active_client = %active, + ignored_client = %client_pubkey, + "Ignoring message from second client: rmcp server worker currently supports one active client per worker" + ); + continue; + } + None => { + tracing::info!(client_pubkey = %client_pubkey, "Binding rmcp server worker to first client session"); + self.active_client_pubkey = Some(client_pubkey.clone()); + } + _ => {} + } + + if let JsonRpcMessage::Request(req) = &message { + match serde_json::to_string(&req.id) { + Ok(request_key) => { + self.request_id_to_event_id.insert(request_key, event_id); + } + Err(e) => { + tracing::warn!("Failed to serialize request id for correlation map: {e}"); + } + } + } + + if let Some(rmcp_msg) = internal_to_rmcp_server_rx(&message) { + if let Err(reason) = context.send_to_handler(rmcp_msg).await { + break reason; + } + } else { + tracing::warn!("Failed to convert incoming server-side message to rmcp format"); + } + } + outbound = context.recv_from_handler() => { + let outbound = match outbound { + Ok(outbound) => outbound, + Err(reason) => break reason, + }; + + let result = if let Some(internal_msg) = rmcp_server_tx_to_internal(outbound.message) { + self.forward_server_internal(internal_msg).await + } else { + Err(Self::Error::Validation( + "failed converting rmcp server message to internal JSON-RPC".to_string(), + )) + }; + + let _ = outbound.responder.send(result); + } + } + }; + + if let Err(e) = self.transport.close().await { + tracing::warn!("Failed to close server transport cleanly: {e}"); + } + + Err(quit_reason) + } } /// rmcp client worker wrapper for ContextVM Nostr client transport. pub struct NostrClientWorker { - transport: NostrClientTransport, + transport: NostrClientTransport, } impl NostrClientWorker { - /// Create a new client worker from existing client transport config. - pub async fn new(signer: T, config: NostrClientTransportConfig) -> Result - where - T: nostr_sdk::prelude::IntoNostrSigner, - { - let transport = NostrClientTransport::new(signer, config).await?; - Ok(Self { transport }) - } - - /// Access the wrapped transport. - pub fn transport(&self) -> &NostrClientTransport { - &self.transport - } + /// Create a new client worker from existing client transport config. + pub async fn new(signer: T, config: NostrClientTransportConfig) -> Result + where + T: nostr_sdk::prelude::IntoNostrSigner, + { + let transport = NostrClientTransport::new(signer, config).await?; + Ok(Self { transport }) + } + + /// Access the wrapped transport. + pub fn transport(&self) -> &NostrClientTransport { + &self.transport + } } impl Worker for NostrClientWorker { - type Error = crate::core::error::Error; - type Role = rmcp::RoleClient; - - fn err_closed() -> Self::Error { - Self::Error::Transport("rmcp worker channel closed".to_string()) - } - - fn err_join(e: tokio::task::JoinError) -> Self::Error { - Self::Error::Other(format!("rmcp worker join error: {e}")) - } - - async fn run( - mut self, - mut context: WorkerContext, - ) -> std::result::Result<(), WorkerQuitReason> { - self.transport - .start() - .await - .map_err(WorkerQuitReason::fatal_context("starting client transport"))?; - - let mut rx = self.transport.take_message_receiver().ok_or_else(|| { - WorkerQuitReason::fatal( - Self::Error::Other("client message receiver already taken".to_string()), - "taking client message receiver", - ) - })?; - - let cancellation_token = context.cancellation_token.clone(); - - let quit_reason = loop { - tokio::select! { - _ = cancellation_token.cancelled() => { - break WorkerQuitReason::Cancelled; - } - incoming = rx.recv() => { - let Some(incoming) = incoming else { - break WorkerQuitReason::TransportClosed; - }; - - if let Some(rmcp_msg) = internal_to_rmcp_client_rx(&incoming) { - if let Err(reason) = context.send_to_handler(rmcp_msg).await { - break reason; - } - } else { - tracing::warn!("Failed to convert incoming client-side message to rmcp format"); - } - } - outbound = context.recv_from_handler() => { - let outbound = match outbound { - Ok(outbound) => outbound, - Err(reason) => break reason, - }; - - let result = if let Some(internal_msg) = rmcp_client_tx_to_internal(outbound.message) { - self.transport.send(&internal_msg).await - } else { - Err(Self::Error::Validation( - "failed converting rmcp client message to internal JSON-RPC".to_string(), - )) - }; - - let _ = outbound.responder.send(result); - } - } - }; - - if let Err(e) = self.transport.close().await { - tracing::warn!("Failed to close client transport cleanly: {e}"); - } - - Err(quit_reason) - } + type Error = crate::core::error::Error; + type Role = rmcp::RoleClient; + + fn err_closed() -> Self::Error { + Self::Error::Transport("rmcp worker channel closed".to_string()) + } + + fn err_join(e: tokio::task::JoinError) -> Self::Error { + Self::Error::Other(format!("rmcp worker join error: {e}")) + } + + async fn run( + mut self, + mut context: WorkerContext, + ) -> std::result::Result<(), WorkerQuitReason> { + self.transport + .start() + .await + .map_err(WorkerQuitReason::fatal_context("starting client transport"))?; + + let mut rx = self.transport.take_message_receiver().ok_or_else(|| { + WorkerQuitReason::fatal( + Self::Error::Other("client message receiver already taken".to_string()), + "taking client message receiver", + ) + })?; + + let cancellation_token = context.cancellation_token.clone(); + + let quit_reason = loop { + tokio::select! { + _ = cancellation_token.cancelled() => { + break WorkerQuitReason::Cancelled; + } + incoming = rx.recv() => { + let Some(incoming) = incoming else { + break WorkerQuitReason::TransportClosed; + }; + + if let Some(rmcp_msg) = internal_to_rmcp_client_rx(&incoming) { + if let Err(reason) = context.send_to_handler(rmcp_msg).await { + break reason; + } + } else { + tracing::warn!("Failed to convert incoming client-side message to rmcp format"); + } + } + outbound = context.recv_from_handler() => { + let outbound = match outbound { + Ok(outbound) => outbound, + Err(reason) => break reason, + }; + + let result = if let Some(internal_msg) = rmcp_client_tx_to_internal(outbound.message) { + self.transport.send(&internal_msg).await + } else { + Err(Self::Error::Validation( + "failed converting rmcp client message to internal JSON-RPC".to_string(), + )) + }; + + let _ = outbound.responder.send(result); + } + } + }; + + if let Err(e) = self.transport.close().await { + tracing::warn!("Failed to close client transport cleanly: {e}"); + } + + Err(quit_reason) + } } impl NostrServerWorker { - async fn forward_server_internal(&mut self, message: JsonRpcMessage) -> Result<()> { - match message { - JsonRpcMessage::Response(resp) => { - let request_key = serde_json::to_string(&resp.id).map_err(|e| { - crate::core::error::Error::Validation(format!( - "failed to serialize rmcp response id for correlation lookup: {e}" - )) - })?; - - let event_id = if let Some(event_id) = self.request_id_to_event_id.remove(&request_key) { - event_id - } else { - resp.id.as_str().map(str::to_owned).ok_or_else(|| { - crate::core::error::Error::Validation( + async fn forward_server_internal(&mut self, message: JsonRpcMessage) -> Result<()> { + match message { + JsonRpcMessage::Response(resp) => { + let request_key = serde_json::to_string(&resp.id).map_err(|e| { + crate::core::error::Error::Validation(format!( + "failed to serialize rmcp response id for correlation lookup: {e}" + )) + })?; + + let event_id = + if let Some(event_id) = self.request_id_to_event_id.remove(&request_key) { + event_id + } else { + resp.id.as_str().map(str::to_owned).ok_or_else(|| { + crate::core::error::Error::Validation( "rmcp server response id has no known correlation mapping and is not a string event id" .to_string(), ) - })? - }; - - self.transport - .send_response(&event_id, JsonRpcMessage::Response(resp)) - .await - } - JsonRpcMessage::ErrorResponse(resp) => { - let request_key = serde_json::to_string(&resp.id).map_err(|e| { - crate::core::error::Error::Validation(format!( - "failed to serialize rmcp error response id for correlation lookup: {e}" - )) - })?; - - let event_id = if let Some(event_id) = self.request_id_to_event_id.remove(&request_key) { - event_id - } else { - resp.id.as_str().map(str::to_owned).ok_or_else(|| { - crate::core::error::Error::Validation( + })? + }; + + self.transport + .send_response(&event_id, JsonRpcMessage::Response(resp)) + .await + } + JsonRpcMessage::ErrorResponse(resp) => { + let request_key = serde_json::to_string(&resp.id).map_err(|e| { + crate::core::error::Error::Validation(format!( + "failed to serialize rmcp error response id for correlation lookup: {e}" + )) + })?; + + let event_id = + if let Some(event_id) = self.request_id_to_event_id.remove(&request_key) { + event_id + } else { + resp.id.as_str().map(str::to_owned).ok_or_else(|| { + crate::core::error::Error::Validation( "rmcp server error response id has no known correlation mapping and is not a string event id" .to_string(), ) - })? - }; - - self.transport - .send_response(&event_id, JsonRpcMessage::ErrorResponse(resp)) - .await - } - JsonRpcMessage::Notification(notification) => { - let target = self.active_client_pubkey.as_deref().ok_or_else(|| { - crate::core::error::Error::Validation( - "cannot forward rmcp server notification: no active client bound" - .to_string(), - ) - })?; - let message = JsonRpcMessage::Notification(notification); - self.transport.send_notification(target, &message, None).await - } - JsonRpcMessage::Request(request) => { - let target = self.active_client_pubkey.as_deref().ok_or_else(|| { - crate::core::error::Error::Validation( - "cannot forward rmcp server request: no active client bound".to_string(), - ) - })?; - let message = JsonRpcMessage::Request(request); - self.transport.send_notification(target, &message, None).await - } - } - } + })? + }; + + self.transport + .send_response(&event_id, JsonRpcMessage::ErrorResponse(resp)) + .await + } + JsonRpcMessage::Notification(notification) => { + let target = self.active_client_pubkey.as_deref().ok_or_else(|| { + crate::core::error::Error::Validation( + "cannot forward rmcp server notification: no active client bound" + .to_string(), + ) + })?; + let message = JsonRpcMessage::Notification(notification); + self.transport + .send_notification(target, &message, None) + .await + } + JsonRpcMessage::Request(request) => { + let target = self.active_client_pubkey.as_deref().ok_or_else(|| { + crate::core::error::Error::Validation( + "cannot forward rmcp server request: no active client bound".to_string(), + ) + })?; + let message = JsonRpcMessage::Request(request); + self.transport + .send_notification(target, &message, None) + .await + } + } + } } diff --git a/src/transport/base.rs b/src/transport/base.rs index f425f75..419a4b6 100644 --- a/src/transport/base.rs +++ b/src/transport/base.rs @@ -127,13 +127,16 @@ impl BaseTransport { if should_encrypt { // Single-layer gift wrap: JSON.stringify(signedEvent) → NIP-44 encrypt // This matches the JS/TS SDK's encryptMessage(JSON.stringify(event), recipient) - let event_json = serde_json::to_string(&event) + let event_json = + serde_json::to_string(&event).map_err(|e| Error::Encryption(e.to_string()))?; + let signer = self + .relay_pool + .client() + .signer() + .await .map_err(|e| Error::Encryption(e.to_string()))?; - let signer = self.relay_pool.client().signer().await - .map_err(|e| Error::Encryption(e.to_string()))?; - let gift_wrap_event = encryption::gift_wrap_single_layer( - &signer, recipient, &event_json, - ).await?; + let gift_wrap_event = + encryption::gift_wrap_single_layer(&signer, recipient, &event_json).await?; self.relay_pool.publish_event(&gift_wrap_event).await?; tracing::debug!( signed_event_id = %signed_event_id, @@ -192,24 +195,60 @@ mod tests { #[test] fn test_should_encrypt_disabled_mode() { - assert!(!should_encrypt(EncryptionMode::Disabled, CTXVM_MESSAGES_KIND, None)); - assert!(!should_encrypt(EncryptionMode::Disabled, CTXVM_MESSAGES_KIND, Some(true))); - assert!(!should_encrypt(EncryptionMode::Disabled, CTXVM_MESSAGES_KIND, Some(false))); + assert!(!should_encrypt( + EncryptionMode::Disabled, + CTXVM_MESSAGES_KIND, + None + )); + assert!(!should_encrypt( + EncryptionMode::Disabled, + CTXVM_MESSAGES_KIND, + Some(true) + )); + assert!(!should_encrypt( + EncryptionMode::Disabled, + CTXVM_MESSAGES_KIND, + Some(false) + )); } #[test] fn test_should_encrypt_required_mode() { - assert!(should_encrypt(EncryptionMode::Required, CTXVM_MESSAGES_KIND, None)); - assert!(should_encrypt(EncryptionMode::Required, CTXVM_MESSAGES_KIND, Some(false))); - assert!(should_encrypt(EncryptionMode::Required, CTXVM_MESSAGES_KIND, Some(true))); + assert!(should_encrypt( + EncryptionMode::Required, + CTXVM_MESSAGES_KIND, + None + )); + assert!(should_encrypt( + EncryptionMode::Required, + CTXVM_MESSAGES_KIND, + Some(false) + )); + assert!(should_encrypt( + EncryptionMode::Required, + CTXVM_MESSAGES_KIND, + Some(true) + )); } #[test] fn test_should_encrypt_optional_mode() { // Default (None) → true - assert!(should_encrypt(EncryptionMode::Optional, CTXVM_MESSAGES_KIND, None)); - assert!(should_encrypt(EncryptionMode::Optional, CTXVM_MESSAGES_KIND, Some(true))); - assert!(!should_encrypt(EncryptionMode::Optional, CTXVM_MESSAGES_KIND, Some(false))); + assert!(should_encrypt( + EncryptionMode::Optional, + CTXVM_MESSAGES_KIND, + None + )); + assert!(should_encrypt( + EncryptionMode::Optional, + CTXVM_MESSAGES_KIND, + Some(true) + )); + assert!(!should_encrypt( + EncryptionMode::Optional, + CTXVM_MESSAGES_KIND, + Some(false) + )); } #[test] @@ -237,10 +276,9 @@ mod tests { let keys = Keys::generate(); let pubkey = keys.public_key(); // Create a dummy event ID - let event_id = EventId::from_hex( - "0000000000000000000000000000000000000000000000000000000000000001", - ) - .unwrap(); + let event_id = + EventId::from_hex("0000000000000000000000000000000000000000000000000000000000000001") + .unwrap(); let tags = BaseTransport::create_response_tags(&pubkey, &event_id); assert_eq!(tags.len(), 2); diff --git a/src/transport/client.rs b/src/transport/client.rs index fbe4c42..5da3555 100644 --- a/src/transport/client.rs +++ b/src/transport/client.rs @@ -18,7 +18,6 @@ use crate::encryption; use crate::relay::RelayPool; use crate::transport::base::BaseTransport; - /// Configuration for the client transport. pub struct NostrClientTransportConfig { /// Relay URLs to connect to. @@ -131,7 +130,13 @@ impl NostrClientTransport { let tags = BaseTransport::create_recipient_tags(&self.server_pubkey); let event_id = self .base - .send_mcp_message(message, &self.server_pubkey, CTXVM_MESSAGES_KIND, tags, None) + .send_mcp_message( + message, + &self.server_pubkey, + CTXVM_MESSAGES_KIND, + tags, + None, + ) .await?; if matches!(message, JsonRpcMessage::Request(_)) { @@ -183,40 +188,40 @@ impl NostrClientTransport { while let Ok(notification) = notifications.recv().await { if let RelayPoolNotification::Event { event, .. } = notification { // Handle gift-wrapped events - let (actual_event_content, actual_pubkey, e_tag) = - if event.kind == Kind::Custom(GIFT_WRAP_KIND) - || event.kind == Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND) - { - // Single-layer NIP-44 decrypt (matches JS/TS SDK) - let signer = match client.signer().await { - Ok(s) => s, - Err(e) => { - tracing::error!("Failed to get signer: {e}"); - continue; - } - }; - match encryption::decrypt_gift_wrap_single_layer(&signer, &event).await { - Ok(decrypted_json) => { - match serde_json::from_str::(&decrypted_json) { - Ok(inner) => { - let e_tag = serializers::get_tag_value(&inner.tags, "e"); - (inner.content, inner.pubkey, e_tag) - } - Err(e) => { - tracing::error!("Failed to parse inner event: {e}"); - continue; - } + let (actual_event_content, actual_pubkey, e_tag) = if event.kind + == Kind::Custom(GIFT_WRAP_KIND) + || event.kind == Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND) + { + // Single-layer NIP-44 decrypt (matches JS/TS SDK) + let signer = match client.signer().await { + Ok(s) => s, + Err(e) => { + tracing::error!("Failed to get signer: {e}"); + continue; + } + }; + match encryption::decrypt_gift_wrap_single_layer(&signer, &event).await { + Ok(decrypted_json) => { + match serde_json::from_str::(&decrypted_json) { + Ok(inner) => { + let e_tag = serializers::get_tag_value(&inner.tags, "e"); + (inner.content, inner.pubkey, e_tag) + } + Err(e) => { + tracing::error!("Failed to parse inner event: {e}"); + continue; } - } - Err(e) => { - tracing::error!("Failed to decrypt gift wrap: {e}"); - continue; } } - } else { - let e_tag = serializers::get_tag_value(&event.tags, "e"); - (event.content.clone(), event.pubkey, e_tag) - }; + Err(e) => { + tracing::error!("Failed to decrypt gift wrap: {e}"); + continue; + } + } + } else { + let e_tag = serializers::get_tag_value(&event.tags, "e"); + (event.content.clone(), event.pubkey, e_tag) + }; // Verify it's from our server if actual_pubkey != server_pubkey { @@ -298,7 +303,10 @@ mod tests { assert!(r.result.get("capabilities").is_some()); assert!(r.result.get("serverInfo").is_some()); let server_info = r.result.get("serverInfo").unwrap(); - assert_eq!(server_info.get("name").unwrap().as_str().unwrap(), "Emulated-Stateless-Server"); + assert_eq!( + server_info.get("name").unwrap().as_str().unwrap(), + "Emulated-Stateless-Server" + ); } } diff --git a/src/transport/server.rs b/src/transport/server.rs index d45e1f9..4ccbe15 100644 --- a/src/transport/server.rs +++ b/src/transport/server.rs @@ -122,7 +122,16 @@ impl NostrServerTransport { let encryption_mode = self.config.encryption_mode; tokio::spawn(async move { - Self::event_loop(client, sessions, event_to_client, tx, allowed, excluded, encryption_mode).await; + Self::event_loop( + client, + sessions, + event_to_client, + tx, + allowed, + excluded, + encryption_mode, + ) + .await; }); // Spawn session cleanup @@ -159,11 +168,7 @@ impl NostrServerTransport { } /// Send a response back to the client that sent the original request. - pub async fn send_response( - &self, - event_id: &str, - mut response: JsonRpcMessage, - ) -> Result<()> { + pub async fn send_response(&self, event_id: &str, mut response: JsonRpcMessage) -> Result<()> { let event_to_client = self.event_to_client.read().await; let client_pubkey_hex = event_to_client .get(event_id) @@ -188,8 +193,8 @@ impl NostrServerTransport { let is_encrypted = session.is_encrypted; drop(sessions); - let client_pubkey = PublicKey::from_hex(&client_pubkey_hex) - .map_err(|e| Error::Other(e.to_string()))?; + let client_pubkey = + PublicKey::from_hex(&client_pubkey_hex).map_err(|e| Error::Other(e.to_string()))?; let event_id_parsed = EventId::from_hex(event_id).map_err(|e| Error::Other(e.to_string()))?; @@ -236,8 +241,8 @@ impl NostrServerTransport { let is_encrypted = session.is_encrypted; drop(sessions); - let client_pubkey = PublicKey::from_hex(client_pubkey_hex) - .map_err(|e| Error::Other(e.to_string()))?; + let client_pubkey = + PublicKey::from_hex(client_pubkey_hex).map_err(|e| Error::Other(e.to_string()))?; let mut tags = BaseTransport::create_recipient_tags(&client_pubkey); if let Some(eid) = correlated_event_id { @@ -329,8 +334,7 @@ impl NostrServerTransport { )); } - let builder = - EventBuilder::new(Kind::Custom(SERVER_ANNOUNCEMENT_KIND), content).tags(tags); + let builder = EventBuilder::new(Kind::Custom(SERVER_ANNOUNCEMENT_KIND), content).tags(tags); self.base.relay_pool.publish(builder).await } @@ -385,11 +389,10 @@ impl NostrServerTransport { let _pubkey_hex = pubkey.to_hex(); for kind in UNENCRYPTED_KINDS { - let builder = EventBuilder::new(Kind::Custom(5), reason) - .tag(Tag::custom( - TagKind::Custom("k".into()), - vec![kind.to_string()], - )); + let builder = EventBuilder::new(Kind::Custom(5), reason).tag(Tag::custom( + TagKind::Custom("k".into()), + vec![kind.to_string()], + )); self.base.relay_pool.publish(builder).await?; } Ok(()) @@ -481,60 +484,60 @@ impl NostrServerTransport { while let Ok(notification) = notifications.recv().await { if let RelayPoolNotification::Event { event, .. } = notification { - let (content, sender_pubkey, event_id, is_encrypted) = - if event.kind == Kind::Custom(GIFT_WRAP_KIND) - || event.kind == Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND) - { - if encryption_mode == EncryptionMode::Disabled { - tracing::warn!("Received encrypted message but encryption is disabled"); + let (content, sender_pubkey, event_id, is_encrypted) = if event.kind + == Kind::Custom(GIFT_WRAP_KIND) + || event.kind == Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND) + { + if encryption_mode == EncryptionMode::Disabled { + tracing::warn!("Received encrypted message but encryption is disabled"); + continue; + } + // Single-layer NIP-44 decrypt (matches JS/TS SDK) + let signer = match client.signer().await { + Ok(s) => s, + Err(e) => { + tracing::error!("Failed to get signer: {e}"); continue; } - // Single-layer NIP-44 decrypt (matches JS/TS SDK) - let signer = match client.signer().await { - Ok(s) => s, - Err(e) => { - tracing::error!("Failed to get signer: {e}"); - continue; - } - }; - match encryption::decrypt_gift_wrap_single_layer(&signer, &event).await { - Ok(decrypted_json) => { - // The decrypted content is JSON of the inner signed event. - // Use the INNER event's ID for correlation — the client - // registers the inner event ID in its correlation store. - match serde_json::from_str::(&decrypted_json) { - Ok(inner) => ( - inner.content, - inner.pubkey.to_hex(), - inner.id.to_hex(), - true, - ), - Err(e) => { - tracing::error!("Failed to parse inner event: {e}"); - continue; - } + }; + match encryption::decrypt_gift_wrap_single_layer(&signer, &event).await { + Ok(decrypted_json) => { + // The decrypted content is JSON of the inner signed event. + // Use the INNER event's ID for correlation — the client + // registers the inner event ID in its correlation store. + match serde_json::from_str::(&decrypted_json) { + Ok(inner) => ( + inner.content, + inner.pubkey.to_hex(), + inner.id.to_hex(), + true, + ), + Err(e) => { + tracing::error!("Failed to parse inner event: {e}"); + continue; } } - Err(e) => { - tracing::error!("Failed to decrypt: {e}"); - continue; - } } - } else { - if encryption_mode == EncryptionMode::Required { - tracing::warn!( - pubkey = %event.pubkey, - "Received unencrypted message but encryption is required" - ); + Err(e) => { + tracing::error!("Failed to decrypt: {e}"); continue; } - ( - event.content.clone(), - event.pubkey.to_hex(), - event.id.to_hex(), - false, - ) - }; + } + } else { + if encryption_mode == EncryptionMode::Required { + tracing::warn!( + pubkey = %event.pubkey, + "Received unencrypted message but encryption is required" + ); + continue; + } + ( + event.content.clone(), + event.pubkey.to_hex(), + event.id.to_hex(), + false, + ) + }; // Parse MCP message let mcp_msg = match serializers::nostr_event_to_mcp_message(&content) { @@ -688,22 +691,36 @@ mod tests { // Insert a session with an old activity time let mut session = ClientSession::new(false); - session.pending_requests.insert("evt1".to_string(), serde_json::json!(1)); - sessions.write().await.insert("pubkey1".to_string(), session); - event_to_client.write().await.insert("evt1".to_string(), "pubkey1".to_string()); + session + .pending_requests + .insert("evt1".to_string(), serde_json::json!(1)); + sessions + .write() + .await + .insert("pubkey1".to_string(), session); + event_to_client + .write() + .await + .insert("evt1".to_string(), "pubkey1".to_string()); // With a long timeout, nothing should be cleaned let cleaned = NostrServerTransport::cleanup_sessions( - &sessions, &event_to_client, Duration::from_secs(300), - ).await; + &sessions, + &event_to_client, + Duration::from_secs(300), + ) + .await; assert_eq!(cleaned, 0); assert_eq!(sessions.read().await.len(), 1); // With zero timeout, it should be cleaned thread::sleep(Duration::from_millis(5)); let cleaned = NostrServerTransport::cleanup_sessions( - &sessions, &event_to_client, Duration::from_millis(1), - ).await; + &sessions, + &event_to_client, + Duration::from_millis(1), + ) + .await; assert_eq!(cleaned, 1); assert!(sessions.read().await.is_empty()); assert!(event_to_client.read().await.is_empty()); @@ -718,8 +735,11 @@ mod tests { sessions.write().await.insert("active".to_string(), session); let cleaned = NostrServerTransport::cleanup_sessions( - &sessions, &event_to_client, Duration::from_secs(300), - ).await; + &sessions, + &event_to_client, + Duration::from_secs(300), + ) + .await; assert_eq!(cleaned, 0); assert_eq!(sessions.read().await.len(), 1); } @@ -729,24 +749,44 @@ mod tests { #[test] fn test_pending_request_tracking() { let mut session = ClientSession::new(false); - session.pending_requests.insert("event_abc".to_string(), serde_json::json!(42)); - assert_eq!(session.pending_requests.get("event_abc"), Some(&serde_json::json!(42))); + session + .pending_requests + .insert("event_abc".to_string(), serde_json::json!(42)); + assert_eq!( + session.pending_requests.get("event_abc"), + Some(&serde_json::json!(42)) + ); } #[test] fn test_progress_token_tracking() { let mut session = ClientSession::new(false); - session.event_to_progress_token.insert("evt1".to_string(), "token1".to_string()); - session.pending_requests.insert("token1".to_string(), serde_json::json!("evt1")); - assert_eq!(session.event_to_progress_token.get("evt1"), Some(&"token1".to_string())); + session + .event_to_progress_token + .insert("evt1".to_string(), "token1".to_string()); + session + .pending_requests + .insert("token1".to_string(), serde_json::json!("evt1")); + assert_eq!( + session.event_to_progress_token.get("evt1"), + Some(&"token1".to_string()) + ); } // ── Authorization (is_capability_excluded) ────────────────── #[test] fn test_initialize_always_excluded() { - assert!(NostrServerTransport::is_capability_excluded(&[], "initialize", None)); - assert!(NostrServerTransport::is_capability_excluded(&[], "notifications/initialized", None)); + assert!(NostrServerTransport::is_capability_excluded( + &[], + "initialize", + None + )); + assert!(NostrServerTransport::is_capability_excluded( + &[], + "notifications/initialized", + None + )); } #[test] @@ -755,8 +795,16 @@ mod tests { method: "tools/list".to_string(), name: None, }]; - assert!(NostrServerTransport::is_capability_excluded(&exclusions, "tools/list", None)); - assert!(NostrServerTransport::is_capability_excluded(&exclusions, "tools/list", Some("anything"))); + assert!(NostrServerTransport::is_capability_excluded( + &exclusions, + "tools/list", + None + )); + assert!(NostrServerTransport::is_capability_excluded( + &exclusions, + "tools/list", + Some("anything") + )); } #[test] @@ -765,9 +813,21 @@ mod tests { method: "tools/call".to_string(), name: Some("get_weather".to_string()), }]; - assert!(NostrServerTransport::is_capability_excluded(&exclusions, "tools/call", Some("get_weather"))); - assert!(!NostrServerTransport::is_capability_excluded(&exclusions, "tools/call", Some("other_tool"))); - assert!(!NostrServerTransport::is_capability_excluded(&exclusions, "tools/call", None)); + assert!(NostrServerTransport::is_capability_excluded( + &exclusions, + "tools/call", + Some("get_weather") + )); + assert!(!NostrServerTransport::is_capability_excluded( + &exclusions, + "tools/call", + Some("other_tool") + )); + assert!(!NostrServerTransport::is_capability_excluded( + &exclusions, + "tools/call", + None + )); } #[test] @@ -776,14 +836,30 @@ mod tests { method: "tools/list".to_string(), name: None, }]; - assert!(!NostrServerTransport::is_capability_excluded(&exclusions, "tools/call", None)); - assert!(!NostrServerTransport::is_capability_excluded(&exclusions, "resources/list", None)); + assert!(!NostrServerTransport::is_capability_excluded( + &exclusions, + "tools/call", + None + )); + assert!(!NostrServerTransport::is_capability_excluded( + &exclusions, + "resources/list", + None + )); } #[test] fn test_empty_exclusions_non_init_method() { - assert!(!NostrServerTransport::is_capability_excluded(&[], "tools/list", None)); - assert!(!NostrServerTransport::is_capability_excluded(&[], "tools/call", Some("x"))); + assert!(!NostrServerTransport::is_capability_excluded( + &[], + "tools/list", + None + )); + assert!(!NostrServerTransport::is_capability_excluded( + &[], + "tools/call", + Some("x") + )); } // ── Encryption mode enforcement ───────────────────────────── From f4e2fd976f21a726af9f6b394e166007ba1d3b78 Mon Sep 17 00:00:00 2001 From: Kushagra Date: Sat, 4 Apr 2026 19:50:14 +0530 Subject: [PATCH 12/69] fix: removed hard coded values and redundant KIND types --- examples/rmcp_integration_test.rs | 10 +-- src/core/constants.rs | 17 ++++- src/encryption/mod.rs | 13 +++- src/lib.rs | 1 + src/rmcp_transport/convert.rs | 117 ++++++++++++++++++++++++++++-- src/rmcp_transport/worker.rs | 1 + src/transport/client.rs | 4 +- 7 files changed, 143 insertions(+), 20 deletions(-) diff --git a/examples/rmcp_integration_test.rs b/examples/rmcp_integration_test.rs index c9656ef..c0bedc6 100644 --- a/examples/rmcp_integration_test.rs +++ b/examples/rmcp_integration_test.rs @@ -17,7 +17,7 @@ //! cargo run --example rmcp_integration_test --features rmcp -- all wss://relay.primal.net use anyhow::{anyhow, bail, Context, Result}; -use contextvm_sdk::core::constants::MCP_PROTOCOL_VERSION; +use contextvm_sdk::core::constants::mcp_protocol_version; use contextvm_sdk::core::types::{ EncryptionMode, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, ServerInfo as CtxServerInfo, @@ -359,7 +359,7 @@ async fn run_hybrid_relay_case(relay_url: &str) -> Result<()> { id: init_id.clone(), method: "initialize".to_string(), params: Some(serde_json::json!({ - "protocolVersion": MCP_PROTOCOL_VERSION, + "protocolVersion": mcp_protocol_version(), "capabilities": { "tools": {}, "resources": {} @@ -654,7 +654,7 @@ fn assert_initialize_shape(response: &JsonRpcMessage) -> Result<()> { let JsonRpcMessage::Response(resp) = response else { bail!("expected initialize response, got {response:?}"); }; - + let expected_protocol = mcp_protocol_version(); let protocol = resp .result .get("protocolVersion") @@ -663,7 +663,7 @@ fn assert_initialize_shape(response: &JsonRpcMessage) -> Result<()> { if !is_supported_protocol(protocol) { bail!( - "unexpected protocolVersion in initialize response: expected one of [{MCP_PROTOCOL_VERSION}, {}], got {protocol}", + "unexpected protocolVersion in initialize response: expected one of [{expected_protocol}, {}], got {protocol}", ProtocolVersion::LATEST ); } @@ -676,7 +676,7 @@ fn assert_initialize_shape(response: &JsonRpcMessage) -> Result<()> { } fn is_supported_protocol(protocol: &str) -> bool { - protocol == MCP_PROTOCOL_VERSION || protocol == ProtocolVersion::LATEST.to_string() + protocol == mcp_protocol_version() || protocol == ProtocolVersion::LATEST.to_string() } fn assert_error_response(response: &JsonRpcMessage) -> Result<()> { diff --git a/src/core/constants.rs b/src/core/constants.rs index b7f8a8b..85cf82b 100644 --- a/src/core/constants.rs +++ b/src/core/constants.rs @@ -33,6 +33,7 @@ pub const RESOURCETEMPLATES_LIST_KIND: u16 = 11319; /// Prompts list (addressable, kind 11320) pub const PROMPTS_LIST_KIND: u16 = 11320; +pub const KIND_GIFT_WRAP: u16 = 1059; /// Nostr tag constants pub mod tags { /// Public key tag @@ -73,7 +74,6 @@ pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024; /// /// Matches the `protocolVersion` field of the `InitializeResult` JSON-RPC response. /// Keep this in sync with the MCP spec and rmcp's `ProtocolVersion::LATEST`. -pub const MCP_PROTOCOL_VERSION: &str = "2025-11-25"; /// Default LRU cache size for deduplication pub const DEFAULT_LRU_SIZE: usize = 5000; @@ -109,6 +109,21 @@ pub const UNENCRYPTED_KINDS: &[u16] = &[ PROMPTS_LIST_KIND, ]; + +#[cfg(feature = "rmcp")] +pub fn mcp_protocol_version() -> &'static str { + use std::sync::OnceLock; + static VERSION: OnceLock = OnceLock::new(); + VERSION + .get_or_init(|| rmcp::model::ProtocolVersion::LATEST.to_string()) + .as_str() +} + +#[cfg(not(feature = "rmcp"))] +pub const fn mcp_protocol_version() -> &'static str { + "2025-11-25" +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/encryption/mod.rs b/src/encryption/mod.rs index d5171e0..0d5b369 100644 --- a/src/encryption/mod.rs +++ b/src/encryption/mod.rs @@ -111,6 +111,8 @@ pub async fn gift_wrap( #[cfg(test)] mod tests { + use crate::core::constants::KIND_GIFT_WRAP; + use super::*; #[tokio::test] @@ -139,7 +141,10 @@ mod tests { /// 2. NIP-44 encrypt the plaintext using ephemeral_secret + recipient_pubkey /// 3. Build kind 1059 event with encrypted content, `p` tag = recipient /// 4. Sign with ephemeral key - async fn create_js_style_gift_wrap(plaintext: &str, recipient: &PublicKey) -> (Event, Keys) { + async fn create_simple_gift_wrap( + plaintext: &str, + recipient: &PublicKey, + ) -> (Event, Keys) { let ephemeral = Keys::generate(); // Single-layer NIP-44 encrypt @@ -148,8 +153,8 @@ mod tests { .unwrap(); // Build kind 1059 event - let builder = - EventBuilder::new(Kind::Custom(1059), encrypted).tag(Tag::public_key(*recipient)); + let builder = EventBuilder::new(Kind::from(KIND_GIFT_WRAP), encrypted) + .tag(Tag::public_key(*recipient)); let event = builder.sign_with_keys(&ephemeral).unwrap(); (event, ephemeral) @@ -177,7 +182,7 @@ mod tests { // Step 3: Encrypt as a gift wrap let (gift_wrap, _ephemeral) = - create_js_style_gift_wrap(&inner_json, &server_keys.public_key()).await; + create_simple_gift_wrap(&inner_json, &server_keys.public_key()).await; assert_eq!(gift_wrap.kind, Kind::Custom(1059)); diff --git a/src/lib.rs b/src/lib.rs index cc2a09e..0af4d58 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -44,6 +44,7 @@ pub mod proxy; pub mod relay; pub mod signer; pub mod transport; +pub mod util; #[cfg(feature = "rmcp")] pub mod rmcp_transport; diff --git a/src/rmcp_transport/convert.rs b/src/rmcp_transport/convert.rs index e575c37..a20fe20 100644 --- a/src/rmcp_transport/convert.rs +++ b/src/rmcp_transport/convert.rs @@ -4,6 +4,7 @@ //! compatibility and avoid fragile hand-mapping between evolving type systems. use crate::core::types::JsonRpcMessage; +use crate::util::logger; /// Convert internal JSON-RPC message into rmcp server RX message. /// @@ -12,8 +13,33 @@ use crate::core::types::JsonRpcMessage; pub fn internal_to_rmcp_server_rx( msg: &JsonRpcMessage, ) -> Option> { - let value = serde_json::to_value(msg).ok()?; - serde_json::from_value(value).ok() + let direction = "internal_to_rmcp_server_rx"; + let target = "contextvm_sdk::rmcp_transport::convert"; + let value = match serde_json::to_value(msg) { + Ok(value) => value, + Err(error) => { + logger::error_with_target( + target, + format!( + "{direction}: failed to serialize message into intermediate JSON: {error}" + ), + ); + return None; + } + }; + + match serde_json::from_value(value.clone()) { + Ok(parsed) => Some(parsed), + Err(error) => { + logger::error_with_target( + target, + format!( + "{direction}: failed to parse converted JSON payload: {error}; payload={value:?}" + ), + ); + None + } + } } /// Convert internal JSON-RPC message into rmcp client RX message. @@ -23,24 +49,99 @@ pub fn internal_to_rmcp_server_rx( pub fn internal_to_rmcp_client_rx( msg: &JsonRpcMessage, ) -> Option> { - let value = serde_json::to_value(msg).ok()?; - serde_json::from_value(value).ok() + let direction = "internal_to_rmcp_client_rx"; + let target = "contextvm_sdk::rmcp_transport::convert"; + let value = match serde_json::to_value(msg) { + Ok(value) => value, + Err(error) => { + logger::error_with_target( + target, + format!( + "{direction}: failed to serialize message into intermediate JSON: {error}" + ), + ); + return None; + } + }; + + match serde_json::from_value(value.clone()) { + Ok(parsed) => Some(parsed), + Err(error) => { + logger::error_with_target( + target, + format!( + "{direction}: failed to parse converted JSON payload: {error}; payload={value:?}" + ), + ); + None + } + } } /// Convert rmcp server TX message back into internal JSON-RPC. pub fn rmcp_server_tx_to_internal( msg: rmcp::service::TxJsonRpcMessage, ) -> Option { - let value = serde_json::to_value(msg).ok()?; - serde_json::from_value(value).ok() + let direction = "rmcp_server_tx_to_internal"; + let target = "contextvm_sdk::rmcp_transport::convert"; + let value = match serde_json::to_value(msg) { + Ok(value) => value, + Err(error) => { + logger::error_with_target( + target, + format!( + "{direction}: failed to serialize message into intermediate JSON: {error}" + ), + ); + return None; + } + }; + + match serde_json::from_value(value.clone()) { + Ok(parsed) => Some(parsed), + Err(error) => { + logger::error_with_target( + target, + format!( + "{direction}: failed to parse converted JSON payload: {error}; payload={value:?}" + ), + ); + None + } + } } /// Convert rmcp client TX message back into internal JSON-RPC. pub fn rmcp_client_tx_to_internal( msg: rmcp::service::TxJsonRpcMessage, ) -> Option { - let value = serde_json::to_value(msg).ok()?; - serde_json::from_value(value).ok() + let direction = "rmcp_client_tx_to_internal"; + let target = "contextvm_sdk::rmcp_transport::convert"; + let value = match serde_json::to_value(msg) { + Ok(value) => value, + Err(error) => { + logger::error_with_target( + target, + format!( + "{direction}: failed to serialize message into intermediate JSON: {error}" + ), + ); + return None; + } + }; + + match serde_json::from_value(value.clone()) { + Ok(parsed) => Some(parsed), + Err(error) => { + logger::error_with_target( + target, + format!( + "{direction}: failed to parse converted JSON payload: {error}; payload={value:?}" + ), + ); + None + } + } } #[cfg(all(test, feature = "rmcp"))] diff --git a/src/rmcp_transport/worker.rs b/src/rmcp_transport/worker.rs index 0d52db0..57b19db 100644 --- a/src/rmcp_transport/worker.rs +++ b/src/rmcp_transport/worker.rs @@ -7,6 +7,7 @@ use std::collections::HashMap; use crate::core::error::Result; use crate::core::types::JsonRpcMessage; +use crate::util::logger; use crate::transport::client::{NostrClientTransport, NostrClientTransportConfig}; use crate::transport::server::{NostrServerTransport, NostrServerTransportConfig}; use rmcp::transport::worker::{Worker, WorkerContext, WorkerQuitReason}; diff --git a/src/transport/client.rs b/src/transport/client.rs index 5da3555..1cc1a0e 100644 --- a/src/transport/client.rs +++ b/src/transport/client.rs @@ -161,7 +161,7 @@ impl NostrClientTransport { jsonrpc: "2.0".to_string(), id: request_id.clone(), result: serde_json::json!({ - "protocolVersion": crate::core::constants::MCP_PROTOCOL_VERSION, + "protocolVersion": crate::core::constants::mcp_protocol_version(), "serverInfo": { "name": "Emulated-Stateless-Server", "version": "1.0.0" @@ -284,7 +284,7 @@ mod tests { jsonrpc: "2.0".to_string(), id: request_id.clone(), result: serde_json::json!({ - "protocolVersion": crate::core::constants::MCP_PROTOCOL_VERSION, + "protocolVersion": crate::core::constants::mcp_protocol_version(), "serverInfo": { "name": "Emulated-Stateless-Server", "version": "1.0.0" From 905b16ab90923e87860ef8e4cd28fcb4f3db50cc Mon Sep 17 00:00:00 2001 From: Harsh Date: Sun, 5 Apr 2026 23:42:06 +0530 Subject: [PATCH 13/69] fix: add missing util::logger module for rmcp build --- src/rmcp_transport/worker.rs | 1 - src/util/logger.rs | 14 ++++++++++++++ src/util/mod.rs | 3 +++ 3 files changed, 17 insertions(+), 1 deletion(-) create mode 100644 src/util/logger.rs create mode 100644 src/util/mod.rs diff --git a/src/rmcp_transport/worker.rs b/src/rmcp_transport/worker.rs index 57b19db..0d52db0 100644 --- a/src/rmcp_transport/worker.rs +++ b/src/rmcp_transport/worker.rs @@ -7,7 +7,6 @@ use std::collections::HashMap; use crate::core::error::Result; use crate::core::types::JsonRpcMessage; -use crate::util::logger; use crate::transport::client::{NostrClientTransport, NostrClientTransportConfig}; use crate::transport::server::{NostrServerTransport, NostrServerTransportConfig}; use rmcp::transport::worker::{Worker, WorkerContext, WorkerQuitReason}; diff --git a/src/util/logger.rs b/src/util/logger.rs new file mode 100644 index 0000000..fa353f4 --- /dev/null +++ b/src/util/logger.rs @@ -0,0 +1,14 @@ +//! Thin wrappers around `tracing` for stable call sites (e.g. rmcp conversion). +//! +//! Note: `tracing` macros require a string-literal `target:` for callsite registration. +//! The `logical_target` field preserves the module-style string passed by callers. + +/// Log at error level; `logical_target` is the caller’s notion of module/path (for filters in output). +pub fn error_with_target(logical_target: &str, message: impl std::fmt::Display) { + tracing::error!( + target: "contextvm_sdk", + logical_target = logical_target, + "{}", + message + ); +} diff --git a/src/util/mod.rs b/src/util/mod.rs new file mode 100644 index 0000000..5f82fa5 --- /dev/null +++ b/src/util/mod.rs @@ -0,0 +1,3 @@ +//! Small shared helpers (logging wrappers, etc.). + +pub mod logger; From ae57431124e701a1503d64ac3cd2821d38e28f0e Mon Sep 17 00:00:00 2001 From: Harsh Date: Mon, 6 Apr 2026 14:51:17 +0530 Subject: [PATCH 14/69] test: add comprehensive tests for core types --- src/core/types.rs | 275 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 275 insertions(+) diff --git a/src/core/types.rs b/src/core/types.rs index 4ea2e34..a811269 100644 --- a/src/core/types.rs +++ b/src/core/types.rs @@ -222,3 +222,278 @@ pub struct CapabilityExclusion { /// Optional capability name for method-specific exclusions (e.g., "get_weather"). pub name: Option, } + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + use std::thread; + use std::time::Duration; + + #[test] + fn test_encryption_mode_serde_roundtrip_optional() { + let mode = EncryptionMode::Optional; + let s = serde_json::to_string(&mode).unwrap(); + assert_eq!(s, "\"optional\""); + let parsed: EncryptionMode = serde_json::from_str(&s).unwrap(); + assert_eq!(parsed, mode); + } + + #[test] + fn test_encryption_mode_serde_roundtrip_required() { + let mode = EncryptionMode::Required; + let s = serde_json::to_string(&mode).unwrap(); + assert_eq!(s, "\"required\""); + let parsed: EncryptionMode = serde_json::from_str(&s).unwrap(); + assert_eq!(parsed, mode); + } + + #[test] + fn test_encryption_mode_serde_roundtrip_disabled() { + let mode = EncryptionMode::Disabled; + let s = serde_json::to_string(&mode).unwrap(); + assert_eq!(s, "\"disabled\""); + let parsed: EncryptionMode = serde_json::from_str(&s).unwrap(); + assert_eq!(parsed, mode); + } + + fn assert_json_rpc_roundtrip(msg: &JsonRpcMessage) { + let wire = serde_json::to_string(msg).unwrap(); + let parsed: JsonRpcMessage = serde_json::from_str(&wire).unwrap(); + let before = serde_json::to_value(msg).unwrap(); + let after = serde_json::to_value(&parsed).unwrap(); + assert_eq!(before, after); + } + + #[test] + fn test_json_rpc_message_serde_roundtrip_request() { + let msg = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: json!(42), + method: "tools/list".to_string(), + params: Some(json!({ "cursor": null })), + }); + assert_json_rpc_roundtrip(&msg); + } + + #[test] + fn test_json_rpc_message_serde_roundtrip_request_without_params() { + let msg = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: json!("req-id"), + method: "ping".to_string(), + params: None, + }); + assert_json_rpc_roundtrip(&msg); + } + + #[test] + fn test_json_rpc_message_serde_roundtrip_response() { + let msg = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: json!(1), + result: json!({ "tools": [] }), + }); + assert_json_rpc_roundtrip(&msg); + } + + #[test] + fn test_json_rpc_message_serde_roundtrip_error_response() { + let msg = JsonRpcMessage::ErrorResponse(JsonRpcErrorResponse { + jsonrpc: "2.0".to_string(), + id: json!(99), + error: JsonRpcError { + code: -32600, + message: "Invalid Request".to_string(), + data: Some(json!({ "hint": "fix it" })), + }, + }); + assert_json_rpc_roundtrip(&msg); + } + + #[test] + fn test_json_rpc_message_serde_roundtrip_notification() { + let msg = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/initialized".to_string(), + params: None, + }); + assert_json_rpc_roundtrip(&msg); + } + + #[test] + fn test_json_rpc_message_type_predicates() { + let req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: json!(1), + method: "m".to_string(), + params: None, + }); + let res = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: json!(1), + result: json!(null), + }); + let err = JsonRpcMessage::ErrorResponse(JsonRpcErrorResponse { + jsonrpc: "2.0".to_string(), + id: json!(1), + error: JsonRpcError { + code: -1, + message: "e".to_string(), + data: None, + }, + }); + let notif = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "n".to_string(), + params: None, + }); + + assert!(req.is_request()); + assert!(res.is_response()); + assert!(err.is_error()); + assert!(notif.is_notification()); + } + + #[test] + fn test_json_rpc_error_data_none_omitted() { + let err = JsonRpcError { + code: -32600, + message: "bad".to_string(), + data: None, + }; + let json_str = serde_json::to_string(&err).unwrap(); + let value: serde_json::Value = serde_json::from_str(&json_str).unwrap(); + let obj = value.as_object().expect("error object"); + assert!( + !obj.contains_key("data"), + "expected data omitted when None, got: {json_str}" + ); + } + + #[test] + fn test_json_rpc_message_method() { + let req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: json!(0), + method: "tools/call".to_string(), + params: None, + }); + let res = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: json!(0), + result: json!(null), + }); + let err = JsonRpcMessage::ErrorResponse(JsonRpcErrorResponse { + jsonrpc: "2.0".to_string(), + id: json!(0), + error: JsonRpcError { + code: 0, + message: "e".to_string(), + data: None, + }, + }); + let notif = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/progress".to_string(), + params: None, + }); + + assert_eq!(req.method(), Some("tools/call")); + assert_eq!(res.method(), None); + assert_eq!(err.method(), None); + assert_eq!(notif.method(), Some("notifications/progress")); + } + + #[test] + fn test_json_rpc_message_id() { + let req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: json!("abc"), + method: "m".to_string(), + params: None, + }); + let res = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: json!(7), + result: json!(null), + }); + let err = JsonRpcMessage::ErrorResponse(JsonRpcErrorResponse { + jsonrpc: "2.0".to_string(), + id: json!([1, 2]), + error: JsonRpcError { + code: 0, + message: "e".to_string(), + data: None, + }, + }); + let notif = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "n".to_string(), + params: None, + }); + + assert_eq!(req.id(), Some(&json!("abc"))); + assert_eq!(res.id(), Some(&json!(7))); + assert_eq!(err.id(), Some(&json!([1, 2]))); + assert_eq!(notif.id(), None); + } + + #[test] + fn test_server_info_serde_all_fields_present() { + let info = ServerInfo { + name: Some("Test Server".to_string()), + version: Some("1.0.0".to_string()), + picture: Some("https://example.com/p.png".to_string()), + website: Some("https://example.com".to_string()), + about: Some("About text".to_string()), + }; + let json_str = serde_json::to_string(&info).unwrap(); + let parsed: ServerInfo = serde_json::from_str(&json_str).unwrap(); + assert_eq!(parsed.name, info.name); + assert_eq!(parsed.version, info.version); + assert_eq!(parsed.picture, info.picture); + assert_eq!(parsed.website, info.website); + assert_eq!(parsed.about, info.about); + } + + #[test] + fn test_server_info_serde_optional_fields_omitted() { + let info = ServerInfo { + name: None, + version: None, + picture: None, + website: None, + about: None, + }; + let json_str = serde_json::to_string(&info).unwrap(); + assert_eq!(json_str, "{}"); + } + + #[test] + fn test_client_session_new_initial_state_encrypted() { + let session = ClientSession::new(true); + assert!(!session.is_initialized); + assert!(session.is_encrypted); + assert!(session.pending_requests.is_empty()); + assert!(session.event_to_progress_token.is_empty()); + } + + #[test] + fn test_client_session_new_initial_state_plaintext() { + let session = ClientSession::new(false); + assert!(!session.is_initialized); + assert!(!session.is_encrypted); + assert!(session.pending_requests.is_empty()); + assert!(session.event_to_progress_token.is_empty()); + } + + #[test] + fn test_client_session_update_activity() { + let mut session = ClientSession::new(false); + let first = session.last_activity; + thread::sleep(Duration::from_millis(10)); + session.update_activity(); + assert!(session.last_activity > first); + } +} From 024ee46c60e9a77b4a4fc5d9727f5380b61b3896 Mon Sep 17 00:00:00 2001 From: Kushagra Date: Tue, 7 Apr 2026 03:01:49 +0530 Subject: [PATCH 15/69] feat: implemented tracing setup for subscriptions and custom log format and file paths --- src/util/logger.rs | 14 --- src/util/mod.rs | 4 +- src/util/tracing_setup.rs | 259 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 260 insertions(+), 17 deletions(-) delete mode 100644 src/util/logger.rs create mode 100644 src/util/tracing_setup.rs diff --git a/src/util/logger.rs b/src/util/logger.rs deleted file mode 100644 index fa353f4..0000000 --- a/src/util/logger.rs +++ /dev/null @@ -1,14 +0,0 @@ -//! Thin wrappers around `tracing` for stable call sites (e.g. rmcp conversion). -//! -//! Note: `tracing` macros require a string-literal `target:` for callsite registration. -//! The `logical_target` field preserves the module-style string passed by callers. - -/// Log at error level; `logical_target` is the caller’s notion of module/path (for filters in output). -pub fn error_with_target(logical_target: &str, message: impl std::fmt::Display) { - tracing::error!( - target: "contextvm_sdk", - logical_target = logical_target, - "{}", - message - ); -} diff --git a/src/util/mod.rs b/src/util/mod.rs index 5f82fa5..c5eb5d2 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -1,3 +1 @@ -//! Small shared helpers (logging wrappers, etc.). - -pub mod logger; +pub mod tracing_setup; diff --git a/src/util/tracing_setup.rs b/src/util/tracing_setup.rs new file mode 100644 index 0000000..313bb5d --- /dev/null +++ b/src/util/tracing_setup.rs @@ -0,0 +1,259 @@ +//! Internal tracing subscriber setup for ContextVM transports. + +use std::fmt; +use std::fs::{File, OpenOptions}; +use std::io::{self, Write}; +use std::path::{Path, PathBuf}; +use std::sync::{Mutex, OnceLock}; + +use tracing::Event; +use tracing_subscriber::fmt::format::Writer; +use tracing_subscriber::fmt::writer::MakeWriter; +use tracing_subscriber::fmt::{FmtContext, FormatEvent, FormatFields}; +use tracing_subscriber::layer::{Layer, SubscriberExt}; +use tracing_subscriber::registry::LookupSpan; +use tracing_subscriber::{EnvFilter, Registry}; + +use crate::core::error::{Error, Result}; + +static TRACING_SETUP_LOCK: OnceLock> = OnceLock::new(); +static TRACING_INITIALIZED: OnceLock<()> = OnceLock::new(); +static LOG_DESTINATION: OnceLock> = OnceLock::new(); + +fn tracing_setup_lock() -> &'static Mutex<()> { + TRACING_SETUP_LOCK.get_or_init(|| Mutex::new(())) +} + +fn log_destination() -> &'static Mutex { + LOG_DESTINATION.get_or_init(|| Mutex::new(LogDestination::default())) +} + +pub(crate) fn init_tracer(log_file_path: Option<&str>) -> Result<()> { + let _guard = tracing_setup_lock() + .lock() + .map_err(|_| Error::Other("failed to acquire tracing setup lock".to_string()))?; + + configure_file_output(log_file_path)?; + + if TRACING_INITIALIZED.get().is_some() { + return Ok(()); + } + + let subscriber = Registry::default().with( + tracing_subscriber::fmt::layer() + .with_ansi(false) + .with_writer(ContextVmMakeWriter) + .event_format(ContextVmEventFormatter) + .with_filter(build_env_filter()), + ); + + match tracing::subscriber::set_global_default(subscriber) { + Ok(()) => { + let _ = TRACING_INITIALIZED.set(()); + Ok(()) + } + Err(error) => { + let text = error.to_string(); + if text.contains("global default trace dispatcher has already been set") { + let _ = TRACING_INITIALIZED.set(()); + Ok(()) + } else { + Err(Error::Other(format!( + "failed to initialize tracing subscriber: {text}" + ))) + } + } + } +} + +fn configure_file_output(log_file_path: Option<&str>) -> Result<()> { + let Some(path) = normalize_log_file_path(log_file_path) else { + return Ok(()); + }; + + ensure_parent_exists(&path)?; + + let file = OpenOptions::new() + .create(true) + .append(true) + .open(&path) + .map_err(|error| { + Error::Other(format!( + "failed to open log file {}: {error}", + path.display() + )) + })?; + + let mut destination = log_destination() + .lock() + .map_err(|_| Error::Other("failed to acquire log destination lock".to_string()))?; + destination.file = Some(file); + + Ok(()) +} + +fn normalize_log_file_path(log_file_path: Option<&str>) -> Option { + let trimmed = log_file_path?.trim(); + if trimmed.is_empty() { + None + } else { + Some(PathBuf::from(trimmed)) + } +} + +fn ensure_parent_exists(path: &Path) -> Result<()> { + if let Some(parent) = path.parent() { + if !parent.as_os_str().is_empty() { + std::fs::create_dir_all(parent).map_err(|error| { + Error::Other(format!( + "failed to create log directory {}: {error}", + parent.display() + )) + })?; + } + } + + Ok(()) +} + +fn build_env_filter() -> EnvFilter { + EnvFilter::try_from_default_env() + .unwrap_or_else(|_| EnvFilter::new("contextvm_sdk=info,rmcp=warn")) +} + +#[derive(Default)] +struct LogDestination { + file: Option, +} + +#[derive(Clone, Copy)] +struct ContextVmMakeWriter; + +impl<'a> MakeWriter<'a> for ContextVmMakeWriter { + type Writer = ContextVmWriter; + + fn make_writer(&'a self) -> Self::Writer { + ContextVmWriter { + stdout: io::stdout(), + } + } +} + +struct ContextVmWriter { + stdout: io::Stdout, +} + +impl Write for ContextVmWriter { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.stdout.write_all(buf)?; + + if let Ok(mut destination) = log_destination().lock() { + if let Some(file) = destination.file.as_mut() { + let _ = file.write_all(buf); + } + } + + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + self.stdout.flush()?; + + if let Ok(mut destination) = log_destination().lock() { + if let Some(file) = destination.file.as_mut() { + let _ = file.flush(); + } + } + + Ok(()) + } +} + +#[derive(Default)] +struct MessageVisitor { + message: Option, + extra_fields: Vec<(String, String)>, +} + +impl MessageVisitor { + fn record_field(&mut self, name: &str, value: String) { + if name == "message" { + self.message = Some(value); + } else { + self.extra_fields.push((name.to_string(), value)); + } + } +} + +impl tracing::field::Visit for MessageVisitor { + fn record_i64(&mut self, field: &tracing::field::Field, value: i64) { + self.record_field(field.name(), value.to_string()); + } + + fn record_u64(&mut self, field: &tracing::field::Field, value: u64) { + self.record_field(field.name(), value.to_string()); + } + + fn record_bool(&mut self, field: &tracing::field::Field, value: bool) { + self.record_field(field.name(), value.to_string()); + } + + fn record_str(&mut self, field: &tracing::field::Field, value: &str) { + self.record_field(field.name(), value.to_string()); + } + + fn record_error( + &mut self, + field: &tracing::field::Field, + value: &(dyn std::error::Error + 'static), + ) { + self.record_field(field.name(), value.to_string()); + } + + fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn fmt::Debug) { + self.record_field(field.name(), format!("{value:?}")); + } +} + +struct ContextVmEventFormatter; + +impl FormatEvent for ContextVmEventFormatter +where + S: tracing::Subscriber + for<'span> LookupSpan<'span>, + N: for<'writer> FormatFields<'writer> + 'static, +{ + fn format_event( + &self, + _ctx: &FmtContext<'_, S, N>, + mut writer: Writer<'_>, + event: &Event<'_>, + ) -> fmt::Result { + let mut visitor = MessageVisitor::default(); + event.record(&mut visitor); + + let metadata = event.metadata(); + let timestamp = unix_timestamp(); + let level = metadata.level().to_string().to_lowercase(); + let message = visitor.message.unwrap_or_default(); + + write!( + writer, + "{timestamp}:{level}::{}:{message}", + metadata.target() + )?; + + for (key, value) in visitor.extra_fields { + write!(writer, " {key}={value}")?; + } + + writeln!(writer) + } +} + +fn unix_timestamp() -> String { + use std::time::{SystemTime, UNIX_EPOCH}; + + let now = SystemTime::now(); + let duration = now.duration_since(UNIX_EPOCH).unwrap_or_default(); + format!("{}.{:03}", duration.as_secs(), duration.subsec_millis()) +} From c84036d31ea6b58f1916a75b2b9b36c0ea3a152f Mon Sep 17 00:00:00 2001 From: Kushagra Date: Tue, 7 Apr 2026 03:02:56 +0530 Subject: [PATCH 16/69] enhancement: improved error logging in encryption module --- src/encryption/mod.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/encryption/mod.rs b/src/encryption/mod.rs index 0d5b369..1e32426 100644 --- a/src/encryption/mod.rs +++ b/src/encryption/mod.rs @@ -141,10 +141,7 @@ mod tests { /// 2. NIP-44 encrypt the plaintext using ephemeral_secret + recipient_pubkey /// 3. Build kind 1059 event with encrypted content, `p` tag = recipient /// 4. Sign with ephemeral key - async fn create_simple_gift_wrap( - plaintext: &str, - recipient: &PublicKey, - ) -> (Event, Keys) { + async fn create_simple_gift_wrap(plaintext: &str, recipient: &PublicKey) -> (Event, Keys) { let ephemeral = Keys::generate(); // Single-layer NIP-44 encrypt From bc10f030d1a89555554ef5e775bfaaccbc9339fb Mon Sep 17 00:00:00 2001 From: Kushagra Date: Tue, 7 Apr 2026 03:03:49 +0530 Subject: [PATCH 17/69] enhancement: improved error logging in gateway and proxy modules --- src/gateway/mod.rs | 1 + src/proxy/mod.rs | 1 + 2 files changed, 2 insertions(+) diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 427b48f..ea388b7 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -127,6 +127,7 @@ mod tests { excluded_capabilities: vec![], cleanup_interval: Duration::from_secs(120), session_timeout: Duration::from_secs(600), + log_file_path: None, }; let config = GatewayConfig { nostr_config }; diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 25fc2a8..df1d386 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -112,6 +112,7 @@ mod tests { encryption_mode: EncryptionMode::Required, is_stateless: true, timeout: Duration::from_secs(60), + log_file_path: None, }; let config = ProxyConfig { nostr_config }; From 3cb5eb0ad873d17b0db66aa28640fe8a13c52d1e Mon Sep 17 00:00:00 2001 From: Kushagra Date: Tue, 7 Apr 2026 03:05:29 +0530 Subject: [PATCH 18/69] enhancement: improved error logging in rmcp and transport modules --- src/lib.rs | 2 +- src/rmcp_transport/convert.rs | 91 +++++++-------- src/transport/client.rs | 143 ++++++++++++++++++++---- src/transport/server.rs | 204 ++++++++++++++++++++++++++++------ 4 files changed, 344 insertions(+), 96 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 0af4d58..2157224 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -44,10 +44,10 @@ pub mod proxy; pub mod relay; pub mod signer; pub mod transport; -pub mod util; #[cfg(feature = "rmcp")] pub mod rmcp_transport; +mod util; // Re-export commonly used types pub use core::error::{Error, Result}; diff --git a/src/rmcp_transport/convert.rs b/src/rmcp_transport/convert.rs index a20fe20..7df0783 100644 --- a/src/rmcp_transport/convert.rs +++ b/src/rmcp_transport/convert.rs @@ -4,7 +4,8 @@ //! compatibility and avoid fragile hand-mapping between evolving type systems. use crate::core::types::JsonRpcMessage; -use crate::util::logger; + +const LOG_TARGET: &str = "contextvm_sdk::rmcp_transport::convert"; /// Convert internal JSON-RPC message into rmcp server RX message. /// @@ -14,15 +15,14 @@ pub fn internal_to_rmcp_server_rx( msg: &JsonRpcMessage, ) -> Option> { let direction = "internal_to_rmcp_server_rx"; - let target = "contextvm_sdk::rmcp_transport::convert"; let value = match serde_json::to_value(msg) { Ok(value) => value, Err(error) => { - logger::error_with_target( - target, - format!( - "{direction}: failed to serialize message into intermediate JSON: {error}" - ), + tracing::error!( + target: LOG_TARGET, + direction = direction, + error = %error, + "Failed to serialize message into intermediate JSON" ); return None; } @@ -31,11 +31,12 @@ pub fn internal_to_rmcp_server_rx( match serde_json::from_value(value.clone()) { Ok(parsed) => Some(parsed), Err(error) => { - logger::error_with_target( - target, - format!( - "{direction}: failed to parse converted JSON payload: {error}; payload={value:?}" - ), + tracing::error!( + target: LOG_TARGET, + direction = direction, + error = %error, + payload = ?value, + "Failed to parse converted JSON payload" ); None } @@ -50,15 +51,14 @@ pub fn internal_to_rmcp_client_rx( msg: &JsonRpcMessage, ) -> Option> { let direction = "internal_to_rmcp_client_rx"; - let target = "contextvm_sdk::rmcp_transport::convert"; let value = match serde_json::to_value(msg) { Ok(value) => value, Err(error) => { - logger::error_with_target( - target, - format!( - "{direction}: failed to serialize message into intermediate JSON: {error}" - ), + tracing::error!( + target: LOG_TARGET, + direction = direction, + error = %error, + "Failed to serialize message into intermediate JSON" ); return None; } @@ -67,11 +67,12 @@ pub fn internal_to_rmcp_client_rx( match serde_json::from_value(value.clone()) { Ok(parsed) => Some(parsed), Err(error) => { - logger::error_with_target( - target, - format!( - "{direction}: failed to parse converted JSON payload: {error}; payload={value:?}" - ), + tracing::error!( + target: LOG_TARGET, + direction = direction, + error = %error, + payload = ?value, + "Failed to parse converted JSON payload" ); None } @@ -83,15 +84,14 @@ pub fn rmcp_server_tx_to_internal( msg: rmcp::service::TxJsonRpcMessage, ) -> Option { let direction = "rmcp_server_tx_to_internal"; - let target = "contextvm_sdk::rmcp_transport::convert"; let value = match serde_json::to_value(msg) { Ok(value) => value, Err(error) => { - logger::error_with_target( - target, - format!( - "{direction}: failed to serialize message into intermediate JSON: {error}" - ), + tracing::error!( + target: LOG_TARGET, + direction = direction, + error = %error, + "Failed to serialize message into intermediate JSON" ); return None; } @@ -100,11 +100,12 @@ pub fn rmcp_server_tx_to_internal( match serde_json::from_value(value.clone()) { Ok(parsed) => Some(parsed), Err(error) => { - logger::error_with_target( - target, - format!( - "{direction}: failed to parse converted JSON payload: {error}; payload={value:?}" - ), + tracing::error!( + target: LOG_TARGET, + direction = direction, + error = %error, + payload = ?value, + "Failed to parse converted JSON payload" ); None } @@ -116,15 +117,14 @@ pub fn rmcp_client_tx_to_internal( msg: rmcp::service::TxJsonRpcMessage, ) -> Option { let direction = "rmcp_client_tx_to_internal"; - let target = "contextvm_sdk::rmcp_transport::convert"; let value = match serde_json::to_value(msg) { Ok(value) => value, Err(error) => { - logger::error_with_target( - target, - format!( - "{direction}: failed to serialize message into intermediate JSON: {error}" - ), + tracing::error!( + target: LOG_TARGET, + direction = direction, + error = %error, + "Failed to serialize message into intermediate JSON" ); return None; } @@ -133,11 +133,12 @@ pub fn rmcp_client_tx_to_internal( match serde_json::from_value(value.clone()) { Ok(parsed) => Some(parsed), Err(error) => { - logger::error_with_target( - target, - format!( - "{direction}: failed to parse converted JSON payload: {error}; payload={value:?}" - ), + tracing::error!( + target: LOG_TARGET, + direction = direction, + error = %error, + payload = ?value, + "Failed to parse converted JSON payload" ); None } diff --git a/src/transport/client.rs b/src/transport/client.rs index 1cc1a0e..e41e518 100644 --- a/src/transport/client.rs +++ b/src/transport/client.rs @@ -18,6 +18,10 @@ use crate::encryption; use crate::relay::RelayPool; use crate::transport::base::BaseTransport; +use crate::util::tracing_setup; + +const LOG_TARGET: &str = "contextvm_sdk::transport::client"; + /// Configuration for the client transport. pub struct NostrClientTransportConfig { /// Relay URLs to connect to. @@ -30,6 +34,8 @@ pub struct NostrClientTransportConfig { pub is_stateless: bool, /// Response timeout (default: 30s). pub timeout: Duration, + /// Optional log file path. Logs always go to stdout and are also appended here when set. + pub log_file_path: Option, } impl Default for NostrClientTransportConfig { @@ -40,6 +46,7 @@ impl Default for NostrClientTransportConfig { encryption_mode: EncryptionMode::Optional, is_stateless: false, timeout: Duration::from_secs(30), + log_file_path: None, } } } @@ -62,12 +69,35 @@ impl NostrClientTransport { where T: IntoNostrSigner, { - let server_pubkey = PublicKey::from_hex(&config.server_pubkey) - .map_err(|e| Error::Other(format!("Invalid server pubkey: {e}")))?; - - let relay_pool = Arc::new(RelayPool::new(signer).await?); + tracing_setup::init_tracer(config.log_file_path.as_deref())?; + + let server_pubkey = PublicKey::from_hex(&config.server_pubkey).map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + server_pubkey = %config.server_pubkey, + "Invalid server pubkey" + ); + Error::Other(format!("Invalid server pubkey: {error}")) + })?; + + let relay_pool = Arc::new(RelayPool::new(signer).await.map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to initialize relay pool for client transport" + ); + error + })?); let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + tracing::info!( + target: LOG_TARGET, + relay_count = config.relay_urls.len(), + stateless = config.is_stateless, + encryption_mode = ?config.encryption_mode, + "Created client transport" + ); Ok(Self { base: BaseTransport { relay_pool, @@ -84,12 +114,44 @@ impl NostrClientTransport { /// Connect and start listening for responses. pub async fn start(&mut self) -> Result<()> { - self.base.connect(&self.config.relay_urls).await?; - - let pubkey = self.base.get_public_key().await?; - tracing::info!(pubkey = %pubkey.to_hex(), "Client transport started"); - - self.base.subscribe_for_pubkey(&pubkey).await?; + self.base + .connect(&self.config.relay_urls) + .await + .map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to connect client transport to relays" + ); + error + })?; + + let pubkey = self.base.get_public_key().await.map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to fetch client transport public key" + ); + error + })?; + tracing::info!( + target: LOG_TARGET, + pubkey = %pubkey.to_hex(), + "Client transport started" + ); + + self.base + .subscribe_for_pubkey(&pubkey) + .await + .map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + pubkey = %pubkey.to_hex(), + "Failed to subscribe client transport for pubkey" + ); + error + })?; // Spawn event loop let client = self.base.relay_pool.client().clone(); @@ -102,6 +164,11 @@ impl NostrClientTransport { Self::event_loop(client, pending, server_pubkey, tx, encryption_mode).await; }); + tracing::info!( + target: LOG_TARGET, + relay_count = self.config.relay_urls.len(), + "Client transport event loop spawned" + ); Ok(()) } @@ -137,7 +204,17 @@ impl NostrClientTransport { tags, None, ) - .await?; + .await + .map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + server_pubkey = %self.server_pubkey.to_hex(), + method = ?message.method(), + "Failed to send client message" + ); + error + })?; if matches!(message, JsonRpcMessage::Request(_)) { self.pending_requests @@ -146,6 +223,12 @@ impl NostrClientTransport { .insert(event_id.to_hex()); } + tracing::debug!( + target: LOG_TARGET, + event_id = %event_id.to_hex(), + method = ?message.method(), + "Sent client message" + ); Ok(()) } @@ -195,8 +278,12 @@ impl NostrClientTransport { // Single-layer NIP-44 decrypt (matches JS/TS SDK) let signer = match client.signer().await { Ok(s) => s, - Err(e) => { - tracing::error!("Failed to get signer: {e}"); + Err(error) => { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to get signer" + ); continue; } }; @@ -207,14 +294,22 @@ impl NostrClientTransport { let e_tag = serializers::get_tag_value(&inner.tags, "e"); (inner.content, inner.pubkey, e_tag) } - Err(e) => { - tracing::error!("Failed to parse inner event: {e}"); + Err(error) => { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to parse inner event" + ); continue; } } } - Err(e) => { - tracing::error!("Failed to decrypt gift wrap: {e}"); + Err(error) => { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to decrypt gift wrap" + ); continue; } } @@ -225,7 +320,12 @@ impl NostrClientTransport { // Verify it's from our server if actual_pubkey != server_pubkey { - tracing::debug!("Skipping event from unexpected pubkey"); + tracing::debug!( + target: LOG_TARGET, + event_pubkey = %actual_pubkey.to_hex(), + expected_pubkey = %server_pubkey.to_hex(), + "Skipping event from unexpected pubkey" + ); continue; } @@ -233,7 +333,11 @@ impl NostrClientTransport { if let Some(ref correlated_id) = e_tag { let is_pending = pending.read().await.contains(correlated_id.as_str()); if !is_pending { - tracing::warn!(e_tag = %correlated_id, "Response for unknown request"); + tracing::warn!( + target: LOG_TARGET, + correlated_event_id = %correlated_id, + "Response for unknown request" + ); continue; } } @@ -265,6 +369,7 @@ mod tests { assert_eq!(config.encryption_mode, EncryptionMode::Optional); assert!(!config.is_stateless); assert_eq!(config.timeout, Duration::from_secs(30)); + assert!(config.log_file_path.is_none()); } #[test] diff --git a/src/transport/server.rs b/src/transport/server.rs index 4ccbe15..dbf4520 100644 --- a/src/transport/server.rs +++ b/src/transport/server.rs @@ -19,6 +19,10 @@ use crate::encryption; use crate::relay::RelayPool; use crate::transport::base::BaseTransport; +use crate::util::tracing_setup; + +const LOG_TARGET: &str = "contextvm_sdk::transport::server"; + /// Configuration for the server transport. pub struct NostrServerTransportConfig { /// Relay URLs to connect to. @@ -37,6 +41,8 @@ pub struct NostrServerTransportConfig { pub cleanup_interval: Duration, /// Session timeout (default: 300s). pub session_timeout: Duration, + /// Optional log file path. Logs always go to stdout and are also appended here when set. + pub log_file_path: Option, } impl Default for NostrServerTransportConfig { @@ -50,6 +56,7 @@ impl Default for NostrServerTransportConfig { excluded_capabilities: Vec::new(), cleanup_interval: Duration::from_secs(60), session_timeout: Duration::from_secs(300), + log_file_path: None, } } } @@ -86,9 +93,25 @@ impl NostrServerTransport { where T: IntoNostrSigner, { - let relay_pool = Arc::new(RelayPool::new(signer).await?); + tracing_setup::init_tracer(config.log_file_path.as_deref())?; + + let relay_pool = Arc::new(RelayPool::new(signer).await.map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to initialize relay pool for server transport" + ); + error + })?); let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + tracing::info!( + target: LOG_TARGET, + relay_count = config.relay_urls.len(), + announced = config.is_announced_server, + encryption_mode = ?config.encryption_mode, + "Created server transport" + ); Ok(Self { base: BaseTransport { relay_pool, @@ -105,12 +128,44 @@ impl NostrServerTransport { /// Start listening for incoming requests. pub async fn start(&mut self) -> Result<()> { - self.base.connect(&self.config.relay_urls).await?; - - let pubkey = self.base.get_public_key().await?; - tracing::info!(pubkey = %pubkey.to_hex(), "Server transport started"); + self.base + .connect(&self.config.relay_urls) + .await + .map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to connect server transport to relays" + ); + error + })?; + + let pubkey = self.base.get_public_key().await.map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to fetch server transport public key" + ); + error + })?; + tracing::info!( + target: LOG_TARGET, + pubkey = %pubkey.to_hex(), + "Server transport started" + ); - self.base.subscribe_for_pubkey(&pubkey).await?; + self.base + .subscribe_for_pubkey(&pubkey) + .await + .map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + pubkey = %pubkey.to_hex(), + "Failed to subscribe server transport for pubkey" + ); + error + })?; // Spawn event loop let client = self.base.relay_pool.client().clone(); @@ -151,11 +206,22 @@ impl NostrServerTransport { ) .await; if cleaned > 0 { - tracing::info!(cleaned, "Cleaned up inactive sessions"); + tracing::info!( + target: LOG_TARGET, + cleaned_sessions = cleaned, + "Cleaned up inactive sessions" + ); } } }); + tracing::info!( + target: LOG_TARGET, + relay_count = self.config.relay_urls.len(), + cleanup_interval_secs = self.config.cleanup_interval.as_secs(), + session_timeout_secs = self.config.session_timeout.as_secs(), + "Server transport loops spawned" + ); Ok(()) } @@ -172,14 +238,26 @@ impl NostrServerTransport { let event_to_client = self.event_to_client.read().await; let client_pubkey_hex = event_to_client .get(event_id) - .ok_or_else(|| Error::Other(format!("No client found for event {event_id}")))? + .ok_or_else(|| { + tracing::error!( + target: LOG_TARGET, + event_id = %event_id, + "No client found for response correlation" + ); + Error::Other(format!("No client found for event {event_id}")) + })? .clone(); drop(event_to_client); let sessions = self.sessions.read().await; - let session = sessions - .get(&client_pubkey_hex) - .ok_or_else(|| Error::Other(format!("No session for client {client_pubkey_hex}")))?; + let session = sessions.get(&client_pubkey_hex).ok_or_else(|| { + tracing::error!( + target: LOG_TARGET, + client_pubkey = %client_pubkey_hex, + "No session for correlated client" + ); + Error::Other(format!("No session for client {client_pubkey_hex}")) + })?; // Restore original request ID if let Some(original_id) = session.pending_requests.get(event_id) { @@ -193,11 +271,25 @@ impl NostrServerTransport { let is_encrypted = session.is_encrypted; drop(sessions); - let client_pubkey = - PublicKey::from_hex(&client_pubkey_hex).map_err(|e| Error::Other(e.to_string()))?; - - let event_id_parsed = - EventId::from_hex(event_id).map_err(|e| Error::Other(e.to_string()))?; + let client_pubkey = PublicKey::from_hex(&client_pubkey_hex).map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + client_pubkey = %client_pubkey_hex, + "Invalid client pubkey in session map" + ); + Error::Other(error.to_string()) + })?; + + let event_id_parsed = EventId::from_hex(event_id).map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + event_id = %event_id, + "Invalid event id while sending response" + ); + Error::Other(error.to_string()) + })?; let tags = BaseTransport::create_response_tags(&client_pubkey, &event_id_parsed); @@ -209,7 +301,17 @@ impl NostrServerTransport { tags, Some(is_encrypted), ) - .await?; + .await + .map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + client_pubkey = %client_pubkey_hex, + event_id = %event_id, + "Failed to publish response message" + ); + error + })?; // Clean up let mut sessions = self.sessions.write().await; @@ -224,6 +326,13 @@ impl NostrServerTransport { self.event_to_client.write().await.remove(event_id); + tracing::debug!( + target: LOG_TARGET, + client_pubkey = %client_pubkey_hex, + event_id = %event_id, + encrypted = is_encrypted, + "Sent server response and cleaned correlation state" + ); Ok(()) } @@ -274,8 +383,13 @@ impl NostrServerTransport { drop(sessions); for pubkey in initialized { - if let Err(e) = self.send_notification(&pubkey, notification, None).await { - tracing::error!(client = %pubkey, "Failed to send notification: {e}"); + if let Err(error) = self.send_notification(&pubkey, notification, None).await { + tracing::error!( + target: LOG_TARGET, + error = %error, + client_pubkey = %pubkey, + "Failed to send notification" + ); } } Ok(()) @@ -489,14 +603,23 @@ impl NostrServerTransport { || event.kind == Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND) { if encryption_mode == EncryptionMode::Disabled { - tracing::warn!("Received encrypted message but encryption is disabled"); + tracing::warn!( + target: LOG_TARGET, + event_id = %event.id.to_hex(), + sender_pubkey = %event.pubkey.to_hex(), + "Received encrypted message but encryption is disabled" + ); continue; } // Single-layer NIP-44 decrypt (matches JS/TS SDK) let signer = match client.signer().await { Ok(s) => s, - Err(e) => { - tracing::error!("Failed to get signer: {e}"); + Err(error) => { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to get signer" + ); continue; } }; @@ -512,21 +635,30 @@ impl NostrServerTransport { inner.id.to_hex(), true, ), - Err(e) => { - tracing::error!("Failed to parse inner event: {e}"); + Err(error) => { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to parse inner event" + ); continue; } } } - Err(e) => { - tracing::error!("Failed to decrypt: {e}"); + Err(error) => { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to decrypt" + ); continue; } } } else { if encryption_mode == EncryptionMode::Required { tracing::warn!( - pubkey = %event.pubkey, + target: LOG_TARGET, + sender_pubkey = %event.pubkey.to_hex(), "Received unencrypted message but encryption is required" ); continue; @@ -543,7 +675,11 @@ impl NostrServerTransport { let mcp_msg = match serializers::nostr_event_to_mcp_message(&content) { Some(msg) => msg, None => { - tracing::warn!("Invalid MCP message from {sender_pubkey}"); + tracing::warn!( + target: LOG_TARGET, + sender_pubkey = %sender_pubkey, + "Invalid MCP message" + ); continue; } }; @@ -565,8 +701,9 @@ impl NostrServerTransport { if !allowed_pubkeys.contains(&sender_pubkey) && !is_excluded { tracing::warn!( - pubkey = %sender_pubkey, - method = %method, + target: LOG_TARGET, + sender_pubkey = %sender_pubkey, + method = method, "Unauthorized request" ); continue; @@ -647,7 +784,11 @@ impl NostrServerTransport { for event_id in session.event_to_progress_token.keys() { event_map.remove(event_id); } - tracing::debug!(client = %pubkey, "Session expired"); + tracing::debug!( + target: LOG_TARGET, + client_pubkey = %pubkey, + "Session expired" + ); cleaned += 1; false } else { @@ -882,5 +1023,6 @@ mod tests { assert_eq!(config.cleanup_interval, Duration::from_secs(60)); assert_eq!(config.session_timeout, Duration::from_secs(300)); assert!(config.server_info.is_none()); + assert!(config.log_file_path.is_none()); } } From bbe369d72ea5f4585fbd5bac5b94a95f75c01767 Mon Sep 17 00:00:00 2001 From: Kushagra Date: Tue, 7 Apr 2026 03:48:17 +0530 Subject: [PATCH 19/69] feat: added exmaples with log files --- Cargo.toml | 2 +- examples/gateway.rs | 23 ++++++++++++++++++++++- examples/proxy.rs | 35 ++++++++++++++++++++++++++++++----- src/core/constants.rs | 1 - 4 files changed, 53 insertions(+), 8 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0ca4b98..e4e03ba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ nostr-sdk = { version = "0.43", features = ["nip59"] } # Logging tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } # Optional MCP integration (Rust equivalent to TS @modelcontextprotocol/sdk) rmcp = { version = "0.16.0", features = ["server", "client", "macros", "transport-worker"], optional = true } @@ -34,6 +35,5 @@ rmcp = ["dep:rmcp"] [dev-dependencies] tokio-test = "0.4" -tracing-subscriber = { version = "0.3", features = ["env-filter"] } anyhow = "1" schemars = "0.8" diff --git a/examples/gateway.rs b/examples/gateway.rs index 2effaba..41543d1 100644 --- a/examples/gateway.rs +++ b/examples/gateway.rs @@ -2,6 +2,8 @@ //! //! This demonstrates how to create a ContextVM gateway that receives //! MCP requests over Nostr and responds to them. +//! +//! Usage: cargo run --example gateway -- [--log-file ] use contextvm_sdk::core::types::*; use contextvm_sdk::gateway::{GatewayConfig, NostrMCPGateway}; @@ -10,7 +12,25 @@ use contextvm_sdk::transport::server::NostrServerTransportConfig; #[tokio::main] async fn main() -> contextvm_sdk::Result<()> { - tracing_subscriber::fmt::init(); + let args: Vec = std::env::args().skip(1).collect(); + let mut log_file_path: Option = None; + + let mut index = 0; + while index < args.len() { + match args[index].as_str() { + "--log-file" => { + index += 1; + let Some(path) = args.get(index) else { + panic!("Usage: gateway [--log-file ]"); + }; + log_file_path = Some(path.clone()); + } + other => { + panic!("Unknown argument: {other}. Usage: gateway [--log-file ]"); + } + } + index += 1; + } // Generate ephemeral keys for this session let keys = signer::generate(); @@ -26,6 +46,7 @@ async fn main() -> contextvm_sdk::Result<()> { ..Default::default() }), is_announced_server: true, + log_file_path, ..Default::default() }, }; diff --git a/examples/proxy.rs b/examples/proxy.rs index c0fcd55..da5a63e 100644 --- a/examples/proxy.rs +++ b/examples/proxy.rs @@ -1,6 +1,6 @@ //! Example: Connect to a remote MCP server via Nostr and call tools/list. //! -//! Usage: cargo run --example proxy -- +//! Usage: cargo run --example proxy -- [--log-file ] use contextvm_sdk::core::types::*; use contextvm_sdk::proxy::{NostrMCPProxy, ProxyConfig}; @@ -8,11 +8,35 @@ use contextvm_sdk::signer; use contextvm_sdk::transport::client::NostrClientTransportConfig; #[tokio::main] async fn main() -> contextvm_sdk::Result<()> { - tracing_subscriber::fmt::init(); + let args: Vec = std::env::args().skip(1).collect(); + let mut server_pubkey_hex: Option = None; + let mut log_file_path: Option = None; - let server_pubkey_hex = std::env::args() - .nth(1) - .expect("Usage: proxy "); + let mut index = 0; + while index < args.len() { + match args[index].as_str() { + "--log-file" => { + index += 1; + let Some(path) = args.get(index) else { + panic!("Usage: proxy [--log-file ]"); + }; + log_file_path = Some(path.clone()); + } + value => { + if server_pubkey_hex.is_none() { + server_pubkey_hex = Some(value.to_string()); + } else { + panic!( + "Unknown argument: {value}. Usage: proxy [--log-file ]" + ); + } + } + } + index += 1; + } + + let server_pubkey_hex = + server_pubkey_hex.expect("Usage: proxy [--log-file ]"); let keys = signer::generate(); println!("Client pubkey: {}", keys.public_key().to_hex()); @@ -22,6 +46,7 @@ async fn main() -> contextvm_sdk::Result<()> { relay_urls: vec!["wss://relay.damus.io".to_string()], server_pubkey: server_pubkey_hex, encryption_mode: EncryptionMode::Optional, + log_file_path, ..Default::default() }, }; diff --git a/src/core/constants.rs b/src/core/constants.rs index 85cf82b..870d165 100644 --- a/src/core/constants.rs +++ b/src/core/constants.rs @@ -109,7 +109,6 @@ pub const UNENCRYPTED_KINDS: &[u16] = &[ PROMPTS_LIST_KIND, ]; - #[cfg(feature = "rmcp")] pub fn mcp_protocol_version() -> &'static str { use std::sync::OnceLock; From 7f4785755d386cca3eeda1b1df661dcb3ec0d97b Mon Sep 17 00:00:00 2001 From: Kushagra Date: Tue, 7 Apr 2026 05:21:19 +0530 Subject: [PATCH 20/69] setup basic github workflow --- .github/workflows/rust.yml | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 .github/workflows/rust.yml diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml new file mode 100644 index 0000000..cc4327a --- /dev/null +++ b/.github/workflows/rust.yml @@ -0,0 +1,33 @@ +name: Rust CI + +on: + push: + branches: ["main"] + pull_request: + branches: ["main"] + +env: + CARGO_TERM_COLOR: always + +jobs: + ci: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache dependencies + uses: Swatinem/rust-cache@v2 + + - name: Check + run: cargo check --all --all-features + + - name: Test + run: cargo test --all --all-features + + - name: Build + run: cargo build --verbose \ No newline at end of file From 75bb1abd42c8da8432e4f1c30c51071e1d33d26e Mon Sep 17 00:00:00 2001 From: Anshuman Singh Date: Tue, 7 Apr 2026 18:03:34 +0530 Subject: [PATCH 21/69] fix: verify inner event signatures after gift-wrap decryption --- src/encryption/mod.rs | 41 +++++++++++++++++++++++++++++++++++++++++ src/transport/client.rs | 4 ++++ src/transport/server.rs | 20 ++++++++++++++------ 3 files changed, 59 insertions(+), 6 deletions(-) diff --git a/src/encryption/mod.rs b/src/encryption/mod.rs index 0d5b369..088ca8b 100644 --- a/src/encryption/mod.rs +++ b/src/encryption/mod.rs @@ -264,4 +264,45 @@ mod tests { // (it uses an ephemeral key, like the JS SDK) assert_ne!(gift_wrap_event.pubkey, sender_keys.public_key()); } + + /// Regression: gift-wrapped inner events with a tampered pubkey must be + /// caught by `Event::verify()`. + #[tokio::test] + async fn test_forged_inner_event_detected_by_verify() { + let real_sender = Keys::generate(); + let impersonated = Keys::generate(); + let recipient = Keys::generate(); + + let mcp_content = r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#; + + // Step 1: build a legitimately signed inner event + let inner_event = EventBuilder::new(Kind::Custom(25910), mcp_content) + .tag(Tag::public_key(recipient.public_key())) + .sign_with_keys(&real_sender) + .unwrap(); + + // Step 2: tamper the pubkey (keep original, now-invalid, signature) + let mut forged_json: serde_json::Value = + serde_json::to_value(&inner_event).unwrap(); + forged_json["pubkey"] = + serde_json::Value::String(impersonated.public_key().to_hex()); + let forged_str = serde_json::to_string(&forged_json).unwrap(); + + // Step 3: gift-wrap the forged payload + let (gift_wrap, _) = + create_simple_gift_wrap(&forged_str, &recipient.public_key()).await; + + // Decrypt + parse both succeed — the forgery is syntactically valid + let decrypted = decrypt_gift_wrap_single_layer(&recipient, &gift_wrap) + .await + .unwrap(); + let parsed: Event = serde_json::from_str(&decrypted).unwrap(); + assert_eq!(parsed.pubkey, impersonated.public_key()); + + // Signature verification catches the tampered pubkey + assert!( + parsed.verify().is_err(), + "forged inner event must fail signature verification" + ); + } } diff --git a/src/transport/client.rs b/src/transport/client.rs index 1cc1a0e..cef1948 100644 --- a/src/transport/client.rs +++ b/src/transport/client.rs @@ -204,6 +204,10 @@ impl NostrClientTransport { Ok(decrypted_json) => { match serde_json::from_str::(&decrypted_json) { Ok(inner) => { + if let Err(e) = inner.verify() { + tracing::warn!("Inner event signature verification failed: {e}"); + continue; + } let e_tag = serializers::get_tag_value(&inner.tags, "e"); (inner.content, inner.pubkey, e_tag) } diff --git a/src/transport/server.rs b/src/transport/server.rs index 4ccbe15..82f0552 100644 --- a/src/transport/server.rs +++ b/src/transport/server.rs @@ -506,12 +506,20 @@ impl NostrServerTransport { // Use the INNER event's ID for correlation — the client // registers the inner event ID in its correlation store. match serde_json::from_str::(&decrypted_json) { - Ok(inner) => ( - inner.content, - inner.pubkey.to_hex(), - inner.id.to_hex(), - true, - ), + Ok(inner) => { + if let Err(e) = inner.verify() { + tracing::warn!( + "Inner event signature verification failed: {e}" + ); + continue; + } + ( + inner.content, + inner.pubkey.to_hex(), + inner.id.to_hex(), + true, + ) + } Err(e) => { tracing::error!("Failed to parse inner event: {e}"); continue; From 860834070c359503afa570464ed6665765a452e5 Mon Sep 17 00:00:00 2001 From: Harsh Date: Tue, 7 Apr 2026 17:55:37 +0530 Subject: [PATCH 22/69] test: add conformance tests for initialization flow wire format --- tests/conformance_wire_format.rs | 183 +++++++++++++++++++++++++++++++ 1 file changed, 183 insertions(+) create mode 100644 tests/conformance_wire_format.rs diff --git a/tests/conformance_wire_format.rs b/tests/conformance_wire_format.rs new file mode 100644 index 0000000..697376a --- /dev/null +++ b/tests/conformance_wire_format.rs @@ -0,0 +1,183 @@ +//! Conformance tests for ContextVM wire format: MCP JSON-RPC carried in Nostr kind 25910 events. +//! +//! These mirror the layering style of `src/rmcp_transport/pipeline_tests.rs`: build the JSON-RPC +//! payload, serialize through the same helpers the transport uses (`mcp_to_nostr_event`, tag +//! builders from [`BaseTransport`]), sign with nostr-sdk, then assert on kind, tags, and content. + +use contextvm_sdk::core::constants::{ + mcp_protocol_version, tags, CTXVM_MESSAGES_KIND, INITIALIZE_METHOD, + NOTIFICATIONS_INITIALIZED_METHOD, +}; +use contextvm_sdk::core::serializers; +use contextvm_sdk::core::types::{ + JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, +}; +use contextvm_sdk::transport::base::BaseTransport; +use nostr_sdk::prelude::*; + +fn assert_ctxvm_message_kind(event: &Event) { + assert_eq!( + event.kind, + Kind::Custom(CTXVM_MESSAGES_KIND), + "ContextVM MCP messages must use kind {}", + CTXVM_MESSAGES_KIND + ); +} + +fn p_tag_hex(event: &Event) -> Option { + serializers::get_tag_value(&event.tags, tags::PUBKEY) +} + +fn e_tag_hex(event: &Event) -> Option { + serializers::get_tag_value(&event.tags, tags::EVENT_ID) +} + +// ── Initialize request ─────────────────────────────────────────────────────── + +#[test] +fn ctxvm_initialize_request_has_kind_p_tag_and_jsonrpc_initialize() { + let server_keys = Keys::generate(); + let server_pk = server_keys.public_key(); + + let init_req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: INITIALIZE_METHOD.to_string(), + params: Some(serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "capabilities": {}, + "clientInfo": { "name": "conformance-test", "version": "0.0.0" } + })), + }); + + let recipient_tags = BaseTransport::create_recipient_tags(&server_pk); + let builder = serializers::mcp_to_nostr_event(&init_req, CTXVM_MESSAGES_KIND, recipient_tags) + .expect("initialize request should serialize to event content"); + + let client_keys = Keys::generate(); + let event = builder + .sign_with_keys(&client_keys) + .expect("sign initialize request event"); + + assert_ctxvm_message_kind(&event); + assert_eq!( + p_tag_hex(&event), + Some(server_pk.to_hex()), + "initialize request must target the server via p tag" + ); + + let msg = serializers::nostr_event_to_mcp_message(&event.content) + .expect("event content should be valid JSON-RPC"); + assert!(msg.is_request()); + assert_eq!(msg.method(), Some(INITIALIZE_METHOD)); + + // Parse at the raw JSON level to verify wire format independently of the typed deserializer. + let v: serde_json::Value = + serde_json::from_str(&event.content).expect("content must be JSON object"); + assert_eq!(v["jsonrpc"], "2.0"); + assert_eq!(v["id"], serde_json::json!(1)); +} + +// ── Initialize response ────────────────────────────────────────────────────── + +#[test] +fn ctxvm_initialize_response_has_kind_e_tag_and_result_protocol_version() { + let server_keys = Keys::generate(); + let server_pk = server_keys.public_key(); + let client_keys = Keys::generate(); + let client_pk = client_keys.public_key(); + + // Signed request event provides the Nostr event id referenced by e on the response. + let init_req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("corr-1"), + method: INITIALIZE_METHOD.to_string(), + params: Some(serde_json::json!({})), + }); + let recipient_tags = BaseTransport::create_recipient_tags(&server_pk); + let request_event = serializers::mcp_to_nostr_event(&init_req, CTXVM_MESSAGES_KIND, recipient_tags) + .expect("request event for response correlation should serialize") + .sign_with_keys(&client_keys) + .expect("sign request event for correlation"); + + let init_resp = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("corr-1"), + result: serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "serverInfo": { + "name": "conformance-test-server", + "version": "0.0.0" + }, + "capabilities": {} + }), + }); + + let response_tags = BaseTransport::create_response_tags(&client_pk, &request_event.id); + let response_event = + serializers::mcp_to_nostr_event(&init_resp, CTXVM_MESSAGES_KIND, response_tags) + .expect("initialize response should serialize") + .sign_with_keys(&server_keys) + .expect("sign initialize response event"); + + assert_ctxvm_message_kind(&response_event); + assert_eq!( + p_tag_hex(&response_event), + Some(client_pk.to_hex()), + "initialize response must route back to the client via p tag" + ); + assert_eq!( + e_tag_hex(&response_event), + Some(request_event.id.to_hex()), + "initialize response must correlate to the request Nostr event via e tag" + ); + + let v: serde_json::Value = + serde_json::from_str(&response_event.content).expect("content must be JSON"); + assert_eq!(v["jsonrpc"], "2.0"); + assert_eq!(v["id"], serde_json::json!("corr-1")); + assert!(v["result"]["protocolVersion"].is_string()); + assert!(v["result"]["serverInfo"]["name"].is_string()); +} + +// ── notifications/initialized ────────────────────────────────────────────── + +#[test] +fn ctxvm_notifications_initialized_has_kind_p_tag_and_method() { + let server_keys = Keys::generate(); + let server_pk = server_keys.public_key(); + let client_keys = Keys::generate(); + + let notif = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: NOTIFICATIONS_INITIALIZED_METHOD.to_string(), + params: None, + }); + + let recipient_tags = BaseTransport::create_recipient_tags(&server_pk); + let event = serializers::mcp_to_nostr_event(¬if, CTXVM_MESSAGES_KIND, recipient_tags) + .expect("notification should serialize") + // Client sends this to the server; signer must differ from `p` so the tag is not stripped. + .sign_with_keys(&client_keys) + .expect("sign initialized notification"); + + assert_ctxvm_message_kind(&event); + assert_eq!( + p_tag_hex(&event), + Some(server_pk.to_hex()), + "initialized notification must include server p tag" + ); + + let msg = serializers::nostr_event_to_mcp_message(&event.content).expect("parse content"); + assert!(msg.is_notification()); + assert_eq!(msg.method(), Some(NOTIFICATIONS_INITIALIZED_METHOD)); + + // Parse at the raw JSON level to verify wire format independently of the typed deserializer. + let v: serde_json::Value = + serde_json::from_str(&event.content).expect("content must be JSON object"); + assert_eq!(v["jsonrpc"], "2.0"); + assert!( + v.get("id").map_or(true, serde_json::Value::is_null), + "JSON-RPC notifications must not include an id" + ); +} From c3a90736199681d37a0d584160d289913febc83b Mon Sep 17 00:00:00 2001 From: Anshuman Singh Date: Tue, 7 Apr 2026 18:33:03 +0530 Subject: [PATCH 23/69] fix: resolve all clippy warnings, remove orphaned doc comment and duplicate constant - Remove KIND_GIFT_WRAP (duplicate of GIFT_WRAP_KIND); update usage in encryption tests - Remove orphaned doc comment that was attached to the wrong item (DEFAULT_LRU_SIZE) - Promote 6 const-value range assertions from runtime tests to a compile-time const block cargo clippy --all-targets now reports 0 warnings. All 101 tests pass. --- src/core/constants.rs | 40 +++++++++++++--------------------------- src/encryption/mod.rs | 4 ++-- 2 files changed, 15 insertions(+), 29 deletions(-) diff --git a/src/core/constants.rs b/src/core/constants.rs index 85cf82b..e720fb5 100644 --- a/src/core/constants.rs +++ b/src/core/constants.rs @@ -33,7 +33,6 @@ pub const RESOURCETEMPLATES_LIST_KIND: u16 = 11319; /// Prompts list (addressable, kind 11320) pub const PROMPTS_LIST_KIND: u16 = 11320; -pub const KIND_GIFT_WRAP: u16 = 1059; /// Nostr tag constants pub mod tags { /// Public key tag @@ -70,11 +69,6 @@ pub mod tags { /// Maximum message size (1MB) pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024; -/// MCP protocol version string used in initialize responses. -/// -/// Matches the `protocolVersion` field of the `InitializeResult` JSON-RPC response. -/// Keep this in sync with the MCP spec and rmcp's `ProtocolVersion::LATEST`. - /// Default LRU cache size for deduplication pub const DEFAULT_LRU_SIZE: usize = 5000; @@ -124,6 +118,19 @@ pub const fn mcp_protocol_version() -> &'static str { "2025-11-25" } +// Compile-time range checks (NIP-01 kind ranges). +// Placed at module level so violations are caught in every build, not just `cargo test`. +const _: () = { + // Ephemeral events: 20000 <= kind < 30000 + assert!(EPHEMERAL_GIFT_WRAP_KIND >= 20000); + assert!(EPHEMERAL_GIFT_WRAP_KIND < 30000); + assert!(CTXVM_MESSAGES_KIND >= 20000); + assert!(CTXVM_MESSAGES_KIND < 30000); + // Replaceable events: 10000 <= kind < 20000 + assert!(RELAY_LIST_METADATA_KIND >= 10000); + assert!(RELAY_LIST_METADATA_KIND < 20000); +}; + #[cfg(test)] mod tests { use super::*; @@ -158,27 +165,6 @@ mod tests { ); } - #[test] - fn test_ephemeral_gift_wrap_in_ephemeral_range() { - // NIP-01: ephemeral events are 20000 <= kind < 30000 - assert!(EPHEMERAL_GIFT_WRAP_KIND >= 20000); - assert!(EPHEMERAL_GIFT_WRAP_KIND < 30000); - } - - #[test] - fn test_ctxvm_messages_in_ephemeral_range() { - // NIP-01: ephemeral events are 20000 <= kind < 30000 - assert!(CTXVM_MESSAGES_KIND >= 20000); - assert!(CTXVM_MESSAGES_KIND < 30000); - } - - #[test] - fn test_relay_list_metadata_in_replaceable_range() { - // NIP-01: replaceable events are 10000 <= kind < 20000 - assert!(RELAY_LIST_METADATA_KIND >= 10000); - assert!(RELAY_LIST_METADATA_KIND < 20000); - } - #[test] fn test_announcement_kinds_in_addressable_range() { // NIP-01: addressable events are 30000 <= kind < 40000 diff --git a/src/encryption/mod.rs b/src/encryption/mod.rs index 0d5b369..2798a57 100644 --- a/src/encryption/mod.rs +++ b/src/encryption/mod.rs @@ -111,7 +111,7 @@ pub async fn gift_wrap( #[cfg(test)] mod tests { - use crate::core::constants::KIND_GIFT_WRAP; + use crate::core::constants::GIFT_WRAP_KIND; use super::*; @@ -153,7 +153,7 @@ mod tests { .unwrap(); // Build kind 1059 event - let builder = EventBuilder::new(Kind::from(KIND_GIFT_WRAP), encrypted) + let builder = EventBuilder::new(Kind::from(GIFT_WRAP_KIND), encrypted) .tag(Tag::public_key(*recipient)); let event = builder.sign_with_keys(&ephemeral).unwrap(); From 3f00ece9814afb3b4066ef984c96b94d33213d7f Mon Sep 17 00:00:00 2001 From: Anshuman Singh Date: Wed, 8 Apr 2026 03:18:11 +0530 Subject: [PATCH 24/69] fix: route event loops through validated message parsing Both event loops called serializers::nostr_event_to_mcp_message() directly, bypassing size validation and JSON-RPC 2.0 structure checks. Replace with validation::validate_and_parse() which enforces both. --- src/core/validation.rs | 37 +++++++++++++++++++++++++++++++++++++ src/transport/base.rs | 8 +------- src/transport/client.rs | 3 ++- src/transport/server.rs | 4 ++-- 4 files changed, 42 insertions(+), 10 deletions(-) diff --git a/src/core/validation.rs b/src/core/validation.rs index d409b7b..9e6dc5d 100644 --- a/src/core/validation.rs +++ b/src/core/validation.rs @@ -8,6 +8,17 @@ pub fn validate_message_size(content: &str) -> bool { content.len() <= MAX_MESSAGE_SIZE } +/// Validate size and structure, then parse into a [`JsonRpcMessage`]. +pub fn validate_and_parse(content: &str) -> Option { + if !validate_message_size(content) { + tracing::warn!("Message size validation failed: {} bytes", content.len()); + return None; + } + + let value: serde_json::Value = serde_json::from_str(content).ok()?; + validate_message(&value) +} + /// Validate that a JSON value is a well-formed JSON-RPC 2.0 message. /// /// Checks: @@ -61,4 +72,30 @@ mod tests { let big = "x".repeat(MAX_MESSAGE_SIZE + 1); assert!(!validate_message_size(&big)); } + + #[test] + fn test_validate_and_parse_valid_request() { + let content = r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#; + let msg = validate_and_parse(content).unwrap(); + assert!(msg.is_request()); + assert_eq!(msg.method(), Some("tools/list")); + } + + #[test] + fn test_validate_and_parse_rejects_oversized() { + let padding = "x".repeat(MAX_MESSAGE_SIZE); + let content = format!(r#"{{"jsonrpc":"2.0","id":1,"method":"{}"}}"#, padding); + assert!(validate_and_parse(&content).is_none()); + } + + #[test] + fn test_validate_and_parse_rejects_invalid_version() { + let content = r#"{"jsonrpc":"1.0","id":1,"method":"test"}"#; + assert!(validate_and_parse(content).is_none()); + } + + #[test] + fn test_validate_and_parse_rejects_invalid_json() { + assert!(validate_and_parse("not json").is_none()); + } } diff --git a/src/transport/base.rs b/src/transport/base.rs index 419a4b6..58f762b 100644 --- a/src/transport/base.rs +++ b/src/transport/base.rs @@ -87,13 +87,7 @@ impl BaseTransport { /// Convert a Nostr event to an MCP message with validation. pub fn convert_event_to_mcp(&self, content: &str) -> Option { - if !validation::validate_message_size(content) { - tracing::warn!("Message size validation failed: {} bytes", content.len()); - return None; - } - - let value: serde_json::Value = serde_json::from_str(content).ok()?; - validation::validate_message(&value) + validation::validate_and_parse(content) } /// Create a signed Nostr event for an MCP message. diff --git a/src/transport/client.rs b/src/transport/client.rs index cef1948..8ef111e 100644 --- a/src/transport/client.rs +++ b/src/transport/client.rs @@ -14,6 +14,7 @@ use crate::core::constants::*; use crate::core::error::{Error, Result}; use crate::core::serializers; use crate::core::types::*; +use crate::core::validation; use crate::encryption; use crate::relay::RelayPool; use crate::transport::base::BaseTransport; @@ -244,7 +245,7 @@ impl NostrClientTransport { // Parse MCP message if let Some(mcp_msg) = - serializers::nostr_event_to_mcp_message(&actual_event_content) + validation::validate_and_parse(&actual_event_content) { // Clean up pending request if let Some(ref correlated_id) = e_tag { diff --git a/src/transport/server.rs b/src/transport/server.rs index 82f0552..a17ed51 100644 --- a/src/transport/server.rs +++ b/src/transport/server.rs @@ -13,8 +13,8 @@ use tokio::sync::RwLock; use crate::core::constants::*; use crate::core::error::{Error, Result}; -use crate::core::serializers; use crate::core::types::*; +use crate::core::validation; use crate::encryption; use crate::relay::RelayPool; use crate::transport::base::BaseTransport; @@ -548,7 +548,7 @@ impl NostrServerTransport { }; // Parse MCP message - let mcp_msg = match serializers::nostr_event_to_mcp_message(&content) { + let mcp_msg = match validation::validate_and_parse(&content) { Some(msg) => msg, None => { tracing::warn!("Invalid MCP message from {sender_pubkey}"); From 68b47118dcad448235654299a5c939021965d3bd Mon Sep 17 00:00:00 2001 From: Harsh Date: Wed, 8 Apr 2026 04:00:31 +0530 Subject: [PATCH 25/69] test: add wire format conformance tests for tools/list, tools/call, and server announcement --- tests/conformance_wire_format.rs | 284 ++++++++++++++++++++++++++++++- 1 file changed, 282 insertions(+), 2 deletions(-) diff --git a/tests/conformance_wire_format.rs b/tests/conformance_wire_format.rs index 697376a..ee039fc 100644 --- a/tests/conformance_wire_format.rs +++ b/tests/conformance_wire_format.rs @@ -6,7 +6,7 @@ use contextvm_sdk::core::constants::{ mcp_protocol_version, tags, CTXVM_MESSAGES_KIND, INITIALIZE_METHOD, - NOTIFICATIONS_INITIALIZED_METHOD, + NOTIFICATIONS_INITIALIZED_METHOD, SERVER_ANNOUNCEMENT_KIND, }; use contextvm_sdk::core::serializers; use contextvm_sdk::core::types::{ @@ -177,7 +177,287 @@ fn ctxvm_notifications_initialized_has_kind_p_tag_and_method() { serde_json::from_str(&event.content).expect("content must be JSON object"); assert_eq!(v["jsonrpc"], "2.0"); assert!( - v.get("id").map_or(true, serde_json::Value::is_null), + v.get("id").is_none_or(serde_json::Value::is_null), "JSON-RPC notifications must not include an id" ); } + +// ── tools/list request ─────────────────────────────────────────────────────── + +#[test] +fn ctxvm_tools_list_request_has_kind_p_tag_and_method() { + let server_keys = Keys::generate(); + let server_pk = server_keys.public_key(); + + let list_req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(2), + method: "tools/list".to_string(), + params: Some(serde_json::json!({})), + }); + + let recipient_tags = BaseTransport::create_recipient_tags(&server_pk); + let builder = serializers::mcp_to_nostr_event(&list_req, CTXVM_MESSAGES_KIND, recipient_tags) + .expect("tools/list request should serialize to event content"); + + let client_keys = Keys::generate(); + let event = builder + .sign_with_keys(&client_keys) + .expect("sign tools/list request event"); + + assert_ctxvm_message_kind(&event); + assert_eq!( + p_tag_hex(&event), + Some(server_pk.to_hex()), + "tools/list request must target the server via p tag" + ); + + let msg = serializers::nostr_event_to_mcp_message(&event.content) + .expect("event content should be valid JSON-RPC"); + assert!(msg.is_request()); + assert_eq!(msg.method(), Some("tools/list")); + + let v: serde_json::Value = + serde_json::from_str(&event.content).expect("content must be JSON object"); + assert_eq!(v["jsonrpc"], "2.0"); + assert_eq!(v["id"], serde_json::json!(2)); +} + +// ── tools/call request ─────────────────────────────────────────────────────── + +#[test] +fn ctxvm_tools_call_request_has_kind_p_tag_method_and_params() { + let server_keys = Keys::generate(); + let server_pk = server_keys.public_key(); + + let call_req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(3), + method: "tools/call".to_string(), + params: Some(serde_json::json!({ + "name": "add", + "arguments": { "a": 5, "b": 3 } + })), + }); + + let recipient_tags = BaseTransport::create_recipient_tags(&server_pk); + let builder = serializers::mcp_to_nostr_event(&call_req, CTXVM_MESSAGES_KIND, recipient_tags) + .expect("tools/call request should serialize to event content"); + + let client_keys = Keys::generate(); + let event = builder + .sign_with_keys(&client_keys) + .expect("sign tools/call request event"); + + assert_ctxvm_message_kind(&event); + assert_eq!( + p_tag_hex(&event), + Some(server_pk.to_hex()), + "tools/call request must target the server via p tag" + ); + + let msg = serializers::nostr_event_to_mcp_message(&event.content) + .expect("event content should be valid JSON-RPC"); + assert!(msg.is_request()); + assert_eq!(msg.method(), Some("tools/call")); + + let v: serde_json::Value = + serde_json::from_str(&event.content).expect("content must be JSON object"); + assert_eq!(v["jsonrpc"], "2.0"); + assert_eq!(v["id"], serde_json::json!(3)); + assert_eq!(v["params"]["name"], "add"); + assert!( + v["params"]["arguments"].is_object(), + "tools/call params.arguments must be an object on the wire" + ); +} + +// ── tools/list response ─────────────────────────────────────────────────────── + +#[test] +fn ctxvm_tools_list_response_has_kind_e_tag_and_result() { + let server_keys = Keys::generate(); + let server_pk = server_keys.public_key(); + let client_keys = Keys::generate(); + let client_pk = client_keys.public_key(); + + let list_req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("corr-tools-list"), + method: "tools/list".to_string(), + params: Some(serde_json::json!({})), + }); + let recipient_tags = BaseTransport::create_recipient_tags(&server_pk); + let request_event = + serializers::mcp_to_nostr_event(&list_req, CTXVM_MESSAGES_KIND, recipient_tags) + .expect("tools/list request for response correlation should serialize") + .sign_with_keys(&client_keys) + .expect("sign tools/list request event for correlation"); + + let list_resp = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("corr-tools-list"), + result: serde_json::json!({ "tools": [] }), + }); + + let response_tags = BaseTransport::create_response_tags(&client_pk, &request_event.id); + let response_event = + serializers::mcp_to_nostr_event(&list_resp, CTXVM_MESSAGES_KIND, response_tags) + .expect("tools/list response should serialize") + .sign_with_keys(&server_keys) + .expect("sign tools/list response event"); + + assert_ctxvm_message_kind(&response_event); + assert_eq!( + p_tag_hex(&response_event), + Some(client_pk.to_hex()), + "tools/list response must route back to the client via p tag" + ); + assert_eq!( + e_tag_hex(&response_event), + Some(request_event.id.to_hex()), + "tools/list response must correlate to the request Nostr event via e tag" + ); + + let v: serde_json::Value = + serde_json::from_str(&response_event.content).expect("content must be JSON"); + assert_eq!(v["jsonrpc"], "2.0"); + assert_eq!(v["id"], serde_json::json!("corr-tools-list")); + assert!(v["result"]["tools"].is_array()); +} + +// ── tools/call response ─────────────────────────────────────────────────────── + +#[test] +fn ctxvm_tools_call_response_has_kind_e_tag_and_result() { + let server_keys = Keys::generate(); + let server_pk = server_keys.public_key(); + let client_keys = Keys::generate(); + let client_pk = client_keys.public_key(); + + let call_req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("corr-tools-call"), + method: "tools/call".to_string(), + params: Some(serde_json::json!({ + "name": "add", + "arguments": { "a": 5, "b": 3 } + })), + }); + let recipient_tags = BaseTransport::create_recipient_tags(&server_pk); + let request_event = + serializers::mcp_to_nostr_event(&call_req, CTXVM_MESSAGES_KIND, recipient_tags) + .expect("tools/call request for response correlation should serialize") + .sign_with_keys(&client_keys) + .expect("sign tools/call request event for correlation"); + + let call_resp = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("corr-tools-call"), + result: serde_json::json!({ + "content": [{ "type": "text", "text": "8" }], + "isError": false + }), + }); + + let response_tags = BaseTransport::create_response_tags(&client_pk, &request_event.id); + let response_event = + serializers::mcp_to_nostr_event(&call_resp, CTXVM_MESSAGES_KIND, response_tags) + .expect("tools/call response should serialize") + .sign_with_keys(&server_keys) + .expect("sign tools/call response event"); + + assert_ctxvm_message_kind(&response_event); + assert_eq!( + p_tag_hex(&response_event), + Some(client_pk.to_hex()), + "tools/call response must route back to the client via p tag" + ); + assert_eq!( + e_tag_hex(&response_event), + Some(request_event.id.to_hex()), + "tools/call response must correlate to the request Nostr event via e tag" + ); + + let v: serde_json::Value = + serde_json::from_str(&response_event.content).expect("content must be JSON"); + assert_eq!(v["jsonrpc"], "2.0"); + assert_eq!(v["id"], serde_json::json!("corr-tools-call")); + assert!(v["result"]["content"].is_array()); + assert_eq!(v["result"]["isError"], serde_json::json!(false)); +} + +// ── Server announcement (kind 11316) ────────────────────────────────────────── + +#[test] +fn ctxvm_server_announcement_has_kind_and_required_tags() { + let server_keys = Keys::generate(); + + // MCP-flavoured JSON for wire conformance; not the same content shape as `NostrServerTransport::announce` (flat `ServerInfo` only). + let content = serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "serverInfo": { "name": "Test Server" }, + "capabilities": {}, + }); + let content_str = serde_json::to_string(&content).expect("announcement content must serialize"); + + let announcement_tags = vec![ + Tag::custom( + TagKind::Custom(tags::NAME.into()), + vec!["Test Server".to_string()], + ), + Tag::custom( + TagKind::Custom(tags::ABOUT.into()), + vec!["A test server".to_string()], + ), + Tag::custom( + TagKind::Custom(tags::WEBSITE.into()), + vec!["http://localhost".to_string()], + ), + Tag::custom( + TagKind::Custom(tags::SUPPORT_ENCRYPTION.into()), + Vec::::new(), + ), + ]; + + let event = EventBuilder::new(Kind::Custom(SERVER_ANNOUNCEMENT_KIND), content_str) + .tags(announcement_tags) + .sign_with_keys(&server_keys) + .expect("sign server announcement event"); + + assert_eq!( + event.kind, + Kind::Custom(SERVER_ANNOUNCEMENT_KIND), + "server announcement must use kind {}", + SERVER_ANNOUNCEMENT_KIND + ); + assert_eq!(event.pubkey, server_keys.public_key()); + + assert_eq!( + serializers::get_tag_value(&event.tags, tags::NAME).as_deref(), + Some("Test Server") + ); + assert_eq!( + serializers::get_tag_value(&event.tags, tags::ABOUT).as_deref(), + Some("A test server") + ); + assert_eq!( + serializers::get_tag_value(&event.tags, tags::WEBSITE).as_deref(), + Some("http://localhost") + ); + + assert!( + event.tags.iter().any(|t| { + let parts = t.clone().to_vec(); + parts.len() == 1 + && parts.first().map(|s| s.as_str()) == Some(tags::SUPPORT_ENCRYPTION) + }), + "support_encryption must be present as a single-element tag" + ); + + let v: serde_json::Value = + serde_json::from_str(&event.content).expect("announcement content must be JSON"); + assert_eq!(v["protocolVersion"], mcp_protocol_version()); + assert_eq!(v["serverInfo"]["name"], "Test Server"); + assert!(v["capabilities"].is_object()); +} From dbd4484a40600537000a9be5df57620c9df168ad Mon Sep 17 00:00:00 2001 From: Kushagra Date: Wed, 8 Apr 2026 04:06:39 +0530 Subject: [PATCH 26/69] fix: made the log format across all the files consistent --- src/rmcp_transport/worker.rs | 37 ++++++++++++++++++++++++++++++------ src/transport/base.rs | 15 +++++++++++++-- 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/src/rmcp_transport/worker.rs b/src/rmcp_transport/worker.rs index 0d52db0..79bb3dd 100644 --- a/src/rmcp_transport/worker.rs +++ b/src/rmcp_transport/worker.rs @@ -16,6 +16,8 @@ use super::convert::{ rmcp_server_tx_to_internal, }; +const LOG_TARGET: &str = "contextvm_sdk::rmcp_transport::worker"; + /// rmcp server worker wrapper for ContextVM Nostr server transport. pub struct NostrServerWorker { transport: NostrServerTransport, @@ -95,6 +97,7 @@ impl Worker for NostrServerWorker { match &self.active_client_pubkey { Some(active) if active != &client_pubkey => { tracing::warn!( + target: LOG_TARGET, active_client = %active, ignored_client = %client_pubkey, "Ignoring message from second client: rmcp server worker currently supports one active client per worker" @@ -102,7 +105,11 @@ impl Worker for NostrServerWorker { continue; } None => { - tracing::info!(client_pubkey = %client_pubkey, "Binding rmcp server worker to first client session"); + tracing::info!( + target: LOG_TARGET, + client_pubkey = %client_pubkey, + "Binding rmcp server worker to first client session" + ); self.active_client_pubkey = Some(client_pubkey.clone()); } _ => {} @@ -114,7 +121,11 @@ impl Worker for NostrServerWorker { self.request_id_to_event_id.insert(request_key, event_id); } Err(e) => { - tracing::warn!("Failed to serialize request id for correlation map: {e}"); + tracing::warn!( + target: LOG_TARGET, + error = %e, + "Failed to serialize request id for correlation map" + ); } } } @@ -124,7 +135,10 @@ impl Worker for NostrServerWorker { break reason; } } else { - tracing::warn!("Failed to convert incoming server-side message to rmcp format"); + tracing::warn!( + target: LOG_TARGET, + "Failed to convert incoming server-side message to rmcp format" + ); } } outbound = context.recv_from_handler() => { @@ -147,7 +161,11 @@ impl Worker for NostrServerWorker { }; if let Err(e) = self.transport.close().await { - tracing::warn!("Failed to close server transport cleanly: {e}"); + tracing::warn!( + target: LOG_TARGET, + error = %e, + "Failed to close server transport cleanly" + ); } Err(quit_reason) @@ -220,7 +238,10 @@ impl Worker for NostrClientWorker { break reason; } } else { - tracing::warn!("Failed to convert incoming client-side message to rmcp format"); + tracing::warn!( + target: LOG_TARGET, + "Failed to convert incoming client-side message to rmcp format" + ); } } outbound = context.recv_from_handler() => { @@ -243,7 +264,11 @@ impl Worker for NostrClientWorker { }; if let Err(e) = self.transport.close().await { - tracing::warn!("Failed to close client transport cleanly: {e}"); + tracing::warn!( + target: LOG_TARGET, + error = %e, + "Failed to close client transport cleanly" + ); } Err(quit_reason) diff --git a/src/transport/base.rs b/src/transport/base.rs index 419a4b6..e0bd024 100644 --- a/src/transport/base.rs +++ b/src/transport/base.rs @@ -11,6 +11,8 @@ use crate::core::validation; use crate::encryption; use crate::relay::RelayPool; +const LOG_TARGET: &str = "contextvm_sdk::transport::base"; + /// Shared transport logic for both client and server. /// /// Handles relay connectivity, event signing/publishing, encryption decisions, @@ -88,7 +90,11 @@ impl BaseTransport { /// Convert a Nostr event to an MCP message with validation. pub fn convert_event_to_mcp(&self, content: &str) -> Option { if !validation::validate_message_size(content) { - tracing::warn!("Message size validation failed: {} bytes", content.len()); + tracing::warn!( + target: LOG_TARGET, + content_size_bytes = content.len(), + "Message size validation failed" + ); return None; } @@ -139,13 +145,18 @@ impl BaseTransport { encryption::gift_wrap_single_layer(&signer, recipient, &event_json).await?; self.relay_pool.publish_event(&gift_wrap_event).await?; tracing::debug!( + target: LOG_TARGET, signed_event_id = %signed_event_id, envelope_id = %gift_wrap_event.id, "Sent encrypted MCP message" ); } else { self.relay_pool.publish_event(&event).await?; - tracing::debug!(signed_event_id = %signed_event_id, "Sent unencrypted MCP message"); + tracing::debug!( + target: LOG_TARGET, + signed_event_id = %signed_event_id, + "Sent unencrypted MCP message" + ); } Ok(signed_event_id) From 4f5f91949e284f52871f3eed7e1f1cf4b64fc6c9 Mon Sep 17 00:00:00 2001 From: Kushagra Date: Wed, 8 Apr 2026 04:23:35 +0530 Subject: [PATCH 27/69] fix: fixed validation logic and improved consistency --- .github/workflows/rust.yml | 33 ++++++ src/core/constants.rs | 40 +++---- src/core/validation.rs | 37 +++++++ src/encryption/mod.rs | 45 +++++++- src/transport/base.rs | 12 +- src/transport/client.rs | 7 +- src/transport/server.rs | 24 ++-- tests/conformance_wire_format.rs | 183 +++++++++++++++++++++++++++++++ 8 files changed, 332 insertions(+), 49 deletions(-) create mode 100644 .github/workflows/rust.yml create mode 100644 tests/conformance_wire_format.rs diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml new file mode 100644 index 0000000..cc4327a --- /dev/null +++ b/.github/workflows/rust.yml @@ -0,0 +1,33 @@ +name: Rust CI + +on: + push: + branches: ["main"] + pull_request: + branches: ["main"] + +env: + CARGO_TERM_COLOR: always + +jobs: + ci: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache dependencies + uses: Swatinem/rust-cache@v2 + + - name: Check + run: cargo check --all --all-features + + - name: Test + run: cargo test --all --all-features + + - name: Build + run: cargo build --verbose \ No newline at end of file diff --git a/src/core/constants.rs b/src/core/constants.rs index 870d165..f610637 100644 --- a/src/core/constants.rs +++ b/src/core/constants.rs @@ -33,7 +33,6 @@ pub const RESOURCETEMPLATES_LIST_KIND: u16 = 11319; /// Prompts list (addressable, kind 11320) pub const PROMPTS_LIST_KIND: u16 = 11320; -pub const KIND_GIFT_WRAP: u16 = 1059; /// Nostr tag constants pub mod tags { /// Public key tag @@ -70,11 +69,6 @@ pub mod tags { /// Maximum message size (1MB) pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024; -/// MCP protocol version string used in initialize responses. -/// -/// Matches the `protocolVersion` field of the `InitializeResult` JSON-RPC response. -/// Keep this in sync with the MCP spec and rmcp's `ProtocolVersion::LATEST`. - /// Default LRU cache size for deduplication pub const DEFAULT_LRU_SIZE: usize = 5000; @@ -123,6 +117,19 @@ pub const fn mcp_protocol_version() -> &'static str { "2025-11-25" } +// Compile-time range checks (NIP-01 kind ranges). +// Placed at module level so violations are caught in every build, not just `cargo test`. +const _: () = { + // Ephemeral events: 20000 <= kind < 30000 + assert!(EPHEMERAL_GIFT_WRAP_KIND >= 20000); + assert!(EPHEMERAL_GIFT_WRAP_KIND < 30000); + assert!(CTXVM_MESSAGES_KIND >= 20000); + assert!(CTXVM_MESSAGES_KIND < 30000); + // Replaceable events: 10000 <= kind < 20000 + assert!(RELAY_LIST_METADATA_KIND >= 10000); + assert!(RELAY_LIST_METADATA_KIND < 20000); +}; + #[cfg(test)] mod tests { use super::*; @@ -157,27 +164,6 @@ mod tests { ); } - #[test] - fn test_ephemeral_gift_wrap_in_ephemeral_range() { - // NIP-01: ephemeral events are 20000 <= kind < 30000 - assert!(EPHEMERAL_GIFT_WRAP_KIND >= 20000); - assert!(EPHEMERAL_GIFT_WRAP_KIND < 30000); - } - - #[test] - fn test_ctxvm_messages_in_ephemeral_range() { - // NIP-01: ephemeral events are 20000 <= kind < 30000 - assert!(CTXVM_MESSAGES_KIND >= 20000); - assert!(CTXVM_MESSAGES_KIND < 30000); - } - - #[test] - fn test_relay_list_metadata_in_replaceable_range() { - // NIP-01: replaceable events are 10000 <= kind < 20000 - assert!(RELAY_LIST_METADATA_KIND >= 10000); - assert!(RELAY_LIST_METADATA_KIND < 20000); - } - #[test] fn test_announcement_kinds_in_addressable_range() { // NIP-01: addressable events are 30000 <= kind < 40000 diff --git a/src/core/validation.rs b/src/core/validation.rs index d409b7b..9e6dc5d 100644 --- a/src/core/validation.rs +++ b/src/core/validation.rs @@ -8,6 +8,17 @@ pub fn validate_message_size(content: &str) -> bool { content.len() <= MAX_MESSAGE_SIZE } +/// Validate size and structure, then parse into a [`JsonRpcMessage`]. +pub fn validate_and_parse(content: &str) -> Option { + if !validate_message_size(content) { + tracing::warn!("Message size validation failed: {} bytes", content.len()); + return None; + } + + let value: serde_json::Value = serde_json::from_str(content).ok()?; + validate_message(&value) +} + /// Validate that a JSON value is a well-formed JSON-RPC 2.0 message. /// /// Checks: @@ -61,4 +72,30 @@ mod tests { let big = "x".repeat(MAX_MESSAGE_SIZE + 1); assert!(!validate_message_size(&big)); } + + #[test] + fn test_validate_and_parse_valid_request() { + let content = r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#; + let msg = validate_and_parse(content).unwrap(); + assert!(msg.is_request()); + assert_eq!(msg.method(), Some("tools/list")); + } + + #[test] + fn test_validate_and_parse_rejects_oversized() { + let padding = "x".repeat(MAX_MESSAGE_SIZE); + let content = format!(r#"{{"jsonrpc":"2.0","id":1,"method":"{}"}}"#, padding); + assert!(validate_and_parse(&content).is_none()); + } + + #[test] + fn test_validate_and_parse_rejects_invalid_version() { + let content = r#"{"jsonrpc":"1.0","id":1,"method":"test"}"#; + assert!(validate_and_parse(content).is_none()); + } + + #[test] + fn test_validate_and_parse_rejects_invalid_json() { + assert!(validate_and_parse("not json").is_none()); + } } diff --git a/src/encryption/mod.rs b/src/encryption/mod.rs index 1e32426..a7a6d39 100644 --- a/src/encryption/mod.rs +++ b/src/encryption/mod.rs @@ -111,7 +111,7 @@ pub async fn gift_wrap( #[cfg(test)] mod tests { - use crate::core::constants::KIND_GIFT_WRAP; + use crate::core::constants::GIFT_WRAP_KIND; use super::*; @@ -150,7 +150,7 @@ mod tests { .unwrap(); // Build kind 1059 event - let builder = EventBuilder::new(Kind::from(KIND_GIFT_WRAP), encrypted) + let builder = EventBuilder::new(Kind::from(GIFT_WRAP_KIND), encrypted) .tag(Tag::public_key(*recipient)); let event = builder.sign_with_keys(&ephemeral).unwrap(); @@ -261,4 +261,45 @@ mod tests { // (it uses an ephemeral key, like the JS SDK) assert_ne!(gift_wrap_event.pubkey, sender_keys.public_key()); } + + /// Regression: gift-wrapped inner events with a tampered pubkey must be + /// caught by `Event::verify()`. + #[tokio::test] + async fn test_forged_inner_event_detected_by_verify() { + let real_sender = Keys::generate(); + let impersonated = Keys::generate(); + let recipient = Keys::generate(); + + let mcp_content = r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#; + + // Step 1: build a legitimately signed inner event + let inner_event = EventBuilder::new(Kind::Custom(25910), mcp_content) + .tag(Tag::public_key(recipient.public_key())) + .sign_with_keys(&real_sender) + .unwrap(); + + // Step 2: tamper the pubkey (keep original, now-invalid, signature) + let mut forged_json: serde_json::Value = + serde_json::to_value(&inner_event).unwrap(); + forged_json["pubkey"] = + serde_json::Value::String(impersonated.public_key().to_hex()); + let forged_str = serde_json::to_string(&forged_json).unwrap(); + + // Step 3: gift-wrap the forged payload + let (gift_wrap, _) = + create_simple_gift_wrap(&forged_str, &recipient.public_key()).await; + + // Decrypt + parse both succeed — the forgery is syntactically valid + let decrypted = decrypt_gift_wrap_single_layer(&recipient, &gift_wrap) + .await + .unwrap(); + let parsed: Event = serde_json::from_str(&decrypted).unwrap(); + assert_eq!(parsed.pubkey, impersonated.public_key()); + + // Signature verification catches the tampered pubkey + assert!( + parsed.verify().is_err(), + "forged inner event must fail signature verification" + ); + } } diff --git a/src/transport/base.rs b/src/transport/base.rs index e0bd024..2b34544 100644 --- a/src/transport/base.rs +++ b/src/transport/base.rs @@ -89,17 +89,7 @@ impl BaseTransport { /// Convert a Nostr event to an MCP message with validation. pub fn convert_event_to_mcp(&self, content: &str) -> Option { - if !validation::validate_message_size(content) { - tracing::warn!( - target: LOG_TARGET, - content_size_bytes = content.len(), - "Message size validation failed" - ); - return None; - } - - let value: serde_json::Value = serde_json::from_str(content).ok()?; - validation::validate_message(&value) + validation::validate_and_parse(content) } /// Create a signed Nostr event for an MCP message. diff --git a/src/transport/client.rs b/src/transport/client.rs index e41e518..db30252 100644 --- a/src/transport/client.rs +++ b/src/transport/client.rs @@ -14,6 +14,7 @@ use crate::core::constants::*; use crate::core::error::{Error, Result}; use crate::core::serializers; use crate::core::types::*; +use crate::core::validation; use crate::encryption; use crate::relay::RelayPool; use crate::transport::base::BaseTransport; @@ -291,6 +292,10 @@ impl NostrClientTransport { Ok(decrypted_json) => { match serde_json::from_str::(&decrypted_json) { Ok(inner) => { + if let Err(e) = inner.verify() { + tracing::warn!("Inner event signature verification failed: {e}"); + continue; + } let e_tag = serializers::get_tag_value(&inner.tags, "e"); (inner.content, inner.pubkey, e_tag) } @@ -344,7 +349,7 @@ impl NostrClientTransport { // Parse MCP message if let Some(mcp_msg) = - serializers::nostr_event_to_mcp_message(&actual_event_content) + validation::validate_and_parse(&actual_event_content) { // Clean up pending request if let Some(ref correlated_id) = e_tag { diff --git a/src/transport/server.rs b/src/transport/server.rs index dbf4520..69f9d3a 100644 --- a/src/transport/server.rs +++ b/src/transport/server.rs @@ -13,8 +13,8 @@ use tokio::sync::RwLock; use crate::core::constants::*; use crate::core::error::{Error, Result}; -use crate::core::serializers; use crate::core::types::*; +use crate::core::validation; use crate::encryption; use crate::relay::RelayPool; use crate::transport::base::BaseTransport; @@ -629,12 +629,20 @@ impl NostrServerTransport { // Use the INNER event's ID for correlation — the client // registers the inner event ID in its correlation store. match serde_json::from_str::(&decrypted_json) { - Ok(inner) => ( - inner.content, - inner.pubkey.to_hex(), - inner.id.to_hex(), - true, - ), + Ok(inner) => { + if let Err(e) = inner.verify() { + tracing::warn!( + "Inner event signature verification failed: {e}" + ); + continue; + } + ( + inner.content, + inner.pubkey.to_hex(), + inner.id.to_hex(), + true, + ) + } Err(error) => { tracing::error!( target: LOG_TARGET, @@ -672,7 +680,7 @@ impl NostrServerTransport { }; // Parse MCP message - let mcp_msg = match serializers::nostr_event_to_mcp_message(&content) { + let mcp_msg = match validation::validate_and_parse(&content) { Some(msg) => msg, None => { tracing::warn!( diff --git a/tests/conformance_wire_format.rs b/tests/conformance_wire_format.rs new file mode 100644 index 0000000..697376a --- /dev/null +++ b/tests/conformance_wire_format.rs @@ -0,0 +1,183 @@ +//! Conformance tests for ContextVM wire format: MCP JSON-RPC carried in Nostr kind 25910 events. +//! +//! These mirror the layering style of `src/rmcp_transport/pipeline_tests.rs`: build the JSON-RPC +//! payload, serialize through the same helpers the transport uses (`mcp_to_nostr_event`, tag +//! builders from [`BaseTransport`]), sign with nostr-sdk, then assert on kind, tags, and content. + +use contextvm_sdk::core::constants::{ + mcp_protocol_version, tags, CTXVM_MESSAGES_KIND, INITIALIZE_METHOD, + NOTIFICATIONS_INITIALIZED_METHOD, +}; +use contextvm_sdk::core::serializers; +use contextvm_sdk::core::types::{ + JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, +}; +use contextvm_sdk::transport::base::BaseTransport; +use nostr_sdk::prelude::*; + +fn assert_ctxvm_message_kind(event: &Event) { + assert_eq!( + event.kind, + Kind::Custom(CTXVM_MESSAGES_KIND), + "ContextVM MCP messages must use kind {}", + CTXVM_MESSAGES_KIND + ); +} + +fn p_tag_hex(event: &Event) -> Option { + serializers::get_tag_value(&event.tags, tags::PUBKEY) +} + +fn e_tag_hex(event: &Event) -> Option { + serializers::get_tag_value(&event.tags, tags::EVENT_ID) +} + +// ── Initialize request ─────────────────────────────────────────────────────── + +#[test] +fn ctxvm_initialize_request_has_kind_p_tag_and_jsonrpc_initialize() { + let server_keys = Keys::generate(); + let server_pk = server_keys.public_key(); + + let init_req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: INITIALIZE_METHOD.to_string(), + params: Some(serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "capabilities": {}, + "clientInfo": { "name": "conformance-test", "version": "0.0.0" } + })), + }); + + let recipient_tags = BaseTransport::create_recipient_tags(&server_pk); + let builder = serializers::mcp_to_nostr_event(&init_req, CTXVM_MESSAGES_KIND, recipient_tags) + .expect("initialize request should serialize to event content"); + + let client_keys = Keys::generate(); + let event = builder + .sign_with_keys(&client_keys) + .expect("sign initialize request event"); + + assert_ctxvm_message_kind(&event); + assert_eq!( + p_tag_hex(&event), + Some(server_pk.to_hex()), + "initialize request must target the server via p tag" + ); + + let msg = serializers::nostr_event_to_mcp_message(&event.content) + .expect("event content should be valid JSON-RPC"); + assert!(msg.is_request()); + assert_eq!(msg.method(), Some(INITIALIZE_METHOD)); + + // Parse at the raw JSON level to verify wire format independently of the typed deserializer. + let v: serde_json::Value = + serde_json::from_str(&event.content).expect("content must be JSON object"); + assert_eq!(v["jsonrpc"], "2.0"); + assert_eq!(v["id"], serde_json::json!(1)); +} + +// ── Initialize response ────────────────────────────────────────────────────── + +#[test] +fn ctxvm_initialize_response_has_kind_e_tag_and_result_protocol_version() { + let server_keys = Keys::generate(); + let server_pk = server_keys.public_key(); + let client_keys = Keys::generate(); + let client_pk = client_keys.public_key(); + + // Signed request event provides the Nostr event id referenced by e on the response. + let init_req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("corr-1"), + method: INITIALIZE_METHOD.to_string(), + params: Some(serde_json::json!({})), + }); + let recipient_tags = BaseTransport::create_recipient_tags(&server_pk); + let request_event = serializers::mcp_to_nostr_event(&init_req, CTXVM_MESSAGES_KIND, recipient_tags) + .expect("request event for response correlation should serialize") + .sign_with_keys(&client_keys) + .expect("sign request event for correlation"); + + let init_resp = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("corr-1"), + result: serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "serverInfo": { + "name": "conformance-test-server", + "version": "0.0.0" + }, + "capabilities": {} + }), + }); + + let response_tags = BaseTransport::create_response_tags(&client_pk, &request_event.id); + let response_event = + serializers::mcp_to_nostr_event(&init_resp, CTXVM_MESSAGES_KIND, response_tags) + .expect("initialize response should serialize") + .sign_with_keys(&server_keys) + .expect("sign initialize response event"); + + assert_ctxvm_message_kind(&response_event); + assert_eq!( + p_tag_hex(&response_event), + Some(client_pk.to_hex()), + "initialize response must route back to the client via p tag" + ); + assert_eq!( + e_tag_hex(&response_event), + Some(request_event.id.to_hex()), + "initialize response must correlate to the request Nostr event via e tag" + ); + + let v: serde_json::Value = + serde_json::from_str(&response_event.content).expect("content must be JSON"); + assert_eq!(v["jsonrpc"], "2.0"); + assert_eq!(v["id"], serde_json::json!("corr-1")); + assert!(v["result"]["protocolVersion"].is_string()); + assert!(v["result"]["serverInfo"]["name"].is_string()); +} + +// ── notifications/initialized ────────────────────────────────────────────── + +#[test] +fn ctxvm_notifications_initialized_has_kind_p_tag_and_method() { + let server_keys = Keys::generate(); + let server_pk = server_keys.public_key(); + let client_keys = Keys::generate(); + + let notif = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: NOTIFICATIONS_INITIALIZED_METHOD.to_string(), + params: None, + }); + + let recipient_tags = BaseTransport::create_recipient_tags(&server_pk); + let event = serializers::mcp_to_nostr_event(¬if, CTXVM_MESSAGES_KIND, recipient_tags) + .expect("notification should serialize") + // Client sends this to the server; signer must differ from `p` so the tag is not stripped. + .sign_with_keys(&client_keys) + .expect("sign initialized notification"); + + assert_ctxvm_message_kind(&event); + assert_eq!( + p_tag_hex(&event), + Some(server_pk.to_hex()), + "initialized notification must include server p tag" + ); + + let msg = serializers::nostr_event_to_mcp_message(&event.content).expect("parse content"); + assert!(msg.is_notification()); + assert_eq!(msg.method(), Some(NOTIFICATIONS_INITIALIZED_METHOD)); + + // Parse at the raw JSON level to verify wire format independently of the typed deserializer. + let v: serde_json::Value = + serde_json::from_str(&event.content).expect("content must be JSON object"); + assert_eq!(v["jsonrpc"], "2.0"); + assert!( + v.get("id").map_or(true, serde_json::Value::is_null), + "JSON-RPC notifications must not include an id" + ); +} From 23eaa5889d7aa49caa2cca4500b7d35c967c225e Mon Sep 17 00:00:00 2001 From: Anshuman Singh Date: Wed, 8 Apr 2026 04:48:09 +0530 Subject: [PATCH 28/69] Add integration test CI harness with local nostr relay --- .github/workflows/integration.yml | 44 +++++++ tests/integration.rs | 195 ++++++++++++++++++++++++++++++ 2 files changed, 239 insertions(+) create mode 100644 .github/workflows/integration.yml create mode 100644 tests/integration.rs diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml new file mode 100644 index 0000000..c7e1fc7 --- /dev/null +++ b/.github/workflows/integration.yml @@ -0,0 +1,44 @@ +name: Integration Tests + +on: + push: + branches: ["main"] + pull_request: + branches: ["main"] + +env: + CARGO_TERM_COLOR: always + +jobs: + integration: + name: Integration tests (local relay) + runs-on: ubuntu-latest + + services: + nostr-relay: + image: scsibug/nostr-rs-relay:0.8.9 + ports: + - 8080:8080 + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache dependencies + uses: Swatinem/rust-cache@v2 + + - name: Wait for relay + run: | + for i in $(seq 1 30); do + nc -z localhost 8080 2>/dev/null && exit 0 + sleep 1 + done + echo "Relay not ready after 30s" && exit 1 + + - name: Run integration example (all scenarios) + env: + CTXVM_RELAY_URL: ws://localhost:8080 + run: cargo run --example rmcp_integration_test --features rmcp -- all diff --git a/tests/integration.rs b/tests/integration.rs new file mode 100644 index 0000000..7507057 --- /dev/null +++ b/tests/integration.rs @@ -0,0 +1,195 @@ +//! Local RMCP integration test (in-process duplex I/O, no relay required). +//! Relay-dependent scenarios live in `examples/rmcp_integration_test.rs` +//! and run via the `integration.yml` workflow against a local relay container. + +#![cfg(feature = "rmcp")] + +use rmcp::{ + handler::server::router::tool::ToolRouter, handler::server::wrapper::Parameters, model::*, + schemars, tool, tool_handler, tool_router, ClientHandler, RoleServer, ServerHandler, + ServiceExt, service::RequestContext, +}; +use std::sync::Arc; +use tokio::sync::Mutex; + +// Minimal fixture: same tools as examples/rmcp_integration_test.rs + +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +struct EchoParams { + message: String, +} + +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +struct AddParams { + a: i64, + b: i64, +} + +#[derive(Clone)] +struct DemoServer { + echo_count: Arc>, + tool_router: ToolRouter, +} + +impl DemoServer { + fn new() -> Self { + Self { + echo_count: Arc::new(Mutex::new(0)), + tool_router: Self::tool_router(), + } + } +} + +#[tool_router] +impl DemoServer { + #[tool(description = "Echo a message back")] + async fn echo( + &self, + Parameters(EchoParams { message }): Parameters, + ) -> Result { + let mut n = self.echo_count.lock().await; + *n += 1; + Ok(CallToolResult::success(vec![Content::text(format!( + "Echo #{n}: {message}" + ))])) + } + + #[tool(description = "Add two integers")] + fn add( + &self, + Parameters(AddParams { a, b }): Parameters, + ) -> Result { + Ok(CallToolResult::success(vec![Content::text(format!( + "{a} + {b} = {}", + a + b + ))])) + } + + #[tool(description = "Return total echo calls")] + async fn get_echo_count(&self) -> Result { + let n = self.echo_count.lock().await; + Ok(CallToolResult::success(vec![Content::text(format!( + "Total echo calls: {n}" + ))])) + } +} + +#[tool_handler] +impl ServerHandler for DemoServer { + fn get_info(&self) -> ServerInfo { + ServerInfo { + protocol_version: ProtocolVersion::LATEST, + capabilities: ServerCapabilities::builder() + .enable_tools() + .enable_resources() + .build(), + server_info: Implementation { + name: "integration-test".to_string(), + title: None, + version: "0.1.0".to_string(), + description: None, + icons: None, + website_url: None, + }, + instructions: None, + } + } + + async fn list_resources( + &self, + _req: Option, + _ctx: RequestContext, + ) -> Result { + Ok(ListResourcesResult { + resources: vec![ + RawResource::new("demo://readme", "Demo README".to_string()).no_annotation(), + ], + next_cursor: None, + meta: None, + }) + } + + async fn read_resource( + &self, + req: ReadResourceRequestParams, + _ctx: RequestContext, + ) -> Result { + match req.uri.as_str() { + "demo://readme" => Ok(ReadResourceResult { + contents: vec![ResourceContents::text("Demo content.", req.uri)], + }), + other => Err(ErrorData::resource_not_found( + "not_found", + Some(serde_json::json!({ "uri": other })), + )), + } + } +} + +#[derive(Clone, Default)] +struct DemoClient; +impl ClientHandler for DemoClient {} + +fn first_text(result: &CallToolResult) -> String { + result + .content + .iter() + .find_map(|c| match &c.raw { + RawContent::Text(t) => Some(t.text.clone()), + _ => None, + }) + .unwrap_or_default() +} + +// ── Test ───────────────────────────────────────────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_local_rmcp() { + let (server_io, client_io) = tokio::io::duplex(65536); + + let server_handle = tokio::spawn(async move { + DemoServer::new() + .serve(server_io) + .await + .expect("serve") + .waiting() + .await + .expect("server error"); + }); + + let client = DemoClient.serve(client_io).await.expect("client init"); + + let tools = client.list_all_tools().await.expect("list tools"); + assert_eq!(tools.len(), 3); + + let add = client + .call_tool(CallToolRequestParams { + name: "add".into(), + arguments: serde_json::from_value(serde_json::json!({ "a": 7, "b": 5 })).ok(), + meta: None, + task: None, + }) + .await + .expect("call add"); + assert!(first_text(&add).contains("12")); + + let resources = client.list_all_resources().await.expect("list resources"); + assert_eq!(resources.len(), 1); + + match client + .call_tool(CallToolRequestParams { + name: "no_such_tool".into(), + arguments: None, + meta: None, + task: None, + }) + .await + { + Err(_) => {} + Ok(r) if r.is_error.unwrap_or(false) => {} + Ok(_) => panic!("expected unknown tool to fail"), + } + + client.cancel().await.expect("cancel"); + server_handle.abort(); +} From 3cb8f50963d408f5d6ffb53e5984f139d9f793ee Mon Sep 17 00:00:00 2001 From: Anshuman Singh Date: Wed, 8 Apr 2026 05:13:59 +0530 Subject: [PATCH 29/69] Fix CI: drop local relay, run integration example in local mode only --- .github/workflows/integration.yml | 25 ++++++------------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index c7e1fc7..b55e77c 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -11,15 +11,9 @@ env: jobs: integration: - name: Integration tests (local relay) + name: Integration tests runs-on: ubuntu-latest - services: - nostr-relay: - image: scsibug/nostr-rs-relay:0.8.9 - ports: - - 8080:8080 - steps: - name: Checkout uses: actions/checkout@v4 @@ -30,15 +24,8 @@ jobs: - name: Cache dependencies uses: Swatinem/rust-cache@v2 - - name: Wait for relay - run: | - for i in $(seq 1 30); do - nc -z localhost 8080 2>/dev/null && exit 0 - sleep 1 - done - echo "Relay not ready after 30s" && exit 1 - - - name: Run integration example (all scenarios) - env: - CTXVM_RELAY_URL: ws://localhost:8080 - run: cargo run --example rmcp_integration_test --features rmcp -- all + - name: Run unit and local integration tests + run: cargo test --all-features + + - name: Run integration example (local RMCP) + run: cargo run --example rmcp_integration_test --features rmcp -- local From 02c0bdf71a4e029c04c8bd2f6f466fef94063dd8 Mon Sep 17 00:00:00 2001 From: Anshuman Singh Date: Wed, 8 Apr 2026 10:47:37 +0530 Subject: [PATCH 30/69] Add stateless conformance tests --- tests/conformance_stateless_mode.rs | 179 ++++++++++++++++++++++++++++ 1 file changed, 179 insertions(+) create mode 100644 tests/conformance_stateless_mode.rs diff --git a/tests/conformance_stateless_mode.rs b/tests/conformance_stateless_mode.rs new file mode 100644 index 0000000..bbddde5 --- /dev/null +++ b/tests/conformance_stateless_mode.rs @@ -0,0 +1,179 @@ +//! Conformance tests for stateless mode behavior in client transport. +//! +//! These mirror the TypeScript SDK StatelessModeHandler expectations at the +//! transport boundary: initialize requests are handled locally with an emulated +//! response, while unrelated methods are not handled statelessly. + +use std::time::Duration; + +use contextvm_sdk::core::constants::{mcp_protocol_version, INITIALIZE_METHOD}; +use contextvm_sdk::core::types::{ + EncryptionMode, JsonRpcMessage, JsonRpcRequest, +}; +use contextvm_sdk::transport::client::{ + NostrClientTransport, NostrClientTransportConfig, +}; +use contextvm_sdk::signer; +use tokio::time::timeout; + +async fn make_stateless_transport() -> ( + NostrClientTransport, + tokio::sync::mpsc::UnboundedReceiver, +) { + let server_keys = signer::generate(); + let client_keys = signer::generate(); + + let config = NostrClientTransportConfig { + relay_urls: Vec::new(), + server_pubkey: server_keys.public_key().to_hex(), + encryption_mode: EncryptionMode::Optional, + is_stateless: true, + timeout: Duration::from_secs(1), + }; + + let mut transport = NostrClientTransport::new(client_keys, config) + .await + .expect("transport should be constructed"); + let rx = transport + .take_message_receiver() + .expect("message receiver should be available once"); + + (transport, rx) +} + +#[tokio::test] +async fn create_emulated_response_returns_correct_request_id() { + let (transport, mut rx) = make_stateless_transport().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("test-id"), + method: INITIALIZE_METHOD.to_string(), + params: Some(serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "capabilities": {}, + "clientInfo": { "name": "conformance-test", "version": "0.0.0" } + })), + }); + + transport + .send(&request) + .await + .expect("initialize should be emulated in stateless mode"); + + let msg = timeout(Duration::from_millis(200), rx.recv()) + .await + .expect("should receive emulated response promptly") + .expect("channel should contain response"); + + match msg { + JsonRpcMessage::Response(resp) => { + assert_eq!(resp.id, serde_json::json!("test-id")); + assert_eq!(resp.jsonrpc, "2.0"); + assert_eq!( + resp.result + .get("protocolVersion") + .and_then(serde_json::Value::as_str), + Some(mcp_protocol_version()) + ); + assert_eq!( + resp.result + .get("serverInfo") + .and_then(|v| v.get("name")) + .and_then(serde_json::Value::as_str), + Some("Emulated-Stateless-Server") + ); + assert_eq!( + resp.result + .get("serverInfo") + .and_then(|v| v.get("version")) + .and_then(serde_json::Value::as_str), + Some("1.0.0") + ); + assert_eq!( + resp.result + .get("capabilities") + .and_then(|v| v.get("tools")) + .and_then(|v| v.get("listChanged")) + .and_then(serde_json::Value::as_bool), + Some(true) + ); + assert_eq!( + resp.result + .get("capabilities") + .and_then(|v| v.get("prompts")) + .and_then(|v| v.get("listChanged")) + .and_then(serde_json::Value::as_bool), + Some(true) + ); + assert_eq!( + resp.result + .get("capabilities") + .and_then(|v| v.get("resources")) + .and_then(|v| v.get("subscribe")) + .and_then(serde_json::Value::as_bool), + Some(true) + ); + assert_eq!( + resp.result + .get("capabilities") + .and_then(|v| v.get("resources")) + .and_then(|v| v.get("listChanged")) + .and_then(serde_json::Value::as_bool), + Some(true) + ); + } + other => panic!("expected Response, got {other:?}"), + } + + let duplicate = timeout(Duration::from_millis(100), rx.recv()).await; + assert!( + duplicate.is_err(), + "initialize request should emit exactly one emulated response" + ); +} + +#[tokio::test] +async fn should_handle_statelessly_returns_true_for_initialize() { + let (transport, mut rx) = make_stateless_transport().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: INITIALIZE_METHOD.to_string(), + params: None, + }); + + transport + .send(&request) + .await + .expect("initialize should be handled statelessly"); + + let msg = timeout(Duration::from_millis(200), rx.recv()) + .await + .expect("initialize should produce local emulated response") + .expect("response should be delivered"); + + assert_eq!(msg.id(), Some(&serde_json::json!(1))); +} + +#[tokio::test] +async fn should_handle_statelessly_returns_false_for_other_methods() { + let (transport, mut rx) = make_stateless_transport().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(2), + method: "tools/list".to_string(), + params: None, + }); + + let _send_result = transport.send(&request).await; + + let recv_result = timeout(Duration::from_millis(200), rx.recv()).await; + assert!( + recv_result.is_err(), + "non-initialize request should not create a local emulated response" + ); + +} From 963467e47653a1c3d8674af4f0fbc64f88c90d46 Mon Sep 17 00:00:00 2001 From: Anshuman Singh Date: Wed, 8 Apr 2026 11:05:48 +0530 Subject: [PATCH 31/69] Polish stateless test wording --- tests/conformance_stateless_mode.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/conformance_stateless_mode.rs b/tests/conformance_stateless_mode.rs index bbddde5..5c7637a 100644 --- a/tests/conformance_stateless_mode.rs +++ b/tests/conformance_stateless_mode.rs @@ -1,8 +1,4 @@ -//! Conformance tests for stateless mode behavior in client transport. -//! -//! These mirror the TypeScript SDK StatelessModeHandler expectations at the -//! transport boundary: initialize requests are handled locally with an emulated -//! response, while unrelated methods are not handled statelessly. +//! Stateless-mode conformance tests for the client transport. use std::time::Duration; From 1bc6b0f0344a91f806281863b68dbfce616b429e Mon Sep 17 00:00:00 2001 From: Harsh Date: Thu, 9 Apr 2026 00:37:03 +0530 Subject: [PATCH 32/69] style: apply cargo fmt formatting --- .github/workflows/rust.yml | 9 + src/encryption/mod.rs | 9 +- src/transport/client.rs | 8 +- tests/conformance_stateless_mode.rs | 345 ++++++++++++------------ tests/conformance_wire_format.rs | 12 +- tests/integration.rs | 390 ++++++++++++++-------------- 6 files changed, 387 insertions(+), 386 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index cc4327a..43b2061 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -29,5 +29,14 @@ jobs: - name: Test run: cargo test --all --all-features + - name: Format check + run: cargo fmt --all -- --check + + - name: Clippy + run: cargo clippy --all --all-features -- -D warnings + + - name: Doc check + run: cargo doc --no-deps --all-features + - name: Build run: cargo build --verbose \ No newline at end of file diff --git a/src/encryption/mod.rs b/src/encryption/mod.rs index a7a6d39..20a5362 100644 --- a/src/encryption/mod.rs +++ b/src/encryption/mod.rs @@ -279,15 +279,12 @@ mod tests { .unwrap(); // Step 2: tamper the pubkey (keep original, now-invalid, signature) - let mut forged_json: serde_json::Value = - serde_json::to_value(&inner_event).unwrap(); - forged_json["pubkey"] = - serde_json::Value::String(impersonated.public_key().to_hex()); + let mut forged_json: serde_json::Value = serde_json::to_value(&inner_event).unwrap(); + forged_json["pubkey"] = serde_json::Value::String(impersonated.public_key().to_hex()); let forged_str = serde_json::to_string(&forged_json).unwrap(); // Step 3: gift-wrap the forged payload - let (gift_wrap, _) = - create_simple_gift_wrap(&forged_str, &recipient.public_key()).await; + let (gift_wrap, _) = create_simple_gift_wrap(&forged_str, &recipient.public_key()).await; // Decrypt + parse both succeed — the forgery is syntactically valid let decrypted = decrypt_gift_wrap_single_layer(&recipient, &gift_wrap) diff --git a/src/transport/client.rs b/src/transport/client.rs index db30252..640f207 100644 --- a/src/transport/client.rs +++ b/src/transport/client.rs @@ -293,7 +293,9 @@ impl NostrClientTransport { match serde_json::from_str::(&decrypted_json) { Ok(inner) => { if let Err(e) = inner.verify() { - tracing::warn!("Inner event signature verification failed: {e}"); + tracing::warn!( + "Inner event signature verification failed: {e}" + ); continue; } let e_tag = serializers::get_tag_value(&inner.tags, "e"); @@ -348,9 +350,7 @@ impl NostrClientTransport { } // Parse MCP message - if let Some(mcp_msg) = - validation::validate_and_parse(&actual_event_content) - { + if let Some(mcp_msg) = validation::validate_and_parse(&actual_event_content) { // Clean up pending request if let Some(ref correlated_id) = e_tag { pending.write().await.remove(correlated_id.as_str()); diff --git a/tests/conformance_stateless_mode.rs b/tests/conformance_stateless_mode.rs index 5c7637a..015e020 100644 --- a/tests/conformance_stateless_mode.rs +++ b/tests/conformance_stateless_mode.rs @@ -1,175 +1,170 @@ -//! Stateless-mode conformance tests for the client transport. - -use std::time::Duration; - -use contextvm_sdk::core::constants::{mcp_protocol_version, INITIALIZE_METHOD}; -use contextvm_sdk::core::types::{ - EncryptionMode, JsonRpcMessage, JsonRpcRequest, -}; -use contextvm_sdk::transport::client::{ - NostrClientTransport, NostrClientTransportConfig, -}; -use contextvm_sdk::signer; -use tokio::time::timeout; - -async fn make_stateless_transport() -> ( - NostrClientTransport, - tokio::sync::mpsc::UnboundedReceiver, -) { - let server_keys = signer::generate(); - let client_keys = signer::generate(); - - let config = NostrClientTransportConfig { - relay_urls: Vec::new(), - server_pubkey: server_keys.public_key().to_hex(), - encryption_mode: EncryptionMode::Optional, - is_stateless: true, - timeout: Duration::from_secs(1), - }; - - let mut transport = NostrClientTransport::new(client_keys, config) - .await - .expect("transport should be constructed"); - let rx = transport - .take_message_receiver() - .expect("message receiver should be available once"); - - (transport, rx) -} - -#[tokio::test] -async fn create_emulated_response_returns_correct_request_id() { - let (transport, mut rx) = make_stateless_transport().await; - - let request = JsonRpcMessage::Request(JsonRpcRequest { - jsonrpc: "2.0".to_string(), - id: serde_json::json!("test-id"), - method: INITIALIZE_METHOD.to_string(), - params: Some(serde_json::json!({ - "protocolVersion": mcp_protocol_version(), - "capabilities": {}, - "clientInfo": { "name": "conformance-test", "version": "0.0.0" } - })), - }); - - transport - .send(&request) - .await - .expect("initialize should be emulated in stateless mode"); - - let msg = timeout(Duration::from_millis(200), rx.recv()) - .await - .expect("should receive emulated response promptly") - .expect("channel should contain response"); - - match msg { - JsonRpcMessage::Response(resp) => { - assert_eq!(resp.id, serde_json::json!("test-id")); - assert_eq!(resp.jsonrpc, "2.0"); - assert_eq!( - resp.result - .get("protocolVersion") - .and_then(serde_json::Value::as_str), - Some(mcp_protocol_version()) - ); - assert_eq!( - resp.result - .get("serverInfo") - .and_then(|v| v.get("name")) - .and_then(serde_json::Value::as_str), - Some("Emulated-Stateless-Server") - ); - assert_eq!( - resp.result - .get("serverInfo") - .and_then(|v| v.get("version")) - .and_then(serde_json::Value::as_str), - Some("1.0.0") - ); - assert_eq!( - resp.result - .get("capabilities") - .and_then(|v| v.get("tools")) - .and_then(|v| v.get("listChanged")) - .and_then(serde_json::Value::as_bool), - Some(true) - ); - assert_eq!( - resp.result - .get("capabilities") - .and_then(|v| v.get("prompts")) - .and_then(|v| v.get("listChanged")) - .and_then(serde_json::Value::as_bool), - Some(true) - ); - assert_eq!( - resp.result - .get("capabilities") - .and_then(|v| v.get("resources")) - .and_then(|v| v.get("subscribe")) - .and_then(serde_json::Value::as_bool), - Some(true) - ); - assert_eq!( - resp.result - .get("capabilities") - .and_then(|v| v.get("resources")) - .and_then(|v| v.get("listChanged")) - .and_then(serde_json::Value::as_bool), - Some(true) - ); - } - other => panic!("expected Response, got {other:?}"), - } - - let duplicate = timeout(Duration::from_millis(100), rx.recv()).await; - assert!( - duplicate.is_err(), - "initialize request should emit exactly one emulated response" - ); -} - -#[tokio::test] -async fn should_handle_statelessly_returns_true_for_initialize() { - let (transport, mut rx) = make_stateless_transport().await; - - let request = JsonRpcMessage::Request(JsonRpcRequest { - jsonrpc: "2.0".to_string(), - id: serde_json::json!(1), - method: INITIALIZE_METHOD.to_string(), - params: None, - }); - - transport - .send(&request) - .await - .expect("initialize should be handled statelessly"); - - let msg = timeout(Duration::from_millis(200), rx.recv()) - .await - .expect("initialize should produce local emulated response") - .expect("response should be delivered"); - - assert_eq!(msg.id(), Some(&serde_json::json!(1))); -} - -#[tokio::test] -async fn should_handle_statelessly_returns_false_for_other_methods() { - let (transport, mut rx) = make_stateless_transport().await; - - let request = JsonRpcMessage::Request(JsonRpcRequest { - jsonrpc: "2.0".to_string(), - id: serde_json::json!(2), - method: "tools/list".to_string(), - params: None, - }); - - let _send_result = transport.send(&request).await; - - let recv_result = timeout(Duration::from_millis(200), rx.recv()).await; - assert!( - recv_result.is_err(), - "non-initialize request should not create a local emulated response" - ); - -} +//! Stateless-mode conformance tests for the client transport. + +use std::time::Duration; + +use contextvm_sdk::core::constants::{mcp_protocol_version, INITIALIZE_METHOD}; +use contextvm_sdk::core::types::{EncryptionMode, JsonRpcMessage, JsonRpcRequest}; +use contextvm_sdk::signer; +use contextvm_sdk::transport::client::{NostrClientTransport, NostrClientTransportConfig}; +use tokio::time::timeout; + +async fn make_stateless_transport() -> ( + NostrClientTransport, + tokio::sync::mpsc::UnboundedReceiver, +) { + let server_keys = signer::generate(); + let client_keys = signer::generate(); + + let config = NostrClientTransportConfig { + relay_urls: Vec::new(), + server_pubkey: server_keys.public_key().to_hex(), + encryption_mode: EncryptionMode::Optional, + is_stateless: true, + timeout: Duration::from_secs(1), + }; + + let mut transport = NostrClientTransport::new(client_keys, config) + .await + .expect("transport should be constructed"); + let rx = transport + .take_message_receiver() + .expect("message receiver should be available once"); + + (transport, rx) +} + +#[tokio::test] +async fn create_emulated_response_returns_correct_request_id() { + let (transport, mut rx) = make_stateless_transport().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("test-id"), + method: INITIALIZE_METHOD.to_string(), + params: Some(serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "capabilities": {}, + "clientInfo": { "name": "conformance-test", "version": "0.0.0" } + })), + }); + + transport + .send(&request) + .await + .expect("initialize should be emulated in stateless mode"); + + let msg = timeout(Duration::from_millis(200), rx.recv()) + .await + .expect("should receive emulated response promptly") + .expect("channel should contain response"); + + match msg { + JsonRpcMessage::Response(resp) => { + assert_eq!(resp.id, serde_json::json!("test-id")); + assert_eq!(resp.jsonrpc, "2.0"); + assert_eq!( + resp.result + .get("protocolVersion") + .and_then(serde_json::Value::as_str), + Some(mcp_protocol_version()) + ); + assert_eq!( + resp.result + .get("serverInfo") + .and_then(|v| v.get("name")) + .and_then(serde_json::Value::as_str), + Some("Emulated-Stateless-Server") + ); + assert_eq!( + resp.result + .get("serverInfo") + .and_then(|v| v.get("version")) + .and_then(serde_json::Value::as_str), + Some("1.0.0") + ); + assert_eq!( + resp.result + .get("capabilities") + .and_then(|v| v.get("tools")) + .and_then(|v| v.get("listChanged")) + .and_then(serde_json::Value::as_bool), + Some(true) + ); + assert_eq!( + resp.result + .get("capabilities") + .and_then(|v| v.get("prompts")) + .and_then(|v| v.get("listChanged")) + .and_then(serde_json::Value::as_bool), + Some(true) + ); + assert_eq!( + resp.result + .get("capabilities") + .and_then(|v| v.get("resources")) + .and_then(|v| v.get("subscribe")) + .and_then(serde_json::Value::as_bool), + Some(true) + ); + assert_eq!( + resp.result + .get("capabilities") + .and_then(|v| v.get("resources")) + .and_then(|v| v.get("listChanged")) + .and_then(serde_json::Value::as_bool), + Some(true) + ); + } + other => panic!("expected Response, got {other:?}"), + } + + let duplicate = timeout(Duration::from_millis(100), rx.recv()).await; + assert!( + duplicate.is_err(), + "initialize request should emit exactly one emulated response" + ); +} + +#[tokio::test] +async fn should_handle_statelessly_returns_true_for_initialize() { + let (transport, mut rx) = make_stateless_transport().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: INITIALIZE_METHOD.to_string(), + params: None, + }); + + transport + .send(&request) + .await + .expect("initialize should be handled statelessly"); + + let msg = timeout(Duration::from_millis(200), rx.recv()) + .await + .expect("initialize should produce local emulated response") + .expect("response should be delivered"); + + assert_eq!(msg.id(), Some(&serde_json::json!(1))); +} + +#[tokio::test] +async fn should_handle_statelessly_returns_false_for_other_methods() { + let (transport, mut rx) = make_stateless_transport().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(2), + method: "tools/list".to_string(), + params: None, + }); + + let _send_result = transport.send(&request).await; + + let recv_result = timeout(Duration::from_millis(200), rx.recv()).await; + assert!( + recv_result.is_err(), + "non-initialize request should not create a local emulated response" + ); +} diff --git a/tests/conformance_wire_format.rs b/tests/conformance_wire_format.rs index ee039fc..828ac38 100644 --- a/tests/conformance_wire_format.rs +++ b/tests/conformance_wire_format.rs @@ -95,10 +95,11 @@ fn ctxvm_initialize_response_has_kind_e_tag_and_result_protocol_version() { params: Some(serde_json::json!({})), }); let recipient_tags = BaseTransport::create_recipient_tags(&server_pk); - let request_event = serializers::mcp_to_nostr_event(&init_req, CTXVM_MESSAGES_KIND, recipient_tags) - .expect("request event for response correlation should serialize") - .sign_with_keys(&client_keys) - .expect("sign request event for correlation"); + let request_event = + serializers::mcp_to_nostr_event(&init_req, CTXVM_MESSAGES_KIND, recipient_tags) + .expect("request event for response correlation should serialize") + .sign_with_keys(&client_keys) + .expect("sign request event for correlation"); let init_resp = JsonRpcMessage::Response(JsonRpcResponse { jsonrpc: "2.0".to_string(), @@ -449,8 +450,7 @@ fn ctxvm_server_announcement_has_kind_and_required_tags() { assert!( event.tags.iter().any(|t| { let parts = t.clone().to_vec(); - parts.len() == 1 - && parts.first().map(|s| s.as_str()) == Some(tags::SUPPORT_ENCRYPTION) + parts.len() == 1 && parts.first().map(|s| s.as_str()) == Some(tags::SUPPORT_ENCRYPTION) }), "support_encryption must be present as a single-element tag" ); diff --git a/tests/integration.rs b/tests/integration.rs index 7507057..339f2a0 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -1,195 +1,195 @@ -//! Local RMCP integration test (in-process duplex I/O, no relay required). -//! Relay-dependent scenarios live in `examples/rmcp_integration_test.rs` -//! and run via the `integration.yml` workflow against a local relay container. - -#![cfg(feature = "rmcp")] - -use rmcp::{ - handler::server::router::tool::ToolRouter, handler::server::wrapper::Parameters, model::*, - schemars, tool, tool_handler, tool_router, ClientHandler, RoleServer, ServerHandler, - ServiceExt, service::RequestContext, -}; -use std::sync::Arc; -use tokio::sync::Mutex; - -// Minimal fixture: same tools as examples/rmcp_integration_test.rs - -#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] -struct EchoParams { - message: String, -} - -#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] -struct AddParams { - a: i64, - b: i64, -} - -#[derive(Clone)] -struct DemoServer { - echo_count: Arc>, - tool_router: ToolRouter, -} - -impl DemoServer { - fn new() -> Self { - Self { - echo_count: Arc::new(Mutex::new(0)), - tool_router: Self::tool_router(), - } - } -} - -#[tool_router] -impl DemoServer { - #[tool(description = "Echo a message back")] - async fn echo( - &self, - Parameters(EchoParams { message }): Parameters, - ) -> Result { - let mut n = self.echo_count.lock().await; - *n += 1; - Ok(CallToolResult::success(vec![Content::text(format!( - "Echo #{n}: {message}" - ))])) - } - - #[tool(description = "Add two integers")] - fn add( - &self, - Parameters(AddParams { a, b }): Parameters, - ) -> Result { - Ok(CallToolResult::success(vec![Content::text(format!( - "{a} + {b} = {}", - a + b - ))])) - } - - #[tool(description = "Return total echo calls")] - async fn get_echo_count(&self) -> Result { - let n = self.echo_count.lock().await; - Ok(CallToolResult::success(vec![Content::text(format!( - "Total echo calls: {n}" - ))])) - } -} - -#[tool_handler] -impl ServerHandler for DemoServer { - fn get_info(&self) -> ServerInfo { - ServerInfo { - protocol_version: ProtocolVersion::LATEST, - capabilities: ServerCapabilities::builder() - .enable_tools() - .enable_resources() - .build(), - server_info: Implementation { - name: "integration-test".to_string(), - title: None, - version: "0.1.0".to_string(), - description: None, - icons: None, - website_url: None, - }, - instructions: None, - } - } - - async fn list_resources( - &self, - _req: Option, - _ctx: RequestContext, - ) -> Result { - Ok(ListResourcesResult { - resources: vec![ - RawResource::new("demo://readme", "Demo README".to_string()).no_annotation(), - ], - next_cursor: None, - meta: None, - }) - } - - async fn read_resource( - &self, - req: ReadResourceRequestParams, - _ctx: RequestContext, - ) -> Result { - match req.uri.as_str() { - "demo://readme" => Ok(ReadResourceResult { - contents: vec![ResourceContents::text("Demo content.", req.uri)], - }), - other => Err(ErrorData::resource_not_found( - "not_found", - Some(serde_json::json!({ "uri": other })), - )), - } - } -} - -#[derive(Clone, Default)] -struct DemoClient; -impl ClientHandler for DemoClient {} - -fn first_text(result: &CallToolResult) -> String { - result - .content - .iter() - .find_map(|c| match &c.raw { - RawContent::Text(t) => Some(t.text.clone()), - _ => None, - }) - .unwrap_or_default() -} - -// ── Test ───────────────────────────────────────────────────────────────── - -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn test_local_rmcp() { - let (server_io, client_io) = tokio::io::duplex(65536); - - let server_handle = tokio::spawn(async move { - DemoServer::new() - .serve(server_io) - .await - .expect("serve") - .waiting() - .await - .expect("server error"); - }); - - let client = DemoClient.serve(client_io).await.expect("client init"); - - let tools = client.list_all_tools().await.expect("list tools"); - assert_eq!(tools.len(), 3); - - let add = client - .call_tool(CallToolRequestParams { - name: "add".into(), - arguments: serde_json::from_value(serde_json::json!({ "a": 7, "b": 5 })).ok(), - meta: None, - task: None, - }) - .await - .expect("call add"); - assert!(first_text(&add).contains("12")); - - let resources = client.list_all_resources().await.expect("list resources"); - assert_eq!(resources.len(), 1); - - match client - .call_tool(CallToolRequestParams { - name: "no_such_tool".into(), - arguments: None, - meta: None, - task: None, - }) - .await - { - Err(_) => {} - Ok(r) if r.is_error.unwrap_or(false) => {} - Ok(_) => panic!("expected unknown tool to fail"), - } - - client.cancel().await.expect("cancel"); - server_handle.abort(); -} +//! Local RMCP integration test (in-process duplex I/O, no relay required). +//! Relay-dependent scenarios live in `examples/rmcp_integration_test.rs` +//! and run via the `integration.yml` workflow against a local relay container. + +#![cfg(feature = "rmcp")] + +use rmcp::{ + handler::server::router::tool::ToolRouter, handler::server::wrapper::Parameters, model::*, + schemars, service::RequestContext, tool, tool_handler, tool_router, ClientHandler, RoleServer, + ServerHandler, ServiceExt, +}; +use std::sync::Arc; +use tokio::sync::Mutex; + +// Minimal fixture: same tools as examples/rmcp_integration_test.rs + +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +struct EchoParams { + message: String, +} + +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +struct AddParams { + a: i64, + b: i64, +} + +#[derive(Clone)] +struct DemoServer { + echo_count: Arc>, + tool_router: ToolRouter, +} + +impl DemoServer { + fn new() -> Self { + Self { + echo_count: Arc::new(Mutex::new(0)), + tool_router: Self::tool_router(), + } + } +} + +#[tool_router] +impl DemoServer { + #[tool(description = "Echo a message back")] + async fn echo( + &self, + Parameters(EchoParams { message }): Parameters, + ) -> Result { + let mut n = self.echo_count.lock().await; + *n += 1; + Ok(CallToolResult::success(vec![Content::text(format!( + "Echo #{n}: {message}" + ))])) + } + + #[tool(description = "Add two integers")] + fn add( + &self, + Parameters(AddParams { a, b }): Parameters, + ) -> Result { + Ok(CallToolResult::success(vec![Content::text(format!( + "{a} + {b} = {}", + a + b + ))])) + } + + #[tool(description = "Return total echo calls")] + async fn get_echo_count(&self) -> Result { + let n = self.echo_count.lock().await; + Ok(CallToolResult::success(vec![Content::text(format!( + "Total echo calls: {n}" + ))])) + } +} + +#[tool_handler] +impl ServerHandler for DemoServer { + fn get_info(&self) -> ServerInfo { + ServerInfo { + protocol_version: ProtocolVersion::LATEST, + capabilities: ServerCapabilities::builder() + .enable_tools() + .enable_resources() + .build(), + server_info: Implementation { + name: "integration-test".to_string(), + title: None, + version: "0.1.0".to_string(), + description: None, + icons: None, + website_url: None, + }, + instructions: None, + } + } + + async fn list_resources( + &self, + _req: Option, + _ctx: RequestContext, + ) -> Result { + Ok(ListResourcesResult { + resources: vec![ + RawResource::new("demo://readme", "Demo README".to_string()).no_annotation() + ], + next_cursor: None, + meta: None, + }) + } + + async fn read_resource( + &self, + req: ReadResourceRequestParams, + _ctx: RequestContext, + ) -> Result { + match req.uri.as_str() { + "demo://readme" => Ok(ReadResourceResult { + contents: vec![ResourceContents::text("Demo content.", req.uri)], + }), + other => Err(ErrorData::resource_not_found( + "not_found", + Some(serde_json::json!({ "uri": other })), + )), + } + } +} + +#[derive(Clone, Default)] +struct DemoClient; +impl ClientHandler for DemoClient {} + +fn first_text(result: &CallToolResult) -> String { + result + .content + .iter() + .find_map(|c| match &c.raw { + RawContent::Text(t) => Some(t.text.clone()), + _ => None, + }) + .unwrap_or_default() +} + +// ── Test ───────────────────────────────────────────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_local_rmcp() { + let (server_io, client_io) = tokio::io::duplex(65536); + + let server_handle = tokio::spawn(async move { + DemoServer::new() + .serve(server_io) + .await + .expect("serve") + .waiting() + .await + .expect("server error"); + }); + + let client = DemoClient.serve(client_io).await.expect("client init"); + + let tools = client.list_all_tools().await.expect("list tools"); + assert_eq!(tools.len(), 3); + + let add = client + .call_tool(CallToolRequestParams { + name: "add".into(), + arguments: serde_json::from_value(serde_json::json!({ "a": 7, "b": 5 })).ok(), + meta: None, + task: None, + }) + .await + .expect("call add"); + assert!(first_text(&add).contains("12")); + + let resources = client.list_all_resources().await.expect("list resources"); + assert_eq!(resources.len(), 1); + + match client + .call_tool(CallToolRequestParams { + name: "no_such_tool".into(), + arguments: None, + meta: None, + task: None, + }) + .await + { + Err(_) => {} + Ok(r) if r.is_error.unwrap_or(false) => {} + Ok(_) => panic!("expected unknown tool to fail"), + } + + client.cancel().await.expect("cancel"); + server_handle.abort(); +} From be45cf6026d13d7d2f25d72c0b46ea30b9fcb6b8 Mon Sep 17 00:00:00 2001 From: Harsh Date: Thu, 9 Apr 2026 00:46:52 +0530 Subject: [PATCH 33/69] resolve ci --- tests/conformance_stateless_mode.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/conformance_stateless_mode.rs b/tests/conformance_stateless_mode.rs index 015e020..54b633b 100644 --- a/tests/conformance_stateless_mode.rs +++ b/tests/conformance_stateless_mode.rs @@ -21,6 +21,7 @@ async fn make_stateless_transport() -> ( encryption_mode: EncryptionMode::Optional, is_stateless: true, timeout: Duration::from_secs(1), + log_file_path: None, }; let mut transport = NostrClientTransport::new(client_keys, config) From bc80b8aaebc32210ee5534801d9a35c18ac04e5b Mon Sep 17 00:00:00 2001 From: Harsh Date: Thu, 9 Apr 2026 01:29:08 +0530 Subject: [PATCH 34/69] test: add conformance tests for signer behavior --- tests/conformance_signer.rs | 75 ++++++ tests/conformance_stateless_mode.rs | 351 ++++++++++++++-------------- 2 files changed, 251 insertions(+), 175 deletions(-) create mode 100644 tests/conformance_signer.rs diff --git a/tests/conformance_signer.rs b/tests/conformance_signer.rs new file mode 100644 index 0000000..5b7c661 --- /dev/null +++ b/tests/conformance_signer.rs @@ -0,0 +1,75 @@ +//! Conformance tests for signer behavior (hex `from_sk`, `generate`, NIP-44, signing). +//! +//! Same layout as `conformance_wire_format.rs`; scenarios follow the TS SDK +//! `private-key-signer.test.ts` alongside `src/signer/mod.rs` / `src/encryption/mod.rs`. + +use contextvm_sdk::encryption::{decrypt_nip44, encrypt_nip44}; +use contextvm_sdk::signer::{self, Keys}; +use nostr_sdk::prelude::*; + +/// Secret `1`, x-only pubkey of secp256k1 `G`. +const FIXTURE_SK_HEX: &str = "0000000000000000000000000000000000000000000000000000000000000001"; +const FIXTURE_PK_HEX: &str = "79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798"; + +fn fixture_keys() -> Keys { + signer::from_sk(FIXTURE_SK_HEX).expect("fixture SK hex parses") +} + +// ── Key derivation ─────────────────────────────────────────────────────────── + +#[test] +fn signer_generates_keypair_from_secret_key() { + let keys = fixture_keys(); + assert_eq!(keys.public_key().to_hex(), FIXTURE_PK_HEX); +} + +// ── Random generation ──────────────────────────────────────────────────────── + +#[test] +fn signer_generates_random_keypair_when_no_secret_provided() { + let keys = signer::generate(); + assert_eq!(keys.public_key().to_hex().len(), 64); +} + +// ── NIP-44 ─────────────────────────────────────────────────────────────────── + +#[tokio::test] +async fn signer_nip44_encrypt_decrypt_roundtrip() { + let sender_keys = Keys::generate(); + let recipient_keys = Keys::generate(); + let plaintext = "Hello Encryption!"; + + let ciphertext = encrypt_nip44(&sender_keys, &recipient_keys.public_key(), plaintext) + .await + .expect("nip44 encrypt"); + + assert_ne!(ciphertext, plaintext); + + let decrypted = decrypt_nip44(&recipient_keys, &sender_keys.public_key(), &ciphertext) + .await + .expect("nip44 decrypt"); + + assert_eq!(decrypted, plaintext); +} + +// ── Public key ─────────────────────────────────────────────────────────────── + +#[test] +fn signer_get_public_key_returns_correct_key() { + let keys = fixture_keys(); + let expected_pk = PublicKey::parse(FIXTURE_PK_HEX).expect("fixture PK hex parses"); + assert_eq!(keys.public_key(), expected_pk); +} + +// ── Signed events ──────────────────────────────────────────────────────────── + +#[test] +fn signer_signed_event_has_valid_signature() { + let keys = fixture_keys(); + let event = EventBuilder::new(Kind::TextNote, "Hello Nostr!") + .sign_with_keys(&keys) + .expect("sign text note"); + + assert_eq!(event.pubkey, keys.public_key()); + event.verify().expect("verify signed event"); +} diff --git a/tests/conformance_stateless_mode.rs b/tests/conformance_stateless_mode.rs index 5c7637a..c8cc437 100644 --- a/tests/conformance_stateless_mode.rs +++ b/tests/conformance_stateless_mode.rs @@ -1,175 +1,176 @@ -//! Stateless-mode conformance tests for the client transport. - -use std::time::Duration; - -use contextvm_sdk::core::constants::{mcp_protocol_version, INITIALIZE_METHOD}; -use contextvm_sdk::core::types::{ - EncryptionMode, JsonRpcMessage, JsonRpcRequest, -}; -use contextvm_sdk::transport::client::{ - NostrClientTransport, NostrClientTransportConfig, -}; -use contextvm_sdk::signer; -use tokio::time::timeout; - -async fn make_stateless_transport() -> ( - NostrClientTransport, - tokio::sync::mpsc::UnboundedReceiver, -) { - let server_keys = signer::generate(); - let client_keys = signer::generate(); - - let config = NostrClientTransportConfig { - relay_urls: Vec::new(), - server_pubkey: server_keys.public_key().to_hex(), - encryption_mode: EncryptionMode::Optional, - is_stateless: true, - timeout: Duration::from_secs(1), - }; - - let mut transport = NostrClientTransport::new(client_keys, config) - .await - .expect("transport should be constructed"); - let rx = transport - .take_message_receiver() - .expect("message receiver should be available once"); - - (transport, rx) -} - -#[tokio::test] -async fn create_emulated_response_returns_correct_request_id() { - let (transport, mut rx) = make_stateless_transport().await; - - let request = JsonRpcMessage::Request(JsonRpcRequest { - jsonrpc: "2.0".to_string(), - id: serde_json::json!("test-id"), - method: INITIALIZE_METHOD.to_string(), - params: Some(serde_json::json!({ - "protocolVersion": mcp_protocol_version(), - "capabilities": {}, - "clientInfo": { "name": "conformance-test", "version": "0.0.0" } - })), - }); - - transport - .send(&request) - .await - .expect("initialize should be emulated in stateless mode"); - - let msg = timeout(Duration::from_millis(200), rx.recv()) - .await - .expect("should receive emulated response promptly") - .expect("channel should contain response"); - - match msg { - JsonRpcMessage::Response(resp) => { - assert_eq!(resp.id, serde_json::json!("test-id")); - assert_eq!(resp.jsonrpc, "2.0"); - assert_eq!( - resp.result - .get("protocolVersion") - .and_then(serde_json::Value::as_str), - Some(mcp_protocol_version()) - ); - assert_eq!( - resp.result - .get("serverInfo") - .and_then(|v| v.get("name")) - .and_then(serde_json::Value::as_str), - Some("Emulated-Stateless-Server") - ); - assert_eq!( - resp.result - .get("serverInfo") - .and_then(|v| v.get("version")) - .and_then(serde_json::Value::as_str), - Some("1.0.0") - ); - assert_eq!( - resp.result - .get("capabilities") - .and_then(|v| v.get("tools")) - .and_then(|v| v.get("listChanged")) - .and_then(serde_json::Value::as_bool), - Some(true) - ); - assert_eq!( - resp.result - .get("capabilities") - .and_then(|v| v.get("prompts")) - .and_then(|v| v.get("listChanged")) - .and_then(serde_json::Value::as_bool), - Some(true) - ); - assert_eq!( - resp.result - .get("capabilities") - .and_then(|v| v.get("resources")) - .and_then(|v| v.get("subscribe")) - .and_then(serde_json::Value::as_bool), - Some(true) - ); - assert_eq!( - resp.result - .get("capabilities") - .and_then(|v| v.get("resources")) - .and_then(|v| v.get("listChanged")) - .and_then(serde_json::Value::as_bool), - Some(true) - ); - } - other => panic!("expected Response, got {other:?}"), - } - - let duplicate = timeout(Duration::from_millis(100), rx.recv()).await; - assert!( - duplicate.is_err(), - "initialize request should emit exactly one emulated response" - ); -} - -#[tokio::test] -async fn should_handle_statelessly_returns_true_for_initialize() { - let (transport, mut rx) = make_stateless_transport().await; - - let request = JsonRpcMessage::Request(JsonRpcRequest { - jsonrpc: "2.0".to_string(), - id: serde_json::json!(1), - method: INITIALIZE_METHOD.to_string(), - params: None, - }); - - transport - .send(&request) - .await - .expect("initialize should be handled statelessly"); - - let msg = timeout(Duration::from_millis(200), rx.recv()) - .await - .expect("initialize should produce local emulated response") - .expect("response should be delivered"); - - assert_eq!(msg.id(), Some(&serde_json::json!(1))); -} - -#[tokio::test] -async fn should_handle_statelessly_returns_false_for_other_methods() { - let (transport, mut rx) = make_stateless_transport().await; - - let request = JsonRpcMessage::Request(JsonRpcRequest { - jsonrpc: "2.0".to_string(), - id: serde_json::json!(2), - method: "tools/list".to_string(), - params: None, - }); - - let _send_result = transport.send(&request).await; - - let recv_result = timeout(Duration::from_millis(200), rx.recv()).await; - assert!( - recv_result.is_err(), - "non-initialize request should not create a local emulated response" - ); - -} +//! Stateless-mode conformance tests for the client transport. + +use std::time::Duration; + +use contextvm_sdk::core::constants::{mcp_protocol_version, INITIALIZE_METHOD}; +use contextvm_sdk::core::types::{ + EncryptionMode, JsonRpcMessage, JsonRpcRequest, +}; +use contextvm_sdk::transport::client::{ + NostrClientTransport, NostrClientTransportConfig, +}; +use contextvm_sdk::signer; +use tokio::time::timeout; + +async fn make_stateless_transport() -> ( + NostrClientTransport, + tokio::sync::mpsc::UnboundedReceiver, +) { + let server_keys = signer::generate(); + let client_keys = signer::generate(); + + let config = NostrClientTransportConfig { + relay_urls: Vec::new(), + server_pubkey: server_keys.public_key().to_hex(), + encryption_mode: EncryptionMode::Optional, + is_stateless: true, + timeout: Duration::from_secs(1), + log_file_path: None, + }; + + let mut transport = NostrClientTransport::new(client_keys, config) + .await + .expect("transport should be constructed"); + let rx = transport + .take_message_receiver() + .expect("message receiver should be available once"); + + (transport, rx) +} + +#[tokio::test] +async fn create_emulated_response_returns_correct_request_id() { + let (transport, mut rx) = make_stateless_transport().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("test-id"), + method: INITIALIZE_METHOD.to_string(), + params: Some(serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "capabilities": {}, + "clientInfo": { "name": "conformance-test", "version": "0.0.0" } + })), + }); + + transport + .send(&request) + .await + .expect("initialize should be emulated in stateless mode"); + + let msg = timeout(Duration::from_millis(200), rx.recv()) + .await + .expect("should receive emulated response promptly") + .expect("channel should contain response"); + + match msg { + JsonRpcMessage::Response(resp) => { + assert_eq!(resp.id, serde_json::json!("test-id")); + assert_eq!(resp.jsonrpc, "2.0"); + assert_eq!( + resp.result + .get("protocolVersion") + .and_then(serde_json::Value::as_str), + Some(mcp_protocol_version()) + ); + assert_eq!( + resp.result + .get("serverInfo") + .and_then(|v| v.get("name")) + .and_then(serde_json::Value::as_str), + Some("Emulated-Stateless-Server") + ); + assert_eq!( + resp.result + .get("serverInfo") + .and_then(|v| v.get("version")) + .and_then(serde_json::Value::as_str), + Some("1.0.0") + ); + assert_eq!( + resp.result + .get("capabilities") + .and_then(|v| v.get("tools")) + .and_then(|v| v.get("listChanged")) + .and_then(serde_json::Value::as_bool), + Some(true) + ); + assert_eq!( + resp.result + .get("capabilities") + .and_then(|v| v.get("prompts")) + .and_then(|v| v.get("listChanged")) + .and_then(serde_json::Value::as_bool), + Some(true) + ); + assert_eq!( + resp.result + .get("capabilities") + .and_then(|v| v.get("resources")) + .and_then(|v| v.get("subscribe")) + .and_then(serde_json::Value::as_bool), + Some(true) + ); + assert_eq!( + resp.result + .get("capabilities") + .and_then(|v| v.get("resources")) + .and_then(|v| v.get("listChanged")) + .and_then(serde_json::Value::as_bool), + Some(true) + ); + } + other => panic!("expected Response, got {other:?}"), + } + + let duplicate = timeout(Duration::from_millis(100), rx.recv()).await; + assert!( + duplicate.is_err(), + "initialize request should emit exactly one emulated response" + ); +} + +#[tokio::test] +async fn should_handle_statelessly_returns_true_for_initialize() { + let (transport, mut rx) = make_stateless_transport().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: INITIALIZE_METHOD.to_string(), + params: None, + }); + + transport + .send(&request) + .await + .expect("initialize should be handled statelessly"); + + let msg = timeout(Duration::from_millis(200), rx.recv()) + .await + .expect("initialize should produce local emulated response") + .expect("response should be delivered"); + + assert_eq!(msg.id(), Some(&serde_json::json!(1))); +} + +#[tokio::test] +async fn should_handle_statelessly_returns_false_for_other_methods() { + let (transport, mut rx) = make_stateless_transport().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(2), + method: "tools/list".to_string(), + params: None, + }); + + let _send_result = transport.send(&request).await; + + let recv_result = timeout(Duration::from_millis(200), rx.recv()).await; + assert!( + recv_result.is_err(), + "non-initialize request should not create a local emulated response" + ); + +} From 53bb7ed0ad80211eb0bb8c7e32e732f0e1d98c51 Mon Sep 17 00:00:00 2001 From: Anshuman Singh Date: Thu, 9 Apr 2026 22:32:39 +0530 Subject: [PATCH 35/69] fix: enforce inbound encryption policy in client event loop The client event_loop accepted the encryption_mode parameter but never checked it (_encryption_mode with leading underscore). This let plaintext events through when EncryptionMode::Required was set, and encrypted events through when EncryptionMode::Disabled was set. Changes: - Rename _encryption_mode -> encryption_mode so it is actually used. - Add is_gift_wrap_kind() and violates_encryption_policy() helpers. - Guard the event loop: drop events that violate the configured policy before any decrypt/parse work. - Add unit tests proving Required drops plaintext, Disabled drops encrypted, and Optional accepts both. --- src/transport/client.rs | 113 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 108 insertions(+), 5 deletions(-) diff --git a/src/transport/client.rs b/src/transport/client.rs index 640f207..b61704d 100644 --- a/src/transport/client.rs +++ b/src/transport/client.rs @@ -265,17 +265,32 @@ impl NostrClientTransport { pending: Arc>>, server_pubkey: PublicKey, tx: tokio::sync::mpsc::UnboundedSender, - _encryption_mode: EncryptionMode, + encryption_mode: EncryptionMode, ) { let mut notifications = client.notifications(); while let Ok(notification) = notifications.recv().await { if let RelayPoolNotification::Event { event, .. } = notification { + let is_gift_wrap = is_gift_wrap_kind(&event.kind); + + // Enforce mode before decrypt/parse. + if violates_encryption_policy(&event.kind, &encryption_mode) { + if is_gift_wrap { + tracing::warn!( + event_id = %event.id.to_hex(), + "Received encrypted response but encryption is disabled" + ); + } else { + tracing::warn!( + event_id = %event.id.to_hex(), + "Received unencrypted response but encryption is required" + ); + } + continue; + } + // Handle gift-wrapped events - let (actual_event_content, actual_pubkey, e_tag) = if event.kind - == Kind::Custom(GIFT_WRAP_KIND) - || event.kind == Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND) - { + let (actual_event_content, actual_pubkey, e_tag) = if is_gift_wrap { // Single-layer NIP-44 decrypt (matches JS/TS SDK) let signer = match client.signer().await { Ok(s) => s, @@ -362,6 +377,20 @@ impl NostrClientTransport { } } +#[inline] +fn is_gift_wrap_kind(kind: &Kind) -> bool { + *kind == Kind::Custom(GIFT_WRAP_KIND) || *kind == Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND) +} + +/// Returns `true` when the inbound event kind violates the configured encryption +/// policy and must be dropped before any further processing. +#[inline] +fn violates_encryption_policy(kind: &Kind, mode: &EncryptionMode) -> bool { + let is_gift_wrap = is_gift_wrap_kind(kind); + (is_gift_wrap && *mode == EncryptionMode::Disabled) + || (!is_gift_wrap && *mode == EncryptionMode::Required) +} + #[cfg(test)] mod tests { use super::*; @@ -437,4 +466,78 @@ mod tests { }); assert_eq!(init_notif.method(), Some("notifications/initialized")); } + + #[test] + fn test_gift_wrap_kind_detection() { + assert!(is_gift_wrap_kind(&Kind::Custom(GIFT_WRAP_KIND))); + assert!(is_gift_wrap_kind(&Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND))); + assert!(!is_gift_wrap_kind(&Kind::Custom(CTXVM_MESSAGES_KIND))); + } + + #[test] + fn test_required_mode_drops_plaintext() { + let plaintext_kind = Kind::Custom(CTXVM_MESSAGES_KIND); + assert!( + violates_encryption_policy(&plaintext_kind, &EncryptionMode::Required), + "Required mode must reject plaintext (non-gift-wrap) events" + ); + } + + #[test] + fn test_disabled_mode_drops_encrypted() { + assert!( + violates_encryption_policy(&Kind::Custom(GIFT_WRAP_KIND), &EncryptionMode::Disabled), + "Disabled mode must reject gift-wrap events" + ); + assert!( + violates_encryption_policy( + &Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND), + &EncryptionMode::Disabled + ), + "Disabled mode must reject ephemeral gift-wrap events" + ); + } + + #[test] + fn test_optional_mode_accepts_all() { + let plaintext = Kind::Custom(CTXVM_MESSAGES_KIND); + let gift_wrap = Kind::Custom(GIFT_WRAP_KIND); + let ephemeral = Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND); + assert!(!violates_encryption_policy( + &plaintext, + &EncryptionMode::Optional + )); + assert!(!violates_encryption_policy( + &gift_wrap, + &EncryptionMode::Optional + )); + assert!(!violates_encryption_policy( + &ephemeral, + &EncryptionMode::Optional + )); + } + + #[test] + fn test_required_mode_accepts_encrypted() { + assert!( + !violates_encryption_policy(&Kind::Custom(GIFT_WRAP_KIND), &EncryptionMode::Required), + "Required mode must accept gift-wrap events" + ); + assert!( + !violates_encryption_policy( + &Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND), + &EncryptionMode::Required + ), + "Required mode must accept ephemeral gift-wrap events" + ); + } + + #[test] + fn test_disabled_mode_accepts_plaintext() { + let plaintext = Kind::Custom(CTXVM_MESSAGES_KIND); + assert!( + !violates_encryption_policy(&plaintext, &EncryptionMode::Disabled), + "Disabled mode must accept plaintext events" + ); + } } From 392538bc7a06dbb39d9f840f93333515981a4372 Mon Sep 17 00:00:00 2001 From: Anurag <86455065+theAnuragMishra@users.noreply.github.com> Date: Mon, 13 Apr 2026 00:37:00 +0530 Subject: [PATCH 36/69] fix(transport): fix comment saying we should not use a since filter on gift-wrap subscriptions The comment says that we should not use a since filter on gift-wrap subscriptions which is not true. This commit fixes it. Also, this commit fixes the comment to include kind 21059 filter which is used right now but not mentioned in the comment. --- src/transport/base.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/transport/base.rs b/src/transport/base.rs index 2b34544..88fda0c 100644 --- a/src/transport/base.rs +++ b/src/transport/base.rs @@ -55,19 +55,17 @@ impl BaseTransport { /// Subscribe to events targeting a pubkey (both regular and encrypted). /// - /// Uses two filters: one for ephemeral ContextVM messages (kind 25910) - /// with `since: now()`, and one for NIP-59 gift wraps (kind 1059) without - /// a `since` constraint. Gift wraps use randomized timestamps per NIP-59, - /// so a `since: now()` filter would reject most incoming encrypted messages. + /// Uses three filters: one for ephemeral ContextVM messages (kind 25910) + /// and two for NIP-59 gift wraps (kinds 1059 and 21059). pub async fn subscribe_for_pubkey(&self, pubkey: &PublicKey) -> Result<()> { let p_tag = pubkey.to_hex(); + let now = Timestamp::now(); let ephemeral_filter = Filter::new() .kind(Kind::Custom(CTXVM_MESSAGES_KIND)) .custom_tag(SingleLetterTag::lowercase(Alphabet::P), p_tag.clone()) - .since(Timestamp::now()); + .since(now); - let now = Timestamp::now(); let gift_wrap_filter = Filter::new() .kind(Kind::Custom(GIFT_WRAP_KIND)) .custom_tag(SingleLetterTag::lowercase(Alphabet::P), p_tag.clone()) From d6fdb3d9be6a63329e0c92ab424ee0eeb36b6f83 Mon Sep 17 00:00:00 2001 From: Harsh Date: Fri, 10 Apr 2026 01:40:46 +0530 Subject: [PATCH 37/69] fix: implement gift-wrap deduplication using LRU cache --- Cargo.toml | 3 +++ src/transport/client.rs | 45 +++++++++++++++++++++++++++++++++++++++-- src/transport/server.rs | 37 ++++++++++++++++++++++++++++++++- 3 files changed, 82 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e4e03ba..95e675a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,9 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] } # Optional MCP integration (Rust equivalent to TS @modelcontextprotocol/sdk) rmcp = { version = "0.16.0", features = ["server", "client", "macros", "transport-worker"], optional = true } +# LRU cache for gift-wrap (outer event id) deduplication +lru = "0.12" + [features] # Enable rmcp by default while keeping legacy APIs available. default = ["rmcp"] diff --git a/src/transport/client.rs b/src/transport/client.rs index b61704d..f7d9541 100644 --- a/src/transport/client.rs +++ b/src/transport/client.rs @@ -4,9 +4,11 @@ //! kind 25910 events, correlates responses via `e` tag. use std::collections::HashSet; -use std::sync::Arc; +use std::num::NonZeroUsize; +use std::sync::{Arc, Mutex}; use std::time::Duration; +use lru::LruCache; use nostr_sdk::prelude::*; use tokio::sync::RwLock; @@ -59,6 +61,10 @@ pub struct NostrClientTransport { server_pubkey: PublicKey, /// Pending request event IDs awaiting responses. pending_requests: Arc>>, + /// Outer gift-wrap event IDs successfully decrypted and verified (inner `verify()`). + /// Duplicate outer ids are skipped before decrypt; ids are inserted only after success + /// so failed decrypt/verify can be retried on redelivery. + seen_gift_wrap_ids: Arc>>, /// Channel for receiving processed MCP messages from the event loop. message_tx: tokio::sync::mpsc::UnboundedSender, message_rx: Option>, @@ -91,6 +97,9 @@ impl NostrClientTransport { error })?); let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let seen_gift_wrap_ids = Arc::new(Mutex::new(LruCache::new( + NonZeroUsize::new(DEFAULT_LRU_SIZE).expect("DEFAULT_LRU_SIZE must be non-zero"), + ))); tracing::info!( target: LOG_TARGET, @@ -108,6 +117,7 @@ impl NostrClientTransport { config, server_pubkey, pending_requests: Arc::new(RwLock::new(HashSet::new())), + seen_gift_wrap_ids, message_tx: tx, message_rx: Some(rx), }) @@ -160,9 +170,18 @@ impl NostrClientTransport { let server_pubkey = self.server_pubkey; let tx = self.message_tx.clone(); let encryption_mode = self.config.encryption_mode; + let seen_gift_wrap_ids = self.seen_gift_wrap_ids.clone(); tokio::spawn(async move { - Self::event_loop(client, pending, server_pubkey, tx, encryption_mode).await; + Self::event_loop( + client, + pending, + server_pubkey, + tx, + encryption_mode, + seen_gift_wrap_ids, + ) + .await; }); tracing::info!( @@ -266,6 +285,7 @@ impl NostrClientTransport { server_pubkey: PublicKey, tx: tokio::sync::mpsc::UnboundedSender, encryption_mode: EncryptionMode, + seen_gift_wrap_ids: Arc>>, ) { let mut notifications = client.notifications(); @@ -291,6 +311,20 @@ impl NostrClientTransport { // Handle gift-wrapped events let (actual_event_content, actual_pubkey, e_tag) = if is_gift_wrap { + { + let guard = match seen_gift_wrap_ids.lock() { + Ok(g) => g, + Err(poisoned) => poisoned.into_inner(), + }; + if guard.contains(&event.id) { + tracing::debug!( + target: LOG_TARGET, + event_id = %event.id.to_hex(), + "Skipping duplicate gift-wrap (outer id)" + ); + continue; + } + } // Single-layer NIP-44 decrypt (matches JS/TS SDK) let signer = match client.signer().await { Ok(s) => s, @@ -313,6 +347,13 @@ impl NostrClientTransport { ); continue; } + { + let mut guard = match seen_gift_wrap_ids.lock() { + Ok(g) => g, + Err(poisoned) => poisoned.into_inner(), + }; + guard.put(event.id, ()); + } let e_tag = serializers::get_tag_value(&inner.tags, "e"); (inner.content, inner.pubkey, e_tag) } diff --git a/src/transport/server.rs b/src/transport/server.rs index 69f9d3a..c7cc912 100644 --- a/src/transport/server.rs +++ b/src/transport/server.rs @@ -5,9 +5,11 @@ //! server announcements. use std::collections::HashMap; -use std::sync::Arc; +use std::num::NonZeroUsize; +use std::sync::{Arc, Mutex}; use std::time::Duration; +use lru::LruCache; use nostr_sdk::prelude::*; use tokio::sync::RwLock; @@ -69,6 +71,10 @@ pub struct NostrServerTransport { sessions: Arc>>, /// Reverse lookup: event_id → client_pubkey_hex event_to_client: Arc>>, + /// Outer gift-wrap event IDs successfully decrypted and verified (inner `verify()`). + /// Duplicate outer ids are skipped before decrypt; ids are inserted only after success + /// so failed decrypt/verify can be retried on redelivery. + seen_gift_wrap_ids: Arc>>, /// Channel for incoming MCP messages (consumed by the MCP server). message_tx: tokio::sync::mpsc::UnboundedSender, message_rx: Option>, @@ -104,6 +110,9 @@ impl NostrServerTransport { error })?); let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let seen_gift_wrap_ids = Arc::new(Mutex::new(LruCache::new( + NonZeroUsize::new(DEFAULT_LRU_SIZE).expect("DEFAULT_LRU_SIZE must be non-zero"), + ))); tracing::info!( target: LOG_TARGET, @@ -121,6 +130,7 @@ impl NostrServerTransport { config, sessions: Arc::new(RwLock::new(HashMap::new())), event_to_client: Arc::new(RwLock::new(HashMap::new())), + seen_gift_wrap_ids, message_tx: tx, message_rx: Some(rx), }) @@ -175,6 +185,7 @@ impl NostrServerTransport { let allowed = self.config.allowed_public_keys.clone(); let excluded = self.config.excluded_capabilities.clone(); let encryption_mode = self.config.encryption_mode; + let seen_gift_wrap_ids = self.seen_gift_wrap_ids.clone(); tokio::spawn(async move { Self::event_loop( @@ -185,6 +196,7 @@ impl NostrServerTransport { allowed, excluded, encryption_mode, + seen_gift_wrap_ids, ) .await; }); @@ -585,6 +597,7 @@ impl NostrServerTransport { }) } + #[allow(clippy::too_many_arguments)] async fn event_loop( client: Arc, sessions: Arc>>, @@ -593,6 +606,7 @@ impl NostrServerTransport { allowed_pubkeys: Vec, excluded_capabilities: Vec, encryption_mode: EncryptionMode, + seen_gift_wrap_ids: Arc>>, ) { let mut notifications = client.notifications(); @@ -611,6 +625,20 @@ impl NostrServerTransport { ); continue; } + { + let guard = match seen_gift_wrap_ids.lock() { + Ok(g) => g, + Err(poisoned) => poisoned.into_inner(), + }; + if guard.contains(&event.id) { + tracing::debug!( + target: LOG_TARGET, + event_id = %event.id.to_hex(), + "Skipping duplicate gift-wrap (outer id)" + ); + continue; + } + } // Single-layer NIP-44 decrypt (matches JS/TS SDK) let signer = match client.signer().await { Ok(s) => s, @@ -636,6 +664,13 @@ impl NostrServerTransport { ); continue; } + { + let mut guard = match seen_gift_wrap_ids.lock() { + Ok(g) => g, + Err(poisoned) => poisoned.into_inner(), + }; + guard.put(event.id, ()); + } ( inner.content, inner.pubkey.to_hex(), From b7fe8efb18c5a065229280293489b7f5fee5233c Mon Sep 17 00:00:00 2001 From: Harsh Date: Sun, 12 Apr 2026 04:24:07 +0530 Subject: [PATCH 38/69] refactor: extract RelayPoolTrait and implement MockRelayPool for testing --- src/relay/mock.rs | 306 ++++++++++++++++++++++++++++++++++++++++ src/relay/mod.rs | 70 +++++++++ src/transport/base.rs | 5 +- src/transport/client.rs | 29 ++-- src/transport/server.rs | 29 ++-- 5 files changed, 408 insertions(+), 31 deletions(-) create mode 100644 src/relay/mock.rs diff --git a/src/relay/mock.rs b/src/relay/mock.rs new file mode 100644 index 0000000..53039b2 --- /dev/null +++ b/src/relay/mock.rs @@ -0,0 +1,306 @@ +//! In-memory mock relay pool for network-free testing. +//! +//! Mirrors the design of the TypeScript SDK's `MockRelayHub`: +//! - `publish_event` stores the event and broadcasts it to all `notifications()` receivers. +//! - `subscribe` registers filters and immediately replays matching stored events through the +//! broadcast, so listeners that called `notifications()` before `subscribe()` see the replay. +//! - `connect` / `disconnect` are no-ops — no sockets are opened. +//! - Signing uses a freshly generated ephemeral `Keys`; `signer()` returns it wrapped in `Arc` +//! so encryption code can call it without any real relay connection. + +use std::collections::HashMap; +use std::sync::Arc; + +use async_trait::async_trait; +use tokio::sync::Mutex; + +use nostr_sdk::prelude::*; + +use crate::core::error::{Error, Result}; +use crate::relay::RelayPoolTrait; + +// ── Internal state ──────────────────────────────────────────────────────────── + +struct MockRelayInner { + events: Vec, + /// Active subscriptions: id → filters registered by that subscription. + subscriptions: HashMap>, + next_sub_id: u32, +} + +impl MockRelayInner { + fn new() -> Self { + Self { + events: Vec::new(), + subscriptions: HashMap::new(), + next_sub_id: 0, + } + } +} + +// ── Public struct ───────────────────────────────────────────────────────────── + +/// In-memory relay pool for deterministic, network-free testing. +/// +/// Create one with [`MockRelayPool::new`] and pass it (wrapped in `Arc`) wherever +/// an `Arc` is expected. +pub struct MockRelayPool { + inner: Arc>, + /// Broadcast sender — every published event is sent here so that all + /// `notifications()` receivers see it. + notification_tx: tokio::sync::broadcast::Sender, + /// Ephemeral key used for signing in `publish` / `sign` / `signer`. + keys: Keys, +} + +impl MockRelayPool { + /// Create a new mock relay pool with a freshly generated ephemeral signing key. + pub fn new() -> Self { + let keys = Keys::generate(); + let (tx, _rx) = tokio::sync::broadcast::channel(1024); + Self { + inner: Arc::new(Mutex::new(MockRelayInner::new())), + notification_tx: tx, + keys, + } + } + + /// The ephemeral public key used by this mock for signing. + pub fn mock_public_key(&self) -> PublicKey { + self.keys.public_key() + } + + /// Clone of all events published so far (useful for assertions in tests). + pub async fn stored_events(&self) -> Vec { + self.inner.lock().await.events.clone() + } +} + +impl Default for MockRelayPool { + fn default() -> Self { + Self::new() + } +} + +// ── RelayPoolTrait impl ─────────────────────────────────────────────────────── + +#[async_trait] +impl RelayPoolTrait for MockRelayPool { + /// No-op: the mock has no sockets to open. + async fn connect(&self, _relay_urls: &[String]) -> Result<()> { + Ok(()) + } + + /// No-op: the mock has no sockets to close. + async fn disconnect(&self) -> Result<()> { + Ok(()) + } + + /// Store the event and broadcast it to all current `notifications()` receivers. + async fn publish_event(&self, event: &Event) -> Result { + let event_id = event.id; + + { + let mut inner = self.inner.lock().await; + inner.events.push(event.clone()); + } + + // Always broadcast — consumers filter by kind/pubkey/tag themselves, + // which mirrors how nostr-sdk's real notification stream works. + let notification = make_notification(event.clone()); + // Ignore send errors: they just mean there are no active receivers yet. + let _ = self.notification_tx.send(notification); + + Ok(event_id) + } + + /// Sign `builder` with the ephemeral key, then call `publish_event`. + async fn publish(&self, builder: EventBuilder) -> Result { + let event = sign_with_keys(builder, &self.keys)?; + let id = event.id; + self.publish_event(&event).await?; + Ok(id) + } + + /// Sign `builder` with the ephemeral key and return the event without publishing. + async fn sign(&self, builder: EventBuilder) -> Result { + sign_with_keys(builder, &self.keys) + } + + /// Return the ephemeral key as a signer. + async fn signer(&self) -> Result> { + Ok(Arc::new(self.keys.clone()) as Arc) + } + + /// Return a new broadcast receiver. Each call gets an independent receiver + /// that sees all events published *after* this call, plus any replayed by + /// a subsequent `subscribe()`. + fn notifications(&self) -> tokio::sync::broadcast::Receiver { + self.notification_tx.subscribe() + } + + /// Return the ephemeral public key. + async fn public_key(&self) -> Result { + Ok(self.keys.public_key()) + } + + /// Register the filters and immediately replay any already-stored events that + /// match them through the broadcast channel, mirroring the behaviour of a + /// real relay that sends historical events before EOSE. + async fn subscribe(&self, filters: Vec) -> Result<()> { + let replay = { + let mut inner = self.inner.lock().await; + let sub_id = inner.next_sub_id; + inner.next_sub_id += 1; + + // Store filters first so the replay read comes from the stored value, + // ensuring the field is both written and read (no dead-code warning). + inner.subscriptions.insert(sub_id, filters); + + // Clone events so we can release the events borrow before borrowing subscriptions. + let events_snapshot = inner.events.clone(); + let stored = inner.subscriptions.get(&sub_id).expect("just inserted"); + events_snapshot + .into_iter() + .filter(|e| { + stored + .iter() + .any(|f| f.match_event(e, MatchEventOptions::default())) + }) + .collect::>() + }; + + for event in replay { + let _ = self.notification_tx.send(make_notification(event)); + } + + Ok(()) + } +} + +// ── Helpers ─────────────────────────────────────────────────────────────────── + +fn sign_with_keys(builder: EventBuilder, keys: &Keys) -> Result { + builder + .sign_with_keys(keys) + .map_err(|e| Error::Transport(e.to_string())) +} + +fn make_notification(event: Event) -> RelayPoolNotification { + RelayPoolNotification::Event { + relay_url: RelayUrl::parse("wss://mock.relay").expect("hardcoded URL"), + subscription_id: SubscriptionId::generate(), + event: Box::new(event), + } +} + +// ── Unit tests ──────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn connect_and_disconnect_are_noops() { + let pool = MockRelayPool::new(); + assert!(pool.connect(&["wss://unused".to_string()]).await.is_ok()); + assert!(pool.disconnect().await.is_ok()); + } + + #[tokio::test] + async fn publish_event_stores_and_broadcasts() { + let pool = MockRelayPool::new(); + let mut rx = pool.notifications(); + + let keys = Keys::generate(); + let event = EventBuilder::new(Kind::TextNote, "hello") + .sign_with_keys(&keys) + .unwrap(); + + pool.publish_event(&event).await.unwrap(); + + assert_eq!(pool.stored_events().await.len(), 1); + let notif = rx.try_recv().unwrap(); + if let RelayPoolNotification::Event { event: e, .. } = notif { + assert_eq!(e.id, event.id); + } else { + panic!("expected Event notification"); + } + } + + #[tokio::test] + async fn publish_signs_and_stores() { + let pool = MockRelayPool::new(); + let builder = EventBuilder::new(Kind::TextNote, "signed"); + pool.publish(builder).await.unwrap(); + let stored = pool.stored_events().await; + assert_eq!(stored.len(), 1); + assert_eq!(stored[0].pubkey, pool.mock_public_key()); + } + + #[tokio::test] + async fn sign_does_not_publish() { + let pool = MockRelayPool::new(); + let builder = EventBuilder::new(Kind::TextNote, "unsigned"); + let event = pool.sign(builder).await.unwrap(); + assert_eq!(event.pubkey, pool.mock_public_key()); + assert!(pool.stored_events().await.is_empty()); + } + + #[tokio::test] + async fn signer_uses_same_key_as_publish() { + let pool = MockRelayPool::new(); + let signer = pool.signer().await.unwrap(); + let expected_pubkey = pool.mock_public_key(); + assert_eq!(signer.get_public_key().await.unwrap(), expected_pubkey); + } + + #[tokio::test] + async fn subscribe_replays_matching_stored_events() { + let pool = MockRelayPool::new(); + let mut rx = pool.notifications(); + + // Pre-publish two events + let keys = Keys::generate(); + let e1 = EventBuilder::new(Kind::TextNote, "one") + .sign_with_keys(&keys) + .unwrap(); + let e2 = EventBuilder::new(Kind::Custom(9999), "two") + .sign_with_keys(&keys) + .unwrap(); + pool.publish_event(&e1).await.unwrap(); + pool.publish_event(&e2).await.unwrap(); + + // Drain the two publish notifications + rx.try_recv().unwrap(); + rx.try_recv().unwrap(); + + // Subscribe for TextNote only — e1 should be replayed, e2 not + let filter = Filter::new().kind(Kind::TextNote); + pool.subscribe(vec![filter]).await.unwrap(); + + let replayed = rx.try_recv().unwrap(); + if let RelayPoolNotification::Event { event, .. } = replayed { + assert_eq!(event.id, e1.id); + } else { + panic!("expected replayed Event notification"); + } + // e2 should not be replayed + assert!(rx.try_recv().is_err()); + } + + #[tokio::test] + async fn notifications_receives_future_publishes() { + let pool = MockRelayPool::new(); + let mut rx = pool.notifications(); + + let keys = Keys::generate(); + let event = EventBuilder::new(Kind::TextNote, "future") + .sign_with_keys(&keys) + .unwrap(); + pool.publish_event(&event).await.unwrap(); + + let notif = rx.try_recv().unwrap(); + assert!(matches!(notif, RelayPoolNotification::Event { .. })); + } +} diff --git a/src/relay/mod.rs b/src/relay/mod.rs index dedc94f..198a5a7 100644 --- a/src/relay/mod.rs +++ b/src/relay/mod.rs @@ -2,10 +2,38 @@ //! //! Wraps nostr-sdk's Client for relay connection, event publishing, and subscription. +pub mod mock; +pub use mock::MockRelayPool; + +use async_trait::async_trait; + use crate::core::error::{Error, Result}; use nostr_sdk::prelude::*; use std::sync::Arc; +/// Trait abstracting relay pool operations, enabling dependency injection and testing. +#[async_trait] +pub trait RelayPoolTrait: Send + Sync { + /// Connect to the given relay URLs. + async fn connect(&self, relay_urls: &[String]) -> Result<()>; + /// Disconnect from all relays. + async fn disconnect(&self) -> Result<()>; + /// Publish a pre-built event to relays. + async fn publish_event(&self, event: &Event) -> Result; + /// Build, sign, and publish an event from a builder. + async fn publish(&self, builder: EventBuilder) -> Result; + /// Sign an event builder without publishing. + async fn sign(&self, builder: EventBuilder) -> Result; + /// Get the signer associated with this relay pool. + async fn signer(&self) -> Result>; + /// Get notifications receiver for event streaming. + fn notifications(&self) -> tokio::sync::broadcast::Receiver; + /// Get the public key of the signer. + async fn public_key(&self) -> Result; + /// Subscribe to events matching filters. + async fn subscribe(&self, filters: Vec) -> Result<()>; +} + /// Relay pool wrapper for managing Nostr relay connections. pub struct RelayPool { client: Arc, @@ -106,3 +134,45 @@ impl RelayPool { Ok(()) } } + +#[async_trait] +impl RelayPoolTrait for RelayPool { + async fn connect(&self, relay_urls: &[String]) -> Result<()> { + RelayPool::connect(self, relay_urls).await + } + + async fn disconnect(&self) -> Result<()> { + RelayPool::disconnect(self).await + } + + async fn publish_event(&self, event: &Event) -> Result { + RelayPool::publish_event(self, event).await + } + + async fn publish(&self, builder: EventBuilder) -> Result { + RelayPool::publish(self, builder).await + } + + async fn sign(&self, builder: EventBuilder) -> Result { + RelayPool::sign(self, builder).await + } + + async fn signer(&self) -> Result> { + self.client + .signer() + .await + .map_err(|e| Error::Other(e.to_string())) + } + + fn notifications(&self) -> tokio::sync::broadcast::Receiver { + RelayPool::notifications(self) + } + + async fn public_key(&self) -> Result { + RelayPool::public_key(self).await + } + + async fn subscribe(&self, filters: Vec) -> Result<()> { + RelayPool::subscribe(self, filters).await + } +} diff --git a/src/transport/base.rs b/src/transport/base.rs index 88fda0c..4f488a5 100644 --- a/src/transport/base.rs +++ b/src/transport/base.rs @@ -9,7 +9,7 @@ use crate::core::serializers; use crate::core::types::{EncryptionMode, JsonRpcMessage}; use crate::core::validation; use crate::encryption; -use crate::relay::RelayPool; +use crate::relay::RelayPoolTrait; const LOG_TARGET: &str = "contextvm_sdk::transport::base"; @@ -20,7 +20,7 @@ const LOG_TARGET: &str = "contextvm_sdk::transport::base"; /// and [`NostrServerTransport`](super::server::NostrServerTransport). pub struct BaseTransport { /// The relay pool for publishing and subscribing to Nostr events. - pub relay_pool: Arc, + pub relay_pool: Arc, /// The encryption policy for outgoing messages. pub encryption_mode: EncryptionMode, /// Whether the transport is currently connected to relays. @@ -125,7 +125,6 @@ impl BaseTransport { serde_json::to_string(&event).map_err(|e| Error::Encryption(e.to_string()))?; let signer = self .relay_pool - .client() .signer() .await .map_err(|e| Error::Encryption(e.to_string()))?; diff --git a/src/transport/client.rs b/src/transport/client.rs index f7d9541..23f2853 100644 --- a/src/transport/client.rs +++ b/src/transport/client.rs @@ -18,7 +18,7 @@ use crate::core::serializers; use crate::core::types::*; use crate::core::validation; use crate::encryption; -use crate::relay::RelayPool; +use crate::relay::{RelayPool, RelayPoolTrait}; use crate::transport::base::BaseTransport; use crate::util::tracing_setup; @@ -88,14 +88,15 @@ impl NostrClientTransport { Error::Other(format!("Invalid server pubkey: {error}")) })?; - let relay_pool = Arc::new(RelayPool::new(signer).await.map_err(|error| { - tracing::error!( - target: LOG_TARGET, - error = %error, - "Failed to initialize relay pool for client transport" - ); - error - })?); + let relay_pool: Arc = + Arc::new(RelayPool::new(signer).await.map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to initialize relay pool for client transport" + ); + error + })?); let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); let seen_gift_wrap_ids = Arc::new(Mutex::new(LruCache::new( NonZeroUsize::new(DEFAULT_LRU_SIZE).expect("DEFAULT_LRU_SIZE must be non-zero"), @@ -165,7 +166,7 @@ impl NostrClientTransport { })?; // Spawn event loop - let client = self.base.relay_pool.client().clone(); + let relay_pool = Arc::clone(&self.base.relay_pool); let pending = self.pending_requests.clone(); let server_pubkey = self.server_pubkey; let tx = self.message_tx.clone(); @@ -174,7 +175,7 @@ impl NostrClientTransport { tokio::spawn(async move { Self::event_loop( - client, + relay_pool, pending, server_pubkey, tx, @@ -280,14 +281,14 @@ impl NostrClientTransport { } async fn event_loop( - client: Arc, + relay_pool: Arc, pending: Arc>>, server_pubkey: PublicKey, tx: tokio::sync::mpsc::UnboundedSender, encryption_mode: EncryptionMode, seen_gift_wrap_ids: Arc>>, ) { - let mut notifications = client.notifications(); + let mut notifications = relay_pool.notifications(); while let Ok(notification) = notifications.recv().await { if let RelayPoolNotification::Event { event, .. } = notification { @@ -326,7 +327,7 @@ impl NostrClientTransport { } } // Single-layer NIP-44 decrypt (matches JS/TS SDK) - let signer = match client.signer().await { + let signer = match relay_pool.signer().await { Ok(s) => s, Err(error) => { tracing::error!( diff --git a/src/transport/server.rs b/src/transport/server.rs index c7cc912..9a41c2c 100644 --- a/src/transport/server.rs +++ b/src/transport/server.rs @@ -18,7 +18,7 @@ use crate::core::error::{Error, Result}; use crate::core::types::*; use crate::core::validation; use crate::encryption; -use crate::relay::RelayPool; +use crate::relay::{RelayPool, RelayPoolTrait}; use crate::transport::base::BaseTransport; use crate::util::tracing_setup; @@ -101,14 +101,15 @@ impl NostrServerTransport { { tracing_setup::init_tracer(config.log_file_path.as_deref())?; - let relay_pool = Arc::new(RelayPool::new(signer).await.map_err(|error| { - tracing::error!( - target: LOG_TARGET, - error = %error, - "Failed to initialize relay pool for server transport" - ); - error - })?); + let relay_pool: Arc = + Arc::new(RelayPool::new(signer).await.map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to initialize relay pool for server transport" + ); + error + })?); let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); let seen_gift_wrap_ids = Arc::new(Mutex::new(LruCache::new( NonZeroUsize::new(DEFAULT_LRU_SIZE).expect("DEFAULT_LRU_SIZE must be non-zero"), @@ -178,7 +179,7 @@ impl NostrServerTransport { })?; // Spawn event loop - let client = self.base.relay_pool.client().clone(); + let relay_pool = Arc::clone(&self.base.relay_pool); let sessions = self.sessions.clone(); let event_to_client = self.event_to_client.clone(); let tx = self.message_tx.clone(); @@ -189,7 +190,7 @@ impl NostrServerTransport { tokio::spawn(async move { Self::event_loop( - client, + relay_pool, sessions, event_to_client, tx, @@ -599,7 +600,7 @@ impl NostrServerTransport { #[allow(clippy::too_many_arguments)] async fn event_loop( - client: Arc, + relay_pool: Arc, sessions: Arc>>, event_to_client: Arc>>, tx: tokio::sync::mpsc::UnboundedSender, @@ -608,7 +609,7 @@ impl NostrServerTransport { encryption_mode: EncryptionMode, seen_gift_wrap_ids: Arc>>, ) { - let mut notifications = client.notifications(); + let mut notifications = relay_pool.notifications(); while let Ok(notification) = notifications.recv().await { if let RelayPoolNotification::Event { event, .. } = notification { @@ -640,7 +641,7 @@ impl NostrServerTransport { } } // Single-layer NIP-44 decrypt (matches JS/TS SDK) - let signer = match client.signer().await { + let signer = match relay_pool.signer().await { Ok(s) => s, Err(error) => { tracing::error!( From 26dfe1371d032a45659ce0be71638e91adf8803e Mon Sep 17 00:00:00 2001 From: Harsh Date: Wed, 22 Apr 2026 06:53:59 +0530 Subject: [PATCH 39/69] test: add conformance tests for gift-wrap deduplication --- tests/conformance_dedup.rs | 113 +++++++++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 tests/conformance_dedup.rs diff --git a/tests/conformance_dedup.rs b/tests/conformance_dedup.rs new file mode 100644 index 0000000..3f53050 --- /dev/null +++ b/tests/conformance_dedup.rs @@ -0,0 +1,113 @@ +//! Conformance tests for gift-wrap deduplication via LRU cache. +//! +//! Both the client and server transports use an `LruCache` to skip +//! duplicate outer gift-wrap event IDs. The dedup check happens *before* decrypt +//! and the insert happens only *after* successful decrypt + inner `verify()`. +//! These tests exercise the LRU cache logic in isolation — no async, no transport. + +use std::num::NonZeroUsize; + +use lru::LruCache; +use nostr_sdk::prelude::*; + +use contextvm_sdk::core::constants::DEFAULT_LRU_SIZE; + +/// Helper: build a cache with the same capacity used by both transports. +fn new_dedup_cache() -> LruCache { + LruCache::new(NonZeroUsize::new(DEFAULT_LRU_SIZE).expect("DEFAULT_LRU_SIZE must be non-zero")) +} + +fn event_id_from_byte(b: u8) -> EventId { + EventId::from_byte_array([b; 32]) +} + +// ── Gift-wrap kind 1059 dedup ───────────────────────────────────────────────── + +#[test] +fn client_dedup_skips_duplicate_outer_gift_wrap_id() { + let mut cache = new_dedup_cache(); + let outer_id = event_id_from_byte(0x01); + + // First delivery: not yet seen, decrypt succeeds, insert into cache. + assert!( + !cache.contains(&outer_id), + "first delivery must not be in cache yet" + ); + cache.put(outer_id, ()); + + // Second delivery: same outer id is already cached, skip before decrypt. + assert!( + cache.contains(&outer_id), + "second delivery of the same outer id must be rejected" + ); +} + +#[test] +fn client_dedup_ephemeral_gift_wrap_skips_duplicate() { + let mut cache = new_dedup_cache(); + let ephemeral_outer_id = event_id_from_byte(0xE1); + + // First delivery of an ephemeral gift-wrap (kind 21059). + assert!( + !cache.contains(&ephemeral_outer_id), + "first delivery must not be in cache yet" + ); + cache.put(ephemeral_outer_id, ()); + + // Second delivery: same outer id is already cached, skip before decrypt. + assert!( + cache.contains(&ephemeral_outer_id), + "second delivery of the same ephemeral outer id must be rejected" + ); +} + +// ── Server dedup ────────────────────────────────────────────────────────────── + +#[test] +fn server_dedup_ephemeral_gift_wrap_skips_duplicate() { + let mut cache = new_dedup_cache(); + let ephemeral_outer_id = event_id_from_byte(0xE2); + + // First delivery of an ephemeral gift-wrap (kind 21059). + assert!( + !cache.contains(&ephemeral_outer_id), + "first delivery must not be in cache yet" + ); + cache.put(ephemeral_outer_id, ()); + + // Second delivery: same outer id is already cached, skip before decrypt. + assert!( + cache.contains(&ephemeral_outer_id), + "second delivery of the same ephemeral outer id must be rejected" + ); +} + +#[test] +fn server_dedup_lru_evicts_oldest_when_capacity_reached() { + let capacity = 3; + let mut cache: LruCache = + LruCache::new(NonZeroUsize::new(capacity).expect("non-zero")); + + let id_0 = event_id_from_byte(0x00); + let id_1 = event_id_from_byte(0x01); + let id_2 = event_id_from_byte(0x02); + let id_3 = event_id_from_byte(0x03); + + cache.put(id_0, ()); + cache.put(id_1, ()); + cache.put(id_2, ()); + + // Cache is at capacity (3). Inserting a fourth must evict the oldest (id_0). + cache.put(id_3, ()); + + assert!( + !cache.contains(&id_0), + "oldest entry must be evicted when capacity is exceeded" + ); + assert!(cache.contains(&id_1), "second entry must still be present"); + assert!(cache.contains(&id_2), "third entry must still be present"); + assert!( + cache.contains(&id_3), + "newly inserted entry must be present" + ); +} From ce94af140d4ac19db5b7a0601a06afb0fcfe4c36 Mon Sep 17 00:00:00 2001 From: Harsh Date: Wed, 22 Apr 2026 08:30:33 +0530 Subject: [PATCH 40/69] refactor: extract ClientCorrelationStore and ServerEventRouteStore into dedicated modules --- src/lib.rs | 8 +- src/transport/client/correlation_store.rs | 74 +++++++++++++ src/transport/{client.rs => client/mod.rs} | 21 ++-- src/transport/mod.rs | 4 +- src/transport/server/correlation_store.rs | 114 +++++++++++++++++++++ src/transport/{server.rs => server/mod.rs} | 91 ++++++++-------- 6 files changed, 250 insertions(+), 62 deletions(-) create mode 100644 src/transport/client/correlation_store.rs rename src/transport/{client.rs => client/mod.rs} (97%) create mode 100644 src/transport/server/correlation_store.rs rename src/transport/{server.rs => server/mod.rs} (94%) diff --git a/src/lib.rs b/src/lib.rs index 2157224..80f5945 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -57,8 +57,12 @@ pub use core::types::{ }; pub use discovery::ServerAnnouncement; pub use relay::RelayPool; -pub use transport::client::{NostrClientTransport, NostrClientTransportConfig}; -pub use transport::server::{IncomingRequest, NostrServerTransport, NostrServerTransportConfig}; +pub use transport::client::{ + ClientCorrelationStore, NostrClientTransport, NostrClientTransportConfig, +}; +pub use transport::server::{ + IncomingRequest, NostrServerTransport, NostrServerTransportConfig, ServerEventRouteStore, +}; #[cfg(feature = "rmcp")] pub use rmcp; diff --git a/src/transport/client/correlation_store.rs b/src/transport/client/correlation_store.rs new file mode 100644 index 0000000..e080fa5 --- /dev/null +++ b/src/transport/client/correlation_store.rs @@ -0,0 +1,74 @@ +//! Client-side correlation store for tracking pending request event IDs. + +use std::collections::HashSet; +use std::sync::Arc; + +use tokio::sync::RwLock; + +/// Tracks pending request event IDs awaiting responses on the client side. +#[derive(Clone)] +pub struct ClientCorrelationStore { + pending_requests: Arc>>, +} + +impl Default for ClientCorrelationStore { + fn default() -> Self { + Self::new() + } +} + +impl ClientCorrelationStore { + pub fn new() -> Self { + Self { + pending_requests: Arc::new(RwLock::new(HashSet::new())), + } + } + + pub async fn register(&self, event_id: String) { + self.pending_requests.write().await.insert(event_id); + } + + pub async fn contains(&self, event_id: &str) -> bool { + self.pending_requests.read().await.contains(event_id) + } + + pub async fn remove(&self, event_id: &str) { + self.pending_requests.write().await.remove(event_id); + } + + pub async fn clear(&self) { + self.pending_requests.write().await.clear(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn remove_nonexistent_is_noop() { + let store = ClientCorrelationStore::new(); + store.remove("nonexistent").await; + assert!(!store.contains("nonexistent").await); + } + + #[tokio::test] + async fn contains_after_clear() { + let store = ClientCorrelationStore::new(); + store.register("e1".into()).await; + store.register("e2".into()).await; + assert!(store.contains("e1").await); + store.clear().await; + assert!(!store.contains("e1").await); + assert!(!store.contains("e2").await); + } + + #[tokio::test] + async fn register_and_remove_roundtrip() { + let store = ClientCorrelationStore::new(); + store.register("e1".into()).await; + assert!(store.contains("e1").await); + store.remove("e1").await; + assert!(!store.contains("e1").await); + } +} diff --git a/src/transport/client.rs b/src/transport/client/mod.rs similarity index 97% rename from src/transport/client.rs rename to src/transport/client/mod.rs index 23f2853..cdd3e5b 100644 --- a/src/transport/client.rs +++ b/src/transport/client/mod.rs @@ -3,14 +3,16 @@ //! Connects to a remote MCP server over Nostr. Sends JSON-RPC requests as //! kind 25910 events, correlates responses via `e` tag. -use std::collections::HashSet; +pub mod correlation_store; + +pub use correlation_store::ClientCorrelationStore; + use std::num::NonZeroUsize; use std::sync::{Arc, Mutex}; use std::time::Duration; use lru::LruCache; use nostr_sdk::prelude::*; -use tokio::sync::RwLock; use crate::core::constants::*; use crate::core::error::{Error, Result}; @@ -60,7 +62,7 @@ pub struct NostrClientTransport { config: NostrClientTransportConfig, server_pubkey: PublicKey, /// Pending request event IDs awaiting responses. - pending_requests: Arc>>, + pending_requests: ClientCorrelationStore, /// Outer gift-wrap event IDs successfully decrypted and verified (inner `verify()`). /// Duplicate outer ids are skipped before decrypt; ids are inserted only after success /// so failed decrypt/verify can be retried on redelivery. @@ -117,7 +119,7 @@ impl NostrClientTransport { }, config, server_pubkey, - pending_requests: Arc::new(RwLock::new(HashSet::new())), + pending_requests: ClientCorrelationStore::new(), seen_gift_wrap_ids, message_tx: tx, message_rx: Some(rx), @@ -238,10 +240,7 @@ impl NostrClientTransport { })?; if matches!(message, JsonRpcMessage::Request(_)) { - self.pending_requests - .write() - .await - .insert(event_id.to_hex()); + self.pending_requests.register(event_id.to_hex()).await; } tracing::debug!( @@ -282,7 +281,7 @@ impl NostrClientTransport { async fn event_loop( relay_pool: Arc, - pending: Arc>>, + pending: ClientCorrelationStore, server_pubkey: PublicKey, tx: tokio::sync::mpsc::UnboundedSender, encryption_mode: EncryptionMode, @@ -395,7 +394,7 @@ impl NostrClientTransport { // Correlate response if let Some(ref correlated_id) = e_tag { - let is_pending = pending.read().await.contains(correlated_id.as_str()); + let is_pending = pending.contains(correlated_id.as_str()).await; if !is_pending { tracing::warn!( target: LOG_TARGET, @@ -410,7 +409,7 @@ impl NostrClientTransport { if let Some(mcp_msg) = validation::validate_and_parse(&actual_event_content) { // Clean up pending request if let Some(ref correlated_id) = e_tag { - pending.write().await.remove(correlated_id.as_str()); + pending.remove(correlated_id.as_str()).await; } let _ = tx.send(mcp_msg); } diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 13a4f8a..0a31f45 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -7,5 +7,5 @@ pub mod base; pub mod client; pub mod server; -pub use client::{NostrClientTransport, NostrClientTransportConfig}; -pub use server::{NostrServerTransport, NostrServerTransportConfig}; +pub use client::{ClientCorrelationStore, NostrClientTransport, NostrClientTransportConfig}; +pub use server::{NostrServerTransport, NostrServerTransportConfig, ServerEventRouteStore}; diff --git a/src/transport/server/correlation_store.rs b/src/transport/server/correlation_store.rs new file mode 100644 index 0000000..659c9d5 --- /dev/null +++ b/src/transport/server/correlation_store.rs @@ -0,0 +1,114 @@ +//! Server-side event route store for mapping event IDs to client public keys. + +use std::collections::HashMap; +use std::sync::Arc; + +use tokio::sync::RwLock; + +/// Maps event IDs to client public keys for response routing on the server side. +#[derive(Clone)] +pub struct ServerEventRouteStore { + event_to_client: Arc>>, +} + +impl Default for ServerEventRouteStore { + fn default() -> Self { + Self::new() + } +} + +impl ServerEventRouteStore { + pub fn new() -> Self { + Self { + event_to_client: Arc::new(RwLock::new(HashMap::new())), + } + } + + pub async fn register(&self, event_id: String, client_pubkey: String) { + self.event_to_client + .write() + .await + .insert(event_id, client_pubkey); + } + + /// Returns the client public key for the given event ID without removing it. + pub async fn get(&self, event_id: &str) -> Option { + self.event_to_client.read().await.get(event_id).cloned() + } + + /// Removes and returns the client public key for the given event ID. + pub async fn pop(&self, event_id: &str) -> Option { + self.event_to_client.write().await.remove(event_id) + } + + /// Removes all routes for a given client public key. + pub async fn remove_for_client(&self, client_pubkey: &str) { + self.event_to_client + .write() + .await + .retain(|_, v| v != client_pubkey); + } + + pub async fn clear(&self) { + self.event_to_client.write().await.clear(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn pop_on_empty_returns_none() { + let store = ServerEventRouteStore::new(); + assert!(store.pop("nonexistent").await.is_none()); + } + + #[tokio::test] + async fn get_returns_without_removing() { + let store = ServerEventRouteStore::new(); + store.register("e1".into(), "pk1".into()).await; + assert_eq!(store.get("e1").await.as_deref(), Some("pk1")); + assert_eq!(store.get("e1").await.as_deref(), Some("pk1")); + } + + #[tokio::test] + async fn pop_removes_entry() { + let store = ServerEventRouteStore::new(); + store.register("e1".into(), "pk1".into()).await; + assert_eq!(store.pop("e1").await.as_deref(), Some("pk1")); + assert!(store.pop("e1").await.is_none()); + } + + #[tokio::test] + async fn remove_for_client_only_removes_matching() { + let store = ServerEventRouteStore::new(); + store.register("e1".into(), "pk1".into()).await; + store.register("e2".into(), "pk2".into()).await; + store.register("e3".into(), "pk1".into()).await; + + store.remove_for_client("pk1").await; + + assert!(store.get("e1").await.is_none()); + assert!(store.get("e3").await.is_none()); + assert_eq!(store.get("e2").await.as_deref(), Some("pk2")); + } + + #[tokio::test] + async fn remove_for_client_noop_when_no_match() { + let store = ServerEventRouteStore::new(); + store.register("e1".into(), "pk1".into()).await; + store.remove_for_client("pk_other").await; + assert_eq!(store.get("e1").await.as_deref(), Some("pk1")); + } + + #[tokio::test] + async fn clear_empties_store() { + let store = ServerEventRouteStore::new(); + store.register("e1".into(), "pk1".into()).await; + store.register("e2".into(), "pk2".into()).await; + store.clear().await; + assert!(store.get("e1").await.is_none()); + assert!(store.get("e2").await.is_none()); + } +} diff --git a/src/transport/server.rs b/src/transport/server/mod.rs similarity index 94% rename from src/transport/server.rs rename to src/transport/server/mod.rs index 9a41c2c..88567d8 100644 --- a/src/transport/server.rs +++ b/src/transport/server/mod.rs @@ -4,6 +4,10 @@ //! sessions, handles request/response correlation, and optionally publishes //! server announcements. +pub mod correlation_store; + +pub use correlation_store::ServerEventRouteStore; + use std::collections::HashMap; use std::num::NonZeroUsize; use std::sync::{Arc, Mutex}; @@ -70,7 +74,7 @@ pub struct NostrServerTransport { /// Client sessions: client_pubkey_hex → ClientSession sessions: Arc>>, /// Reverse lookup: event_id → client_pubkey_hex - event_to_client: Arc>>, + event_routes: ServerEventRouteStore, /// Outer gift-wrap event IDs successfully decrypted and verified (inner `verify()`). /// Duplicate outer ids are skipped before decrypt; ids are inserted only after success /// so failed decrypt/verify can be retried on redelivery. @@ -130,7 +134,7 @@ impl NostrServerTransport { }, config, sessions: Arc::new(RwLock::new(HashMap::new())), - event_to_client: Arc::new(RwLock::new(HashMap::new())), + event_routes: ServerEventRouteStore::new(), seen_gift_wrap_ids, message_tx: tx, message_rx: Some(rx), @@ -181,7 +185,7 @@ impl NostrServerTransport { // Spawn event loop let relay_pool = Arc::clone(&self.base.relay_pool); let sessions = self.sessions.clone(); - let event_to_client = self.event_to_client.clone(); + let event_routes = self.event_routes.clone(); let tx = self.message_tx.clone(); let allowed = self.config.allowed_public_keys.clone(); let excluded = self.config.excluded_capabilities.clone(); @@ -192,7 +196,7 @@ impl NostrServerTransport { Self::event_loop( relay_pool, sessions, - event_to_client, + event_routes, tx, allowed, excluded, @@ -204,7 +208,7 @@ impl NostrServerTransport { // Spawn session cleanup let sessions_cleanup = self.sessions.clone(); - let event_to_client_cleanup = self.event_to_client.clone(); + let event_routes_cleanup = self.event_routes.clone(); let cleanup_interval = self.config.cleanup_interval; let session_timeout = self.config.session_timeout; @@ -214,7 +218,7 @@ impl NostrServerTransport { interval.tick().await; let cleaned = Self::cleanup_sessions( &sessions_cleanup, - &event_to_client_cleanup, + &event_routes_cleanup, session_timeout, ) .await; @@ -242,25 +246,20 @@ impl NostrServerTransport { pub async fn close(&mut self) -> Result<()> { self.base.disconnect().await?; self.sessions.write().await.clear(); - self.event_to_client.write().await.clear(); + self.event_routes.clear().await; Ok(()) } /// Send a response back to the client that sent the original request. pub async fn send_response(&self, event_id: &str, mut response: JsonRpcMessage) -> Result<()> { - let event_to_client = self.event_to_client.read().await; - let client_pubkey_hex = event_to_client - .get(event_id) - .ok_or_else(|| { - tracing::error!( - target: LOG_TARGET, - event_id = %event_id, - "No client found for response correlation" - ); - Error::Other(format!("No client found for event {event_id}")) - })? - .clone(); - drop(event_to_client); + let client_pubkey_hex = self.event_routes.get(event_id).await.ok_or_else(|| { + tracing::error!( + target: LOG_TARGET, + event_id = %event_id, + "No client found for response correlation" + ); + Error::Other(format!("No client found for event {event_id}")) + })?; let sessions = self.sessions.read().await; let session = sessions.get(&client_pubkey_hex).ok_or_else(|| { @@ -326,7 +325,9 @@ impl NostrServerTransport { error })?; - // Clean up + // Clean up only after successful send + self.event_routes.pop(event_id).await; + let mut sessions = self.sessions.write().await; if let Some(session) = sessions.get_mut(&client_pubkey_hex) { // Clean up progress token @@ -337,8 +338,6 @@ impl NostrServerTransport { } drop(sessions); - self.event_to_client.write().await.remove(event_id); - tracing::debug!( target: LOG_TARGET, client_pubkey = %client_pubkey_hex, @@ -602,7 +601,7 @@ impl NostrServerTransport { async fn event_loop( relay_pool: Arc, sessions: Arc>>, - event_to_client: Arc>>, + event_routes: ServerEventRouteStore, tx: tokio::sync::mpsc::UnboundedSender, allowed_pubkeys: Vec, excluded_capabilities: Vec, @@ -768,10 +767,9 @@ impl NostrServerTransport { session .pending_requests .insert(event_id.clone(), original_id); - event_to_client - .write() - .await - .insert(event_id.clone(), sender_pubkey.clone()); + event_routes + .register(event_id.clone(), sender_pubkey.clone()) + .await; // Track progress token if let Some(token) = req @@ -812,22 +810,17 @@ impl NostrServerTransport { async fn cleanup_sessions( sessions: &RwLock>, - event_to_client: &RwLock>, + event_routes: &ServerEventRouteStore, timeout: Duration, ) -> usize { let mut sessions_w = sessions.write().await; - let mut event_map = event_to_client.write().await; let mut cleaned = 0; + let mut stale_event_ids = Vec::new(); sessions_w.retain(|pubkey, session| { if session.last_activity.elapsed() > timeout { - // Clean up reverse mappings - for event_id in session.pending_requests.keys() { - event_map.remove(event_id); - } - for event_id in session.event_to_progress_token.keys() { - event_map.remove(event_id); - } + stale_event_ids.extend(session.pending_requests.keys().cloned()); + stale_event_ids.extend(session.event_to_progress_token.keys().cloned()); tracing::debug!( target: LOG_TARGET, client_pubkey = %pubkey, @@ -839,6 +832,11 @@ impl NostrServerTransport { true } }); + drop(sessions_w); + + for event_id in &stale_event_ids { + event_routes.pop(event_id).await; + } cleaned } @@ -872,7 +870,7 @@ mod tests { #[tokio::test] async fn test_cleanup_sessions_removes_expired() { let sessions = Arc::new(RwLock::new(HashMap::new())); - let event_to_client = Arc::new(RwLock::new(HashMap::new())); + let event_routes = ServerEventRouteStore::new(); // Insert a session with an old activity time let mut session = ClientSession::new(false); @@ -883,15 +881,14 @@ mod tests { .write() .await .insert("pubkey1".to_string(), session); - event_to_client - .write() - .await - .insert("evt1".to_string(), "pubkey1".to_string()); + event_routes + .register("evt1".to_string(), "pubkey1".to_string()) + .await; // With a long timeout, nothing should be cleaned let cleaned = NostrServerTransport::cleanup_sessions( &sessions, - &event_to_client, + &event_routes, Duration::from_secs(300), ) .await; @@ -902,26 +899,26 @@ mod tests { thread::sleep(Duration::from_millis(5)); let cleaned = NostrServerTransport::cleanup_sessions( &sessions, - &event_to_client, + &event_routes, Duration::from_millis(1), ) .await; assert_eq!(cleaned, 1); assert!(sessions.read().await.is_empty()); - assert!(event_to_client.read().await.is_empty()); + assert!(event_routes.pop("evt1").await.is_none()); } #[tokio::test] async fn test_cleanup_preserves_active_sessions() { let sessions = Arc::new(RwLock::new(HashMap::new())); - let event_to_client = Arc::new(RwLock::new(HashMap::new())); + let event_routes = ServerEventRouteStore::new(); let session = ClientSession::new(false); sessions.write().await.insert("active".to_string(), session); let cleaned = NostrServerTransport::cleanup_sessions( &sessions, - &event_to_client, + &event_routes, Duration::from_secs(300), ) .await; From 6bdc58684e2a1e97af8b7cbb1bd25f6692598cfb Mon Sep 17 00:00:00 2001 From: Kushagra Date: Thu, 23 Apr 2026 13:35:15 +0530 Subject: [PATCH 41/69] feat(cep-19): add GiftWrapMode type and ephemeral gift-wrap encryption Introduces core CEP-19 types and encryption support: - GiftWrapMode enum (Optional/Ephemeral/Persistent) with policy helpers allows_kind() and supports_ephemeral() - gift_wrap_single_layer_with_kind() for creating 1059 or 21059 wraps - Re-export GiftWrapMode from crate root CEP-19: https://github.com/ContextVM/ceps/pull/19 --- src/core/types.rs | 80 ++++++++++++++++++++++++++++++++ src/encryption/mod.rs | 103 ++++++++++++++++++++++++++++++++++++------ src/lib.rs | 5 +- 3 files changed, 171 insertions(+), 17 deletions(-) diff --git a/src/core/types.rs b/src/core/types.rs index a811269..4ef3dd7 100644 --- a/src/core/types.rs +++ b/src/core/types.rs @@ -4,6 +4,8 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::time::Instant; +use crate::core::constants::{EPHEMERAL_GIFT_WRAP_KIND, GIFT_WRAP_KIND}; + // ── Encryption mode ───────────────────────────────────────────────── /// Encryption mode for transport communication. @@ -22,6 +24,39 @@ pub enum EncryptionMode { Disabled, } +// Gift-wrap mode (CEP-19) + +// Gift-wrap policy for encrypted transport communication (CEP-19). +// Controls whether encrypted messages use persistent gift wraps (kind `1059`), +// ephemeral gift wraps (kind `21059`), or adapt based on peer support. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum GiftWrapMode { + /// Prefer persistent gift wraps until ephemeral support is explicitly chosen or learned. + #[default] + Optional, + /// Force the ephemeral gift-wrap kind (`21059`) for encrypted messages. + Ephemeral, + /// Force the persistent gift-wrap kind (`1059`) for encrypted messages. + Persistent, +} + +impl GiftWrapMode { + /// Returns whether this mode accepts the given encrypted outer event kind. + pub fn allows_kind(self, kind: u16) -> bool { + match self { + Self::Optional => kind == GIFT_WRAP_KIND || kind == EPHEMERAL_GIFT_WRAP_KIND, + Self::Ephemeral => kind == EPHEMERAL_GIFT_WRAP_KIND, + Self::Persistent => kind == GIFT_WRAP_KIND, + } + } + + /// Returns whether this mode supports sending and advertising ephemeral gift wraps. + pub fn supports_ephemeral(self) -> bool { + !matches!(self, Self::Persistent) + } +} + // ── Server info ───────────────────────────────────────────────────── /// Server information for announcements (kind 11316). @@ -226,6 +261,7 @@ pub struct CapabilityExclusion { #[cfg(test)] mod tests { use super::*; + use crate::core::constants::{EPHEMERAL_GIFT_WRAP_KIND, GIFT_WRAP_KIND}; use serde_json::json; use std::thread; use std::time::Duration; @@ -257,6 +293,50 @@ mod tests { assert_eq!(parsed, mode); } + #[test] + fn test_gift_wrap_mode_serde_roundtrip_optional() { + let mode = GiftWrapMode::Optional; + let s = serde_json::to_string(&mode).unwrap(); + assert_eq!(s, "\"optional\""); + let parsed: GiftWrapMode = serde_json::from_str(&s).unwrap(); + assert_eq!(parsed, mode); + } + + #[test] + fn test_gift_wrap_mode_serde_roundtrip_ephemeral() { + let mode = GiftWrapMode::Ephemeral; + let s = serde_json::to_string(&mode).unwrap(); + assert_eq!(s, "\"ephemeral\""); + let parsed: GiftWrapMode = serde_json::from_str(&s).unwrap(); + assert_eq!(parsed, mode); + } + + #[test] + fn test_gift_wrap_mode_serde_roundtrip_persistent() { + let mode = GiftWrapMode::Persistent; + let s = serde_json::to_string(&mode).unwrap(); + assert_eq!(s, "\"persistent\""); + let parsed: GiftWrapMode = serde_json::from_str(&s).unwrap(); + assert_eq!(parsed, mode); + } + + #[test] + fn test_gift_wrap_mode_policy_helpers() { + // Optional accepts both kinds + assert!(GiftWrapMode::Optional.allows_kind(GIFT_WRAP_KIND)); + assert!(GiftWrapMode::Optional.allows_kind(EPHEMERAL_GIFT_WRAP_KIND)); + // Ephemeral only accepts 21059 + assert!(GiftWrapMode::Ephemeral.allows_kind(EPHEMERAL_GIFT_WRAP_KIND)); + assert!(!GiftWrapMode::Ephemeral.allows_kind(GIFT_WRAP_KIND)); + // Persistent only accepts 1059 + assert!(GiftWrapMode::Persistent.allows_kind(GIFT_WRAP_KIND)); + assert!(!GiftWrapMode::Persistent.allows_kind(EPHEMERAL_GIFT_WRAP_KIND)); + // supports_ephemeral check + assert!(GiftWrapMode::Optional.supports_ephemeral()); + assert!(GiftWrapMode::Ephemeral.supports_ephemeral()); + assert!(!GiftWrapMode::Persistent.supports_ephemeral()); + } + fn assert_json_rpc_roundtrip(msg: &JsonRpcMessage) { let wire = serde_json::to_string(msg).unwrap(); let parsed: JsonRpcMessage = serde_json::from_str(&wire).unwrap(); diff --git a/src/encryption/mod.rs b/src/encryption/mod.rs index 20a5362..6a6d4bc 100644 --- a/src/encryption/mod.rs +++ b/src/encryption/mod.rs @@ -3,7 +3,7 @@ //! Provides NIP-44 encryption/decryption and NIP-59 gift wrapping. //! The actual gift wrapping is done via nostr-sdk's Client for full NIP-59 compliance. -use crate::core::constants::GIFT_WRAP_KIND; +use crate::core::constants::{EPHEMERAL_GIFT_WRAP_KIND, GIFT_WRAP_KIND}; use crate::core::error::{Error, Result}; use nostr_sdk::prelude::*; @@ -37,12 +37,8 @@ where .map_err(|e| Error::Decryption(e.to_string())) } -/// Decrypt a single-layer NIP-44 gift wrap (kind 1059). -/// -/// This matches the ContextVM JS/TS SDK's encryption scheme: -/// - The gift wrap event has NIP-44 encrypted content (single layer) -/// - Decrypt using recipient's key + event's pubkey (ephemeral sender) -/// - Returns the decrypted plaintext content string +// Decrypt a single-layer NIP-44 gift wrap (kind 1059). + pub async fn decrypt_gift_wrap_single_layer(signer: &T, event: &Event) -> Result where T: NostrSigner, @@ -51,13 +47,8 @@ where decrypt_nip44(signer, &sender_pubkey, &event.content).await } -/// Create a single-layer NIP-44 gift wrap (kind 1059). -/// -/// Matches the ContextVM JS/TS SDK's `encryptMessage`: -/// 1. Generate ephemeral keypair -/// 2. NIP-44 encrypt plaintext using ephemeral_secret + recipient_pubkey -/// 3. Build kind 1059 event with `p` tag pointing to recipient -/// 4. Sign with ephemeral key +// Create a single-layer NIP-44 gift wrap (kind 1059). + pub async fn gift_wrap_single_layer( _signer: &T, recipient: &PublicKey, @@ -78,6 +69,37 @@ where .map_err(|e| Error::Encryption(e.to_string())) } +/// Create a single-layer NIP-44 gift wrap using the provided outer event kind. +/// +/// Only ContextVM's supported persistent (`1059`) and ephemeral (`21059`) gift-wrap +/// kinds are accepted here. +pub async fn gift_wrap_single_layer_with_kind( + _signer: &T, + recipient: &PublicKey, + plaintext: &str, + gift_wrap_kind: u16, +) -> Result +where + T: NostrSigner, +{ + if gift_wrap_kind != GIFT_WRAP_KIND && gift_wrap_kind != EPHEMERAL_GIFT_WRAP_KIND { + return Err(Error::Encryption(format!( + "Unsupported gift-wrap kind for single-layer encryption: {gift_wrap_kind}" + ))); + } + + let ephemeral = Keys::generate(); + + let encrypted = encrypt_nip44(&ephemeral, recipient, plaintext).await?; + + let builder = + EventBuilder::new(Kind::Custom(gift_wrap_kind), encrypted).tag(Tag::public_key(*recipient)); + + builder + .sign_with_keys(&ephemeral) + .map_err(|e| Error::Encryption(e.to_string())) +} + // Legacy NIP-59 functions kept for reference but deprecated. /// Decrypt a full NIP-59 gift-wrapped event using the Client. @@ -111,7 +133,7 @@ pub async fn gift_wrap( #[cfg(test)] mod tests { - use crate::core::constants::GIFT_WRAP_KIND; + use crate::core::constants::{EPHEMERAL_GIFT_WRAP_KIND, GIFT_WRAP_KIND}; use super::*; @@ -299,4 +321,55 @@ mod tests { "forged inner event must fail signature verification" ); } + + #[tokio::test] + async fn test_ephemeral_gift_wrap_roundtrip_single_layer() { + let sender_keys = Keys::generate(); + let recipient_keys = Keys::generate(); + + let mcp_content = r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#; + let inner_event = EventBuilder::new(Kind::Custom(25910), mcp_content) + .tag(Tag::public_key(recipient_keys.public_key())) + .sign_with_keys(&sender_keys) + .unwrap(); + let inner_json = serde_json::to_string(&inner_event).unwrap(); + + let gift_wrap_event = gift_wrap_single_layer_with_kind( + &sender_keys, + &recipient_keys.public_key(), + &inner_json, + EPHEMERAL_GIFT_WRAP_KIND, + ) + .await + .unwrap(); + + assert_eq!(gift_wrap_event.kind, Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND)); + + let decrypted = decrypt_gift_wrap_single_layer(&recipient_keys, &gift_wrap_event) + .await + .unwrap(); + let parsed: Event = serde_json::from_str(&decrypted).unwrap(); + assert_eq!(parsed.pubkey, sender_keys.public_key()); + assert_eq!(parsed.content, mcp_content); + } + + #[tokio::test] + async fn test_invalid_gift_wrap_kind_rejected() { + let sender_keys = Keys::generate(); + let recipient_keys = Keys::generate(); + + let error = gift_wrap_single_layer_with_kind( + &sender_keys, + &recipient_keys.public_key(), + "test", + 4242, + ) + .await + .unwrap_err(); + + assert!( + error.to_string().contains("Unsupported gift-wrap kind"), + "unexpected error: {error}" + ); + } } diff --git a/src/lib.rs b/src/lib.rs index 2157224..f53c957 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,8 +52,9 @@ mod util; // Re-export commonly used types pub use core::error::{Error, Result}; pub use core::types::{ - CapabilityExclusion, ClientSession, EncryptionMode, JsonRpcError, JsonRpcErrorResponse, - JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ServerInfo, + CapabilityExclusion, ClientSession, EncryptionMode, GiftWrapMode, JsonRpcError, + JsonRpcErrorResponse, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, + ServerInfo, }; pub use discovery::ServerAnnouncement; pub use relay::RelayPool; From eeb4e366c358c7ce4e9786312abdb3ba22ce8ff4 Mon Sep 17 00:00:00 2001 From: Harsh Date: Thu, 23 Apr 2026 04:41:14 +0530 Subject: [PATCH 42/69] test: add Phase 3 integration tests using MockRelayPool --- src/lib.rs | 3 +- src/relay/mock.rs | 30 ++ src/transport/client/mod.rs | 44 +++ src/transport/server/mod.rs | 34 ++ tests/transport_integration.rs | 603 +++++++++++++++++++++++++++++++++ 5 files changed, 713 insertions(+), 1 deletion(-) create mode 100644 tests/transport_integration.rs diff --git a/src/lib.rs b/src/lib.rs index 80f5945..4172347 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -56,7 +56,8 @@ pub use core::types::{ JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ServerInfo, }; pub use discovery::ServerAnnouncement; -pub use relay::RelayPool; +pub use relay::mock::MockRelayPool; +pub use relay::{RelayPool, RelayPoolTrait}; pub use transport::client::{ ClientCorrelationStore, NostrClientTransport, NostrClientTransportConfig, }; diff --git a/src/relay/mock.rs b/src/relay/mock.rs index 53039b2..52e52bc 100644 --- a/src/relay/mock.rs +++ b/src/relay/mock.rs @@ -70,6 +70,36 @@ impl MockRelayPool { self.keys.public_key() } + /// Like [`new`](Self::new) but with caller-provided signing keys. + pub fn with_keys(keys: Keys) -> Self { + let (tx, _rx) = tokio::sync::broadcast::channel(1024); + Self { + inner: Arc::new(Mutex::new(MockRelayInner::new())), + notification_tx: tx, + keys, + } + } + + /// Create a pair of linked mock relay pools with different signing keys. + /// + /// Both pools share the same event store and notification channel; events + /// published by one are visible to the other's `notifications()` receivers. + pub fn create_pair() -> (Self, Self) { + let (tx, _rx) = tokio::sync::broadcast::channel(1024); + let inner = Arc::new(Mutex::new(MockRelayInner::new())); + let a = Self { + inner: Arc::clone(&inner), + notification_tx: tx.clone(), + keys: Keys::generate(), + }; + let b = Self { + inner, + notification_tx: tx, + keys: Keys::generate(), + }; + (a, b) + } + /// Clone of all events published so far (useful for assertions in tests). pub async fn stored_events(&self) -> Vec { self.inner.lock().await.events.clone() diff --git a/src/transport/client/mod.rs b/src/transport/client/mod.rs index cdd3e5b..955e8fc 100644 --- a/src/transport/client/mod.rs +++ b/src/transport/client/mod.rs @@ -126,6 +126,50 @@ impl NostrClientTransport { }) } + /// Like [`new`](Self::new) but accepts an existing relay pool. + pub async fn with_relay_pool( + config: NostrClientTransportConfig, + relay_pool: Arc, + ) -> Result { + tracing_setup::init_tracer(config.log_file_path.as_deref())?; + + let server_pubkey = PublicKey::from_hex(&config.server_pubkey).map_err(|error| { + tracing::error!( + target: LOG_TARGET, + error = %error, + server_pubkey = %config.server_pubkey, + "Invalid server pubkey" + ); + Error::Other(format!("Invalid server pubkey: {error}")) + })?; + + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let seen_gift_wrap_ids = Arc::new(Mutex::new(LruCache::new( + NonZeroUsize::new(DEFAULT_LRU_SIZE).expect("DEFAULT_LRU_SIZE must be non-zero"), + ))); + + tracing::info!( + target: LOG_TARGET, + relay_count = config.relay_urls.len(), + stateless = config.is_stateless, + encryption_mode = ?config.encryption_mode, + "Created client transport (with_relay_pool)" + ); + Ok(Self { + base: BaseTransport { + relay_pool, + encryption_mode: config.encryption_mode, + is_connected: false, + }, + config, + server_pubkey, + pending_requests: ClientCorrelationStore::new(), + seen_gift_wrap_ids, + message_tx: tx, + message_rx: Some(rx), + }) + } + /// Connect and start listening for responses. pub async fn start(&mut self) -> Result<()> { self.base diff --git a/src/transport/server/mod.rs b/src/transport/server/mod.rs index 88567d8..7781557 100644 --- a/src/transport/server/mod.rs +++ b/src/transport/server/mod.rs @@ -141,6 +141,40 @@ impl NostrServerTransport { }) } + /// Like [`new`](Self::new) but accepts an existing relay pool. + pub async fn with_relay_pool( + config: NostrServerTransportConfig, + relay_pool: Arc, + ) -> Result { + tracing_setup::init_tracer(config.log_file_path.as_deref())?; + + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let seen_gift_wrap_ids = Arc::new(Mutex::new(LruCache::new( + NonZeroUsize::new(DEFAULT_LRU_SIZE).expect("DEFAULT_LRU_SIZE must be non-zero"), + ))); + + tracing::info!( + target: LOG_TARGET, + relay_count = config.relay_urls.len(), + announced = config.is_announced_server, + encryption_mode = ?config.encryption_mode, + "Created server transport (with_relay_pool)" + ); + Ok(Self { + base: BaseTransport { + relay_pool, + encryption_mode: config.encryption_mode, + is_connected: false, + }, + config, + sessions: Arc::new(RwLock::new(HashMap::new())), + event_routes: ServerEventRouteStore::new(), + seen_gift_wrap_ids, + message_tx: tx, + message_rx: Some(rx), + }) + } + /// Start listening for incoming requests. pub async fn start(&mut self) -> Result<()> { self.base diff --git a/tests/transport_integration.rs b/tests/transport_integration.rs new file mode 100644 index 0000000..c577b79 --- /dev/null +++ b/tests/transport_integration.rs @@ -0,0 +1,603 @@ +//! Integration tests — transport-level flows using MockRelayPool. +//! +//! Each test wires client and/or server transports to an in-memory mock relay +//! network so that the full event-loop logic (subscription, publish, routing, +//! encryption-mode enforcement, and authorization) is exercised without +//! connecting to real relays. + +use std::sync::Arc; +use std::time::Duration; + +use contextvm_sdk::core::constants::{ + mcp_protocol_version, GIFT_WRAP_KIND, SERVER_ANNOUNCEMENT_KIND, +}; +use contextvm_sdk::core::types::EncryptionMode; +use contextvm_sdk::relay::mock::MockRelayPool; +use contextvm_sdk::transport::client::{NostrClientTransport, NostrClientTransportConfig}; +use contextvm_sdk::transport::server::{NostrServerTransport, NostrServerTransportConfig}; +use contextvm_sdk::{ + JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, RelayPoolTrait, + ServerInfo, +}; +use nostr_sdk::prelude::*; + +fn as_pool(pool: MockRelayPool) -> Arc { + Arc::new(pool) +} + +/// Let spawned event loops call `notifications()` before we publish anything. +/// Without this, broadcast messages can be lost on slow CI runners. +async fn let_event_loops_start() { + tokio::time::sleep(Duration::from_millis(10)).await; +} + +// ── 1. Full initialization handshake ──────────────────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn full_initialization_handshake() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + let mut client_rx = client + .take_message_receiver() + .expect("client message receiver"); + + server.start().await.expect("server start"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + // Client sends initialize request. + let init_request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: Some(serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "capabilities": {}, + "clientInfo": { "name": "test-client", "version": "0.0.0" } + })), + }); + client + .send(&init_request) + .await + .expect("client send initialize"); + + // Server should receive the initialize request. + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout waiting for server to receive init request") + .expect("server channel closed"); + + assert_eq!( + incoming.message.method(), + Some("initialize"), + "server must receive initialize request" + ); + + // Server sends initialize response. + let init_response = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + result: serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "serverInfo": { "name": "test-server", "version": "0.0.0" }, + "capabilities": {} + }), + }); + server + .send_response(&incoming.event_id, init_response) + .await + .expect("server send response"); + + // Client should receive the initialize response. + let response = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .expect("timeout waiting for client to receive init response") + .expect("client channel closed"); + + assert!(response.is_response(), "client must receive a response"); + assert_eq!(response.id(), Some(&serde_json::json!(1))); +} + +// ── 2. Server announcement publishing ─────────────────────────────────────── + +#[tokio::test] +async fn server_announcement_publishing() { + let pool = Arc::new(MockRelayPool::new()); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + is_announced_server: true, + server_info: Some(ServerInfo { + name: Some("Phase3-Test-Server".to_string()), + ..Default::default() + }), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + Arc::clone(&pool) as Arc, + ) + .await + .expect("create server transport"); + + server.start().await.expect("server start"); + server.announce().await.expect("server announce"); + + let events = pool.stored_events().await; + let announcement = events + .iter() + .find(|e| e.kind == Kind::Custom(SERVER_ANNOUNCEMENT_KIND)); + + assert!( + announcement.is_some(), + "kind {} event must be published after announce()", + SERVER_ANNOUNCEMENT_KIND + ); + + let ann = announcement.unwrap(); + let content: serde_json::Value = + serde_json::from_str(&ann.content).expect("announcement content must be JSON"); + assert_eq!( + content["name"], "Phase3-Test-Server", + "announcement content must include server name" + ); +} + +// ── 3. Encryption mode Optional accepts plaintext ─────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn encryption_mode_optional_accepts_plaintext() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + // Server uses Optional — should accept both encrypted and plaintext. + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + encryption_mode: EncryptionMode::Optional, + ..Default::default() + }, + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + server.start().await.expect("server start"); + + // Client uses Disabled — sends plaintext kind 25910. + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + client.start().await.expect("client start"); + let_event_loops_start().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("plain-1"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send plaintext request"); + + // Server must receive and process the plaintext message. + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout waiting for server to receive plaintext request") + .expect("server channel closed"); + + assert_eq!( + incoming.message.method(), + Some("tools/list"), + "Optional-mode server must accept plaintext kind 25910" + ); + assert!( + !incoming.is_encrypted, + "plaintext request must not be marked as encrypted" + ); +} + +// ── 4. Auth allowlist blocks disallowed pubkey ────────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn auth_allowlist_blocks_disallowed_pubkey() { + let allowed_keys = Keys::generate(); // a DIFFERENT pubkey + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + // Server allows only `allowed_keys` — client_keys is NOT allowed. + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + allowed_public_keys: vec![allowed_keys.public_key().to_hex()], + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + server.start().await.expect("server start"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + client.start().await.expect("client start"); + let_event_loops_start().await; + + // Send a non-initialize request (those are always allowed). + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(42), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send request"); + + // The server should NOT forward the request (pubkey is disallowed). + let result = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()).await; + assert!( + result.is_err(), + "disallowed pubkey request must not reach the server handler" + ); +} + +// ── 5. Encryption mode Required drops plaintext ───────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn encryption_mode_required_drops_plaintext() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + // Server requires encryption — plaintext must be dropped. + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + encryption_mode: EncryptionMode::Required, + ..Default::default() + }, + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + server.start().await.expect("server start"); + + // Client sends plaintext (Disabled mode). + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + client.start().await.expect("client start"); + let_event_loops_start().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("drop-me"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send plaintext request"); + + // Server must NOT receive the plaintext message. + let result = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()).await; + assert!( + result.is_err(), + "Required-mode server must drop plaintext kind 25910 events" + ); +} + +// ── 6. Encrypted gift-wrap roundtrip ──────────────────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn encrypted_gift_wrap_roundtrip() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + let server_pool = Arc::new(server_pool); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + encryption_mode: EncryptionMode::Required, + ..Default::default() + }, + Arc::clone(&server_pool) as Arc, + ) + .await + .expect("create server transport"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Required, + ..Default::default() + }, + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + let mut client_rx = client + .take_message_receiver() + .expect("client message receiver"); + + server.start().await.expect("server start"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + // Client sends encrypted request. + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("enc-1"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send encrypted request"); + + // Verify the published event is a gift-wrap (kind 1059). + let events = server_pool.stored_events().await; + assert!( + events + .iter() + .any(|e| e.kind == Kind::Custom(GIFT_WRAP_KIND)), + "client must publish a kind 1059 gift-wrap event" + ); + + // Server should decrypt and receive the request. + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout waiting for server to decrypt gift-wrap request") + .expect("server channel closed"); + + assert_eq!(incoming.message.method(), Some("tools/list")); + assert!(incoming.is_encrypted, "message must be marked encrypted"); + + // Server sends an encrypted response back. + let response = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("enc-1"), + result: serde_json::json!({ "tools": [] }), + }); + server + .send_response(&incoming.event_id, response) + .await + .expect("server send encrypted response"); + + // Client should decrypt and receive the response. + let client_msg = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .expect("timeout waiting for client to decrypt gift-wrap response") + .expect("client channel closed"); + + assert!(client_msg.is_response()); + assert_eq!(client_msg.id(), Some(&serde_json::json!("enc-1"))); +} + +// ── 7. Gift-wrap dedup skips duplicate delivery ───────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn gift_wrap_dedup_skips_duplicate_delivery() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + let server_pool = Arc::new(server_pool); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + encryption_mode: EncryptionMode::Required, + ..Default::default() + }, + Arc::clone(&server_pool) as Arc, + ) + .await + .expect("create server transport"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Required, + ..Default::default() + }, + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + + server.start().await.expect("server start"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + // Client sends a gift-wrapped request. + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("dedup-1"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send request"); + + // Server receives the first delivery. + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout waiting for first delivery") + .expect("server channel closed"); + assert_eq!(incoming.message.method(), Some("tools/list")); + assert!(incoming.is_encrypted); + + // Re-deliver the same gift-wrap event (simulates relay redelivery). + let events = server_pool.stored_events().await; + let gift_wrap = events + .iter() + .find(|e| e.kind == Kind::Custom(GIFT_WRAP_KIND)) + .expect("gift-wrap event must exist") + .clone(); + server_pool + .publish_event(&gift_wrap) + .await + .expect("re-inject duplicate"); + + // Server must NOT process the duplicate. + let result = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()).await; + assert!( + result.is_err(), + "duplicate gift-wrap (same outer event id) must be skipped" + ); +} + +// ── 8. Correlated notification has e tag ───────────────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn correlated_notification_has_e_tag() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + let server_pool = Arc::new(server_pool); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + Arc::clone(&server_pool) as Arc, + ) + .await + .expect("create server transport"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + let mut client_rx = client + .take_message_receiver() + .expect("client message receiver"); + + server.start().await.expect("server start"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + // Client sends a tools/list request. + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("notif-corr"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send request"); + + // Server receives the request and captures the event_id. + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout waiting for server to receive request") + .expect("server channel closed"); + assert_eq!(incoming.message.method(), Some("tools/list")); + let request_event_id = incoming.event_id.clone(); + + // Server sends a correlated notifications/progress notification. + let notification = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/progress".to_string(), + params: Some(serde_json::json!({ + "progressToken": "tok-1", + "progress": 50, + "total": 100 + })), + }); + server + .send_notification( + &incoming.client_pubkey, + ¬ification, + Some(&request_event_id), + ) + .await + .expect("send correlated notification"); + + // Client should receive the notification. + let client_msg = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .expect("timeout waiting for client to receive notification") + .expect("client channel closed"); + + assert!(client_msg.is_notification()); + assert_eq!(client_msg.method(), Some("notifications/progress")); + + // The published notification event must carry an e tag referencing the request. + let events = server_pool.stored_events().await; + let notif_event = events + .iter() + .find(|e| e.pubkey == server_pubkey && e.content.contains("notifications/progress")) + .expect("notification event must be in stored events"); + + let e_tag = contextvm_sdk::core::serializers::get_tag_value(¬if_event.tags, "e"); + assert_eq!( + e_tag.as_deref(), + Some(request_event_id.as_str()), + "notification event must have e tag referencing the original request event id" + ); +} From 3dfc682242fd0459df0e89132c92344d02cd084d Mon Sep 17 00:00:00 2001 From: Harsh Date: Fri, 24 Apr 2026 06:25:23 +0530 Subject: [PATCH 43/69] refactor: enrich correlation stores, add SessionStore, and port Phase 1/2 conformance tests --- src/lib.rs | 3 +- src/transport/client/correlation_store.rs | 47 +- src/transport/client/mod.rs | 6 +- src/transport/server/correlation_store.rs | 229 ++++++- src/transport/server/mod.rs | 73 ++- src/transport/server/session_store.rs | 204 ++++++ tests/conformance_stores.rs | 725 ++++++++++++++++++++++ 7 files changed, 1209 insertions(+), 78 deletions(-) create mode 100644 src/transport/server/session_store.rs create mode 100644 tests/conformance_stores.rs diff --git a/src/lib.rs b/src/lib.rs index 5671c22..828439f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -63,7 +63,8 @@ pub use transport::client::{ ClientCorrelationStore, NostrClientTransport, NostrClientTransportConfig, }; pub use transport::server::{ - IncomingRequest, NostrServerTransport, NostrServerTransportConfig, ServerEventRouteStore, + IncomingRequest, NostrServerTransport, NostrServerTransportConfig, RouteEntry, + ServerEventRouteStore, SessionSnapshot, SessionStore, }; #[cfg(feature = "rmcp")] diff --git a/src/transport/client/correlation_store.rs b/src/transport/client/correlation_store.rs index e080fa5..140f664 100644 --- a/src/transport/client/correlation_store.rs +++ b/src/transport/client/correlation_store.rs @@ -1,14 +1,14 @@ //! Client-side correlation store for tracking pending request event IDs. -use std::collections::HashSet; +use std::collections::HashMap; use std::sync::Arc; use tokio::sync::RwLock; -/// Tracks pending request event IDs awaiting responses on the client side. +/// Tracks pending request event IDs and their original request IDs on the client side. #[derive(Clone)] pub struct ClientCorrelationStore { - pending_requests: Arc>>, + pending_requests: Arc>>, } impl Default for ClientCorrelationStore { @@ -20,20 +20,39 @@ impl Default for ClientCorrelationStore { impl ClientCorrelationStore { pub fn new() -> Self { Self { - pending_requests: Arc::new(RwLock::new(HashSet::new())), + pending_requests: Arc::new(RwLock::new(HashMap::new())), } } - pub async fn register(&self, event_id: String) { - self.pending_requests.write().await.insert(event_id); + /// Register a pending request with its original JSON-RPC request ID. + pub async fn register(&self, event_id: String, original_id: serde_json::Value) { + self.pending_requests + .write() + .await + .insert(event_id, original_id); } pub async fn contains(&self, event_id: &str) -> bool { - self.pending_requests.read().await.contains(event_id) + self.pending_requests.read().await.contains_key(event_id) } - pub async fn remove(&self, event_id: &str) { - self.pending_requests.write().await.remove(event_id); + /// Remove a pending request. Returns `true` if the key existed. + pub async fn remove(&self, event_id: &str) -> bool { + self.pending_requests + .write() + .await + .remove(event_id) + .is_some() + } + + /// Retrieve the original request ID for a given event ID without removing it. + pub async fn get_original_id(&self, event_id: &str) -> Option { + self.pending_requests.read().await.get(event_id).cloned() + } + + /// Number of pending requests currently tracked. + pub async fn count(&self) -> usize { + self.pending_requests.read().await.len() } pub async fn clear(&self) { @@ -48,15 +67,15 @@ mod tests { #[tokio::test] async fn remove_nonexistent_is_noop() { let store = ClientCorrelationStore::new(); - store.remove("nonexistent").await; + assert!(!store.remove("nonexistent").await); assert!(!store.contains("nonexistent").await); } #[tokio::test] async fn contains_after_clear() { let store = ClientCorrelationStore::new(); - store.register("e1".into()).await; - store.register("e2".into()).await; + store.register("e1".into(), serde_json::Value::Null).await; + store.register("e2".into(), serde_json::Value::Null).await; assert!(store.contains("e1").await); store.clear().await; assert!(!store.contains("e1").await); @@ -66,9 +85,9 @@ mod tests { #[tokio::test] async fn register_and_remove_roundtrip() { let store = ClientCorrelationStore::new(); - store.register("e1".into()).await; + store.register("e1".into(), serde_json::Value::Null).await; assert!(store.contains("e1").await); - store.remove("e1").await; + assert!(store.remove("e1").await); assert!(!store.contains("e1").await); } } diff --git a/src/transport/client/mod.rs b/src/transport/client/mod.rs index 955e8fc..d7f2c2b 100644 --- a/src/transport/client/mod.rs +++ b/src/transport/client/mod.rs @@ -283,8 +283,10 @@ impl NostrClientTransport { error })?; - if matches!(message, JsonRpcMessage::Request(_)) { - self.pending_requests.register(event_id.to_hex()).await; + if let JsonRpcMessage::Request(ref req) = message { + self.pending_requests + .register(event_id.to_hex(), req.id.clone()) + .await; } tracing::debug!( diff --git a/src/transport/server/correlation_store.rs b/src/transport/server/correlation_store.rs index 659c9d5..1879d80 100644 --- a/src/transport/server/correlation_store.rs +++ b/src/transport/server/correlation_store.rs @@ -1,14 +1,65 @@ -//! Server-side event route store for mapping event IDs to client public keys. +//! Server-side event route store for mapping event IDs to client routes. -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use tokio::sync::RwLock; -/// Maps event IDs to client public keys for response routing on the server side. +/// A route entry for an in-flight request. +#[derive(Debug, Clone)] +pub struct RouteEntry { + /// The client's public key that originated this request. + pub client_pubkey: String, + /// The original JSON-RPC request ID (before replacement with event ID). + pub original_request_id: serde_json::Value, + /// Optional progress token for this request. + pub progress_token: Option, +} + +/// Internal state behind the lock. +struct Inner { + /// Primary index: event_id → route entry. + routes: HashMap, + /// Secondary index: progress_token → event_id. + progress_token_to_event: HashMap, + /// Secondary index: client_pubkey → set of event_ids. + client_event_ids: HashMap>, +} + +impl Inner { + fn new() -> Self { + Self { + routes: HashMap::new(), + progress_token_to_event: HashMap::new(), + client_event_ids: HashMap::new(), + } + } + + /// Remove a single route and clean up all secondary indexes. + fn remove_route(&mut self, event_id: &str) -> Option { + let route = self.routes.remove(event_id)?; + + // Clean up progress token index. + if let Some(ref token) = route.progress_token { + self.progress_token_to_event.remove(token); + } + + // Clean up client index. + if let Some(set) = self.client_event_ids.get_mut(&route.client_pubkey) { + set.remove(event_id); + if set.is_empty() { + self.client_event_ids.remove(&route.client_pubkey); + } + } + + Some(route) + } +} + +/// Maps event IDs to full route entries for response routing on the server side. #[derive(Clone)] pub struct ServerEventRouteStore { - event_to_client: Arc>>, + inner: Arc>, } impl Default for ServerEventRouteStore { @@ -20,43 +71,140 @@ impl Default for ServerEventRouteStore { impl ServerEventRouteStore { pub fn new() -> Self { Self { - event_to_client: Arc::new(RwLock::new(HashMap::new())), + inner: Arc::new(RwLock::new(Inner::new())), } } - pub async fn register(&self, event_id: String, client_pubkey: String) { - self.event_to_client - .write() - .await - .insert(event_id, client_pubkey); + /// Register a route for an incoming request. + pub async fn register( + &self, + event_id: String, + client_pubkey: String, + original_request_id: serde_json::Value, + progress_token: Option, + ) { + let mut inner = self.inner.write().await; + + // Update client index. + inner + .client_event_ids + .entry(client_pubkey.clone()) + .or_default() + .insert(event_id.clone()); + + // Update progress token index. + if let Some(ref token) = progress_token { + inner + .progress_token_to_event + .insert(token.clone(), event_id.clone()); + } + + inner.routes.insert( + event_id, + RouteEntry { + client_pubkey, + original_request_id, + progress_token, + }, + ); } /// Returns the client public key for the given event ID without removing it. pub async fn get(&self, event_id: &str) -> Option { - self.event_to_client.read().await.get(event_id).cloned() + self.inner + .read() + .await + .routes + .get(event_id) + .map(|r| r.client_pubkey.clone()) + } + + /// Returns the full route entry for the given event ID without removing it. + pub async fn get_route(&self, event_id: &str) -> Option { + self.inner.read().await.routes.get(event_id).cloned() + } + + /// Removes and returns the full route entry for the given event ID. + pub async fn pop(&self, event_id: &str) -> Option { + self.inner.write().await.remove_route(event_id) + } + + /// Removes all routes for a given client public key. Returns the count removed. + pub async fn remove_for_client(&self, client_pubkey: &str) -> usize { + let mut inner = self.inner.write().await; + + let event_ids = match inner.client_event_ids.remove(client_pubkey) { + Some(ids) => ids, + None => return 0, + }; + + let count = event_ids.len(); + for event_id in &event_ids { + if let Some(route) = inner.routes.remove(event_id.as_str()) { + if let Some(ref token) = route.progress_token { + inner.progress_token_to_event.remove(token); + } + } + } + count } - /// Removes and returns the client public key for the given event ID. - pub async fn pop(&self, event_id: &str) -> Option { - self.event_to_client.write().await.remove(event_id) + /// Check whether a route exists for the given event ID. + pub async fn has_event_route(&self, event_id: &str) -> bool { + self.inner.read().await.routes.contains_key(event_id) } - /// Removes all routes for a given client public key. - pub async fn remove_for_client(&self, client_pubkey: &str) { - self.event_to_client - .write() + /// Check whether the given client has any active routes. + pub async fn has_active_routes_for_client(&self, client_pubkey: &str) -> bool { + self.inner + .read() .await - .retain(|_, v| v != client_pubkey); + .client_event_ids + .get(client_pubkey) + .is_some_and(|set| !set.is_empty()) + } + + /// Look up the event ID associated with a progress token. + pub async fn get_event_id_by_progress_token(&self, token: &str) -> Option { + self.inner + .read() + .await + .progress_token_to_event + .get(token) + .cloned() + } + + /// Check whether a progress token mapping exists. + pub async fn has_progress_token(&self, token: &str) -> bool { + self.inner + .read() + .await + .progress_token_to_event + .contains_key(token) + } + + /// Number of event routes currently tracked. + pub async fn event_route_count(&self) -> usize { + self.inner.read().await.routes.len() + } + + /// Number of progress token mappings currently tracked. + pub async fn progress_token_count(&self) -> usize { + self.inner.read().await.progress_token_to_event.len() } pub async fn clear(&self) { - self.event_to_client.write().await.clear(); + let mut inner = self.inner.write().await; + inner.routes.clear(); + inner.progress_token_to_event.clear(); + inner.client_event_ids.clear(); } } #[cfg(test)] mod tests { use super::*; + use serde_json::json; #[tokio::test] async fn pop_on_empty_returns_none() { @@ -67,7 +215,9 @@ mod tests { #[tokio::test] async fn get_returns_without_removing() { let store = ServerEventRouteStore::new(); - store.register("e1".into(), "pk1".into()).await; + store + .register("e1".into(), "pk1".into(), json!("r1"), None) + .await; assert_eq!(store.get("e1").await.as_deref(), Some("pk1")); assert_eq!(store.get("e1").await.as_deref(), Some("pk1")); } @@ -75,19 +225,29 @@ mod tests { #[tokio::test] async fn pop_removes_entry() { let store = ServerEventRouteStore::new(); - store.register("e1".into(), "pk1".into()).await; - assert_eq!(store.pop("e1").await.as_deref(), Some("pk1")); + store + .register("e1".into(), "pk1".into(), json!("r1"), None) + .await; + let route = store.pop("e1").await.unwrap(); + assert_eq!(route.client_pubkey, "pk1"); assert!(store.pop("e1").await.is_none()); } #[tokio::test] async fn remove_for_client_only_removes_matching() { let store = ServerEventRouteStore::new(); - store.register("e1".into(), "pk1".into()).await; - store.register("e2".into(), "pk2".into()).await; - store.register("e3".into(), "pk1".into()).await; + store + .register("e1".into(), "pk1".into(), json!("r1"), None) + .await; + store + .register("e2".into(), "pk2".into(), json!("r2"), None) + .await; + store + .register("e3".into(), "pk1".into(), json!("r3"), None) + .await; - store.remove_for_client("pk1").await; + let removed = store.remove_for_client("pk1").await; + assert_eq!(removed, 2); assert!(store.get("e1").await.is_none()); assert!(store.get("e3").await.is_none()); @@ -97,16 +257,23 @@ mod tests { #[tokio::test] async fn remove_for_client_noop_when_no_match() { let store = ServerEventRouteStore::new(); - store.register("e1".into(), "pk1".into()).await; - store.remove_for_client("pk_other").await; + store + .register("e1".into(), "pk1".into(), json!("r1"), None) + .await; + let removed = store.remove_for_client("pk_other").await; + assert_eq!(removed, 0); assert_eq!(store.get("e1").await.as_deref(), Some("pk1")); } #[tokio::test] async fn clear_empties_store() { let store = ServerEventRouteStore::new(); - store.register("e1".into(), "pk1".into()).await; - store.register("e2".into(), "pk2".into()).await; + store + .register("e1".into(), "pk1".into(), json!("r1"), None) + .await; + store + .register("e2".into(), "pk2".into(), json!("r2"), None) + .await; store.clear().await; assert!(store.get("e1").await.is_none()); assert!(store.get("e2").await.is_none()); diff --git a/src/transport/server/mod.rs b/src/transport/server/mod.rs index 7781557..6105520 100644 --- a/src/transport/server/mod.rs +++ b/src/transport/server/mod.rs @@ -5,17 +5,17 @@ //! server announcements. pub mod correlation_store; +pub mod session_store; -pub use correlation_store::ServerEventRouteStore; +pub use correlation_store::{RouteEntry, ServerEventRouteStore}; +pub use session_store::{SessionSnapshot, SessionStore}; -use std::collections::HashMap; use std::num::NonZeroUsize; use std::sync::{Arc, Mutex}; use std::time::Duration; use lru::LruCache; use nostr_sdk::prelude::*; -use tokio::sync::RwLock; use crate::core::constants::*; use crate::core::error::{Error, Result}; @@ -71,9 +71,9 @@ impl Default for NostrServerTransportConfig { pub struct NostrServerTransport { base: BaseTransport, config: NostrServerTransportConfig, - /// Client sessions: client_pubkey_hex → ClientSession - sessions: Arc>>, - /// Reverse lookup: event_id → client_pubkey_hex + /// Client sessions. + sessions: SessionStore, + /// Reverse lookup: event_id → client route. event_routes: ServerEventRouteStore, /// Outer gift-wrap event IDs successfully decrypted and verified (inner `verify()`). /// Duplicate outer ids are skipped before decrypt; ids are inserted only after success @@ -133,7 +133,7 @@ impl NostrServerTransport { is_connected: false, }, config, - sessions: Arc::new(RwLock::new(HashMap::new())), + sessions: SessionStore::new(), event_routes: ServerEventRouteStore::new(), seen_gift_wrap_ids, message_tx: tx, @@ -167,7 +167,7 @@ impl NostrServerTransport { is_connected: false, }, config, - sessions: Arc::new(RwLock::new(HashMap::new())), + sessions: SessionStore::new(), event_routes: ServerEventRouteStore::new(), seen_gift_wrap_ids, message_tx: tx, @@ -279,7 +279,7 @@ impl NostrServerTransport { /// Close the transport. pub async fn close(&mut self) -> Result<()> { self.base.disconnect().await?; - self.sessions.write().await.clear(); + self.sessions.clear().await; self.event_routes.clear().await; Ok(()) } @@ -634,7 +634,7 @@ impl NostrServerTransport { #[allow(clippy::too_many_arguments)] async fn event_loop( relay_pool: Arc, - sessions: Arc>>, + sessions: SessionStore, event_routes: ServerEventRouteStore, tx: tokio::sync::mpsc::UnboundedSender, allowed_pubkeys: Vec, @@ -798,28 +798,37 @@ impl NostrServerTransport { // Track request for correlation if let JsonRpcMessage::Request(ref req) = mcp_msg { let original_id = req.id.clone(); - session - .pending_requests - .insert(event_id.clone(), original_id); - event_routes - .register(event_id.clone(), sender_pubkey.clone()) - .await; - // Track progress token - if let Some(token) = req + // Extract progress token from _meta if present. + let progress_token = req .params .as_ref() .and_then(|p| p.get("_meta")) .and_then(|m| m.get("progressToken")) .and_then(|t| t.as_str()) - { + .map(String::from); + + // Duplicate into session fields (kept for backward compat). + session + .pending_requests + .insert(event_id.clone(), original_id.clone()); + if let Some(ref token) = progress_token { session .pending_requests - .insert(token.to_string(), serde_json::json!(event_id)); + .insert(token.clone(), serde_json::json!(event_id)); session .event_to_progress_token - .insert(event_id.clone(), token.to_string()); + .insert(event_id.clone(), token.clone()); } + + event_routes + .register( + event_id.clone(), + sender_pubkey.clone(), + original_id, + progress_token, + ) + .await; } // Handle initialized notification @@ -843,7 +852,7 @@ impl NostrServerTransport { } async fn cleanup_sessions( - sessions: &RwLock>, + sessions: &SessionStore, event_routes: &ServerEventRouteStore, timeout: Duration, ) -> usize { @@ -903,7 +912,7 @@ mod tests { #[tokio::test] async fn test_cleanup_sessions_removes_expired() { - let sessions = Arc::new(RwLock::new(HashMap::new())); + let sessions = SessionStore::new(); let event_routes = ServerEventRouteStore::new(); // Insert a session with an old activity time @@ -916,7 +925,12 @@ mod tests { .await .insert("pubkey1".to_string(), session); event_routes - .register("evt1".to_string(), "pubkey1".to_string()) + .register( + "evt1".to_string(), + "pubkey1".to_string(), + serde_json::json!(1), + None, + ) .await; // With a long timeout, nothing should be cleaned @@ -927,7 +941,7 @@ mod tests { ) .await; assert_eq!(cleaned, 0); - assert_eq!(sessions.read().await.len(), 1); + assert_eq!(sessions.session_count().await, 1); // With zero timeout, it should be cleaned thread::sleep(Duration::from_millis(5)); @@ -938,17 +952,16 @@ mod tests { ) .await; assert_eq!(cleaned, 1); - assert!(sessions.read().await.is_empty()); + assert_eq!(sessions.session_count().await, 0); assert!(event_routes.pop("evt1").await.is_none()); } #[tokio::test] async fn test_cleanup_preserves_active_sessions() { - let sessions = Arc::new(RwLock::new(HashMap::new())); + let sessions = SessionStore::new(); let event_routes = ServerEventRouteStore::new(); - let session = ClientSession::new(false); - sessions.write().await.insert("active".to_string(), session); + sessions.get_or_create_session("active", false).await; let cleaned = NostrServerTransport::cleanup_sessions( &sessions, @@ -957,7 +970,7 @@ mod tests { ) .await; assert_eq!(cleaned, 0); - assert_eq!(sessions.read().await.len(), 1); + assert_eq!(sessions.session_count().await, 1); } // ── Request ID correlation ────────────────────────────────── diff --git a/src/transport/server/session_store.rs b/src/transport/server/session_store.rs new file mode 100644 index 0000000..3361a9b --- /dev/null +++ b/src/transport/server/session_store.rs @@ -0,0 +1,204 @@ +//! Server-side session store for managing client sessions. + +use std::collections::HashMap; +use std::sync::Arc; + +use tokio::sync::RwLock; + +use crate::core::types::ClientSession; + +/// Manages client sessions keyed by public key (hex). +#[derive(Clone)] +pub struct SessionStore { + sessions: Arc>>, +} + +impl Default for SessionStore { + fn default() -> Self { + Self::new() + } +} + +impl SessionStore { + pub fn new() -> Self { + Self { + sessions: Arc::new(RwLock::new(HashMap::new())), + } + } + + /// Get an existing session or create a new one. Returns `true` if a new session was created. + pub async fn get_or_create_session(&self, client_pubkey: &str, is_encrypted: bool) -> bool { + let mut sessions = self.sessions.write().await; + if let Some(session) = sessions.get_mut(client_pubkey) { + session.is_encrypted = is_encrypted; + false + } else { + sessions.insert(client_pubkey.to_string(), ClientSession::new(is_encrypted)); + true + } + } + + /// Get a read-only snapshot of session fields. + /// Returns `None` if the session does not exist. + pub async fn get_session(&self, client_pubkey: &str) -> Option { + let sessions = self.sessions.read().await; + sessions.get(client_pubkey).map(|s| SessionSnapshot { + is_initialized: s.is_initialized, + is_encrypted: s.is_encrypted, + }) + } + + /// Mark a session as initialized. Returns `true` if the session existed. + pub async fn mark_initialized(&self, client_pubkey: &str) -> bool { + let mut sessions = self.sessions.write().await; + if let Some(session) = sessions.get_mut(client_pubkey) { + session.is_initialized = true; + true + } else { + false + } + } + + /// Remove a session. Returns `true` if it existed. + pub async fn remove_session(&self, client_pubkey: &str) -> bool { + self.sessions.write().await.remove(client_pubkey).is_some() + } + + /// Remove all sessions. + pub async fn clear(&self) { + self.sessions.write().await.clear(); + } + + /// Number of active sessions. + pub async fn session_count(&self) -> usize { + self.sessions.read().await.len() + } + + /// Return a snapshot of all sessions as `(client_pubkey, snapshot)` pairs. + pub async fn get_all_sessions(&self) -> Vec<(String, SessionSnapshot)> { + let sessions = self.sessions.read().await; + sessions + .iter() + .map(|(k, s)| { + ( + k.clone(), + SessionSnapshot { + is_initialized: s.is_initialized, + is_encrypted: s.is_encrypted, + }, + ) + }) + .collect() + } + + /// Acquire write access to the underlying map (transport internals only). + pub(crate) async fn write( + &self, + ) -> tokio::sync::RwLockWriteGuard<'_, HashMap> { + self.sessions.write().await + } + + /// Acquire read access to the underlying map (transport internals only). + pub(crate) async fn read( + &self, + ) -> tokio::sync::RwLockReadGuard<'_, HashMap> { + self.sessions.read().await + } +} + +/// A lightweight snapshot of session state (avoids exposing the full `ClientSession` +/// through the async API boundary). +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SessionSnapshot { + pub is_initialized: bool, + pub is_encrypted: bool, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn create_and_retrieve_session() { + let store = SessionStore::new(); + + let created = store.get_or_create_session("client-1", true).await; + assert!(created); + + let snap = store.get_session("client-1").await.unwrap(); + assert!(snap.is_encrypted); + assert!(!snap.is_initialized); + } + + #[tokio::test] + async fn get_or_create_returns_existing() { + let store = SessionStore::new(); + + let created = store.get_or_create_session("client-1", false).await; + assert!(created); + + let created2 = store.get_or_create_session("client-1", true).await; + assert!(!created2); + + // is_encrypted should have been updated. + let snap = store.get_session("client-1").await.unwrap(); + assert!(snap.is_encrypted); + } + + #[tokio::test] + async fn mark_initialized() { + let store = SessionStore::new(); + store.get_or_create_session("client-1", false).await; + + assert!(store.mark_initialized("client-1").await); + let snap = store.get_session("client-1").await.unwrap(); + assert!(snap.is_initialized); + } + + #[tokio::test] + async fn mark_initialized_unknown_returns_false() { + let store = SessionStore::new(); + assert!(!store.mark_initialized("unknown").await); + } + + #[tokio::test] + async fn remove_session() { + let store = SessionStore::new(); + store.get_or_create_session("client-1", false).await; + assert!(store.remove_session("client-1").await); + assert!(store.get_session("client-1").await.is_none()); + } + + #[tokio::test] + async fn remove_unknown_returns_false() { + let store = SessionStore::new(); + assert!(!store.remove_session("unknown").await); + } + + #[tokio::test] + async fn clear_all_sessions() { + let store = SessionStore::new(); + store.get_or_create_session("client-1", false).await; + store.get_or_create_session("client-2", true).await; + + store.clear().await; + + assert_eq!(store.session_count().await, 0); + assert!(store.get_session("client-1").await.is_none()); + assert!(store.get_session("client-2").await.is_none()); + } + + #[tokio::test] + async fn get_all_sessions() { + let store = SessionStore::new(); + store.get_or_create_session("client-1", false).await; + store.get_or_create_session("client-2", true).await; + + let all = store.get_all_sessions().await; + assert_eq!(all.len(), 2); + + let keys: Vec<&str> = all.iter().map(|(k, _)| k.as_str()).collect(); + assert!(keys.contains(&"client-1")); + assert!(keys.contains(&"client-2")); + } +} diff --git a/tests/conformance_stores.rs b/tests/conformance_stores.rs new file mode 100644 index 0000000..0aa659f --- /dev/null +++ b/tests/conformance_stores.rs @@ -0,0 +1,725 @@ +//! Conformance tests for store abstractions. +//! +//! Ported from the TS SDK: +//! - `src/transport/nostr-client/correlation-store.test.ts` +//! - `src/transport/nostr-server/session-store.test.ts` +//! - `src/transport/nostr-server/correlation-store.test.ts` +//! +//! LRU eviction tests are deferred — only non-eviction tests are ported here. + +use contextvm_sdk::{ClientCorrelationStore, ServerEventRouteStore, SessionStore}; +use serde_json::json; + +// ════════════════════════════════════════════════════════════════════ +// Client Correlation Store +// ════════════════════════════════════════════════════════════════════ + +mod client_correlation_store { + use super::*; + + // ── registerRequest ─────────────────────────────────────────── + + #[tokio::test] + async fn stores_request_with_event_id() { + let store = ClientCorrelationStore::new(); + store.register("event123".into(), json!("req1")).await; + assert!(store.contains("event123").await); + } + + #[tokio::test] + async fn stores_and_resolves_original_request_id() { + let store = ClientCorrelationStore::new(); + store.register("event456".into(), json!("req2")).await; + + // Retrieve the stored original ID. + let original = store.get_original_id("event456").await.unwrap(); + assert_eq!(original, json!("req2")); + + // After removal the entry is fully gone. + assert!(store.remove("event456").await); + assert!(store.get_original_id("event456").await.is_none()); + } + + // ── resolveResponse (get_original_id + remove) ──────────────── + + #[tokio::test] + async fn restores_original_request_id() { + let store = ClientCorrelationStore::new(); + store.register("event789".into(), json!(42)).await; + let original = store.get_original_id("event789").await.unwrap(); + assert_eq!(original, json!(42)); + } + + #[tokio::test] + async fn returns_none_for_unknown_event_id() { + let store = ClientCorrelationStore::new(); + assert!(store.get_original_id("unknown").await.is_none()); + } + + #[tokio::test] + async fn get_and_remove_roundtrip() { + let store = ClientCorrelationStore::new(); + store.register("event1".into(), json!("req1")).await; + + // Lookup succeeds before removal. + let original = store.get_original_id("event1").await.unwrap(); + assert_eq!(original, json!("req1")); + + // Remove returns true and cleans up completely. + assert!(store.remove("event1").await); + assert!(!store.contains("event1").await); + assert!(store.get_original_id("event1").await.is_none()); + } + + // ── removePendingRequest ────────────────────────────────────── + + #[tokio::test] + async fn removes_existing_request() { + let store = ClientCorrelationStore::new(); + store.register("event1".into(), json!(null)).await; + assert!(store.remove("event1").await); + assert!(!store.contains("event1").await); + } + + #[tokio::test] + async fn returns_false_for_unknown_request() { + let store = ClientCorrelationStore::new(); + assert!(!store.remove("unknown").await); + } + + // ── clear ───────────────────────────────────────────────────── + + #[tokio::test] + async fn removes_all_pending_requests() { + let store = ClientCorrelationStore::new(); + store.register("event1".into(), json!(null)).await; + store.register("event2".into(), json!(null)).await; + store.clear().await; + assert_eq!(store.count().await, 0); + } +} + +// ════════════════════════════════════════════════════════════════════ +// Server Session Store +// ════════════════════════════════════════════════════════════════════ + +mod server_session_store { + use super::*; + + #[tokio::test] + async fn create_and_retrieve_sessions() { + let store = SessionStore::new(); + + let created = store.get_or_create_session("client-1", true).await; + assert!(created); + + let session = store.get_session("client-1").await.unwrap(); + assert!(session.is_encrypted); + assert!(!session.is_initialized); + + // Retrieving same key should return it + assert!(store.get_session("client-1").await.is_some()); + } + + #[tokio::test] + async fn mark_sessions_as_initialized() { + let store = SessionStore::new(); + store.get_or_create_session("client-1", false).await; + + let result = store.mark_initialized("client-1").await; + assert!(result); + + let session = store.get_session("client-1").await.unwrap(); + assert!(session.is_initialized); + } + + #[tokio::test] + async fn remove_sessions() { + let store = SessionStore::new(); + store.get_or_create_session("client-1", false).await; + + let result = store.remove_session("client-1").await; + assert!(result); + assert!(store.get_session("client-1").await.is_none()); + } + + #[tokio::test] + async fn clear_all_sessions() { + let store = SessionStore::new(); + store.get_or_create_session("client-1", false).await; + store.get_or_create_session("client-2", true).await; + + store.clear().await; + + assert_eq!(store.session_count().await, 0); + assert!(store.get_session("client-1").await.is_none()); + assert!(store.get_session("client-2").await.is_none()); + } + + #[tokio::test] + async fn iterate_over_all_sessions() { + let store = SessionStore::new(); + store.get_or_create_session("client-1", false).await; + store.get_or_create_session("client-2", true).await; + + let sessions = store.get_all_sessions().await; + assert_eq!(sessions.len(), 2); + + let keys: Vec<&str> = sessions.iter().map(|(k, _)| k.as_str()).collect(); + assert!(keys.contains(&"client-1")); + assert!(keys.contains(&"client-2")); + } +} + +// ════════════════════════════════════════════════════════════════════ +// Server Correlation Store (ServerEventRouteStore) +// ════════════════════════════════════════════════════════════════════ + +mod server_correlation_store { + use super::*; + + // ── registerEventRoute ──────────────────────────────────────── + + #[tokio::test] + async fn registers_route_with_all_fields() { + let store = ServerEventRouteStore::new(); + store + .register( + "event1".into(), + "client1".into(), + json!("req1"), + Some("token1".into()), + ) + .await; + + let route = store.get_route("event1").await.unwrap(); + assert_eq!(route.client_pubkey, "client1"); + assert_eq!(route.original_request_id, json!("req1")); + assert_eq!(route.progress_token.as_deref(), Some("token1")); + } + + #[tokio::test] + async fn registers_route_without_progress_token() { + let store = ServerEventRouteStore::new(); + store + .register("event1".into(), "client1".into(), json!("req1"), None) + .await; + + let route = store.get_route("event1").await.unwrap(); + assert!(route.progress_token.is_none()); + } + + #[tokio::test] + async fn registers_route_with_numeric_request_id() { + let store = ServerEventRouteStore::new(); + store + .register("event1".into(), "client1".into(), json!(42), None) + .await; + + let route = store.get_route("event1").await.unwrap(); + assert_eq!(route.original_request_id, json!(42)); + } + + #[tokio::test] + async fn updates_client_index_when_registering() { + let store = ServerEventRouteStore::new(); + store + .register("event1".into(), "client1".into(), json!("req1"), None) + .await; + store + .register("event2".into(), "client1".into(), json!("req2"), None) + .await; + + assert!(store.has_active_routes_for_client("client1").await); + } + + #[tokio::test] + async fn registers_progress_token_mapping() { + let store = ServerEventRouteStore::new(); + store + .register( + "event1".into(), + "client1".into(), + json!("req1"), + Some("token1".into()), + ) + .await; + + assert_eq!( + store + .get_event_id_by_progress_token("token1") + .await + .as_deref(), + Some("event1") + ); + assert!(store.has_progress_token("token1").await); + } + + // ── getEventRoute ───────────────────────────────────────────── + + #[tokio::test] + async fn returns_none_for_unknown_event_id() { + let store = ServerEventRouteStore::new(); + assert!(store.get_route("unknown").await.is_none()); + } + + // ── popEventRoute ───────────────────────────────────────────── + + #[tokio::test] + async fn returns_and_removes_route_atomically() { + let store = ServerEventRouteStore::new(); + store + .register( + "event1".into(), + "client1".into(), + json!("req1"), + Some("token1".into()), + ) + .await; + + let route = store.pop("event1").await.unwrap(); + assert_eq!(route.client_pubkey, "client1"); + assert_eq!(route.original_request_id, json!("req1")); + assert_eq!(route.progress_token.as_deref(), Some("token1")); + + // Route + token mapping should be gone. + assert!(!store.has_event_route("event1").await); + assert!(!store.has_progress_token("token1").await); + + // Second pop is a no-op. + assert!(store.pop("event1").await.is_none()); + } + + // ── getEventIdByProgressToken ───────────────────────────────── + + #[tokio::test] + async fn returns_none_for_unknown_token() { + let store = ServerEventRouteStore::new(); + assert!(store + .get_event_id_by_progress_token("unknown") + .await + .is_none()); + } + + #[tokio::test] + async fn returns_correct_event_id_for_token() { + let store = ServerEventRouteStore::new(); + store + .register( + "event1".into(), + "client1".into(), + json!("req1"), + Some("token1".into()), + ) + .await; + store + .register( + "event2".into(), + "client2".into(), + json!("req2"), + Some("token2".into()), + ) + .await; + + assert_eq!( + store + .get_event_id_by_progress_token("token1") + .await + .as_deref(), + Some("event1") + ); + assert_eq!( + store + .get_event_id_by_progress_token("token2") + .await + .as_deref(), + Some("event2") + ); + } + + // ── removeRoutesForClient ───────────────────────────────────── + + #[tokio::test] + async fn removes_all_routes_for_client() { + let store = ServerEventRouteStore::new(); + store + .register("event1".into(), "client1".into(), json!("req1"), None) + .await; + store + .register("event2".into(), "client1".into(), json!("req2"), None) + .await; + store + .register("event3".into(), "client2".into(), json!("req3"), None) + .await; + + let removed = store.remove_for_client("client1").await; + assert_eq!(removed, 2); + + assert!(!store.has_event_route("event1").await); + assert!(!store.has_event_route("event2").await); + assert!(store.has_event_route("event3").await); + } + + #[tokio::test] + async fn returns_zero_for_unknown_client() { + let store = ServerEventRouteStore::new(); + assert_eq!(store.remove_for_client("unknown").await, 0); + } + + #[tokio::test] + async fn cleans_up_progress_tokens_for_removed_routes() { + let store = ServerEventRouteStore::new(); + store + .register( + "event1".into(), + "client1".into(), + json!("req1"), + Some("token1".into()), + ) + .await; + store + .register( + "event2".into(), + "client1".into(), + json!("req2"), + Some("token2".into()), + ) + .await; + + store.remove_for_client("client1").await; + + assert!(!store.has_progress_token("token1").await); + assert!(!store.has_progress_token("token2").await); + } + + #[tokio::test] + async fn removes_client_from_index_after_cleanup() { + let store = ServerEventRouteStore::new(); + store + .register("event1".into(), "client1".into(), json!("req1"), None) + .await; + + store.remove_for_client("client1").await; + + assert!(!store.has_active_routes_for_client("client1").await); + } + + // ── hasEventRoute ───────────────────────────────────────────── + + #[tokio::test] + async fn has_event_route_true_for_existing() { + let store = ServerEventRouteStore::new(); + store + .register("event1".into(), "client1".into(), json!("req1"), None) + .await; + assert!(store.has_event_route("event1").await); + } + + #[tokio::test] + async fn has_event_route_false_for_unknown() { + let store = ServerEventRouteStore::new(); + assert!(!store.has_event_route("unknown").await); + } + + // ── hasProgressToken ────────────────────────────────────────── + + #[tokio::test] + async fn has_progress_token_true_for_existing() { + let store = ServerEventRouteStore::new(); + store + .register( + "event1".into(), + "client1".into(), + json!("req1"), + Some("token1".into()), + ) + .await; + assert!(store.has_progress_token("token1").await); + } + + #[tokio::test] + async fn has_progress_token_false_for_unknown() { + let store = ServerEventRouteStore::new(); + assert!(!store.has_progress_token("unknown").await); + } + + // ── hasActiveRoutesForClient ────────────────────────────────── + + #[tokio::test] + async fn has_active_routes_true_when_routes_exist() { + let store = ServerEventRouteStore::new(); + store + .register("event1".into(), "client1".into(), json!("req1"), None) + .await; + assert!(store.has_active_routes_for_client("client1").await); + } + + #[tokio::test] + async fn has_active_routes_false_when_no_routes() { + let store = ServerEventRouteStore::new(); + assert!(!store.has_active_routes_for_client("client1").await); + } + + #[tokio::test] + async fn has_active_routes_false_after_all_popped() { + let store = ServerEventRouteStore::new(); + store + .register("event1".into(), "client1".into(), json!("req1"), None) + .await; + store.pop("event1").await; + assert!(!store.has_active_routes_for_client("client1").await); + } + + // ── eventRouteCount ─────────────────────────────────────────── + + #[tokio::test] + async fn event_route_count_zero_for_empty() { + let store = ServerEventRouteStore::new(); + assert_eq!(store.event_route_count().await, 0); + } + + #[tokio::test] + async fn event_route_count_after_registrations() { + let store = ServerEventRouteStore::new(); + store + .register("event1".into(), "client1".into(), json!("req1"), None) + .await; + store + .register("event2".into(), "client1".into(), json!("req2"), None) + .await; + assert_eq!(store.event_route_count().await, 2); + } + + #[tokio::test] + async fn event_route_count_after_removals() { + let store = ServerEventRouteStore::new(); + store + .register("event1".into(), "client1".into(), json!("req1"), None) + .await; + store + .register("event2".into(), "client1".into(), json!("req2"), None) + .await; + store.pop("event1").await; + assert_eq!(store.event_route_count().await, 1); + } + + // ── progressTokenCount ──────────────────────────────────────── + + #[tokio::test] + async fn progress_token_count_zero_for_empty() { + let store = ServerEventRouteStore::new(); + assert_eq!(store.progress_token_count().await, 0); + } + + #[tokio::test] + async fn progress_token_count_after_registrations() { + let store = ServerEventRouteStore::new(); + store + .register( + "event1".into(), + "client1".into(), + json!("req1"), + Some("token1".into()), + ) + .await; + store + .register( + "event2".into(), + "client1".into(), + json!("req2"), + Some("token2".into()), + ) + .await; + store + .register("event3".into(), "client1".into(), json!("req3"), None) + .await; + assert_eq!(store.progress_token_count().await, 2); + } + + #[tokio::test] + async fn progress_token_count_after_removals() { + let store = ServerEventRouteStore::new(); + store + .register( + "event1".into(), + "client1".into(), + json!("req1"), + Some("token1".into()), + ) + .await; + store + .register( + "event2".into(), + "client1".into(), + json!("req2"), + Some("token2".into()), + ) + .await; + store.pop("event1").await; + assert_eq!(store.progress_token_count().await, 1); + } + + // ── clear ───────────────────────────────────────────────────── + + #[tokio::test] + async fn clear_removes_all_routes() { + let store = ServerEventRouteStore::new(); + store + .register("event1".into(), "client1".into(), json!("req1"), None) + .await; + store + .register("event2".into(), "client2".into(), json!("req2"), None) + .await; + + store.clear().await; + + assert_eq!(store.event_route_count().await, 0); + assert!(!store.has_event_route("event1").await); + } + + #[tokio::test] + async fn clear_removes_all_progress_tokens() { + let store = ServerEventRouteStore::new(); + store + .register( + "event1".into(), + "client1".into(), + json!("req1"), + Some("token1".into()), + ) + .await; + + store.clear().await; + + assert_eq!(store.progress_token_count().await, 0); + assert!(!store.has_progress_token("token1").await); + } + + #[tokio::test] + async fn clear_cleans_up_client_index() { + let store = ServerEventRouteStore::new(); + store + .register("event1".into(), "client1".into(), json!("req1"), None) + .await; + + store.clear().await; + + assert!(!store.has_active_routes_for_client("client1").await); + } + + // ── complex scenarios ───────────────────────────────────────── + + #[tokio::test] + async fn handles_multiple_clients_with_multiple_routes() { + let store = ServerEventRouteStore::new(); + + // Client 1: 2 routes + store + .register( + "c1e1".into(), + "client1".into(), + json!("r1"), + Some("t1".into()), + ) + .await; + store + .register( + "c1e2".into(), + "client1".into(), + json!("r2"), + Some("t2".into()), + ) + .await; + + // Client 2: 1 route + store + .register( + "c2e1".into(), + "client2".into(), + json!("r3"), + Some("t3".into()), + ) + .await; + + assert_eq!(store.event_route_count().await, 3); + assert_eq!(store.progress_token_count().await, 3); + assert!(store.has_active_routes_for_client("client1").await); + assert!(store.has_active_routes_for_client("client2").await); + + // Remove one of client1's routes + store.pop("c1e1").await; + + assert!(store.has_active_routes_for_client("client1").await); + assert!(!store.has_progress_token("t1").await); + assert!(store.has_progress_token("t2").await); + } + + #[tokio::test] + async fn handles_route_replacement_with_same_progress_token() { + let store = ServerEventRouteStore::new(); + + store + .register( + "event1".into(), + "client1".into(), + json!("req1"), + Some("token1".into()), + ) + .await; + assert_eq!( + store + .get_event_id_by_progress_token("token1") + .await + .as_deref(), + Some("event1") + ); + + // Register new route with same token (overwrites mapping) + store + .register( + "event2".into(), + "client1".into(), + json!("req2"), + Some("token1".into()), + ) + .await; + assert_eq!( + store + .get_event_id_by_progress_token("token1") + .await + .as_deref(), + Some("event2") + ); + } + + #[tokio::test] + async fn maintains_consistency_through_mixed_operations() { + let store = ServerEventRouteStore::new(); + + // Add routes + store + .register("e1".into(), "c1".into(), json!("r1"), Some("t1".into())) + .await; + store + .register("e2".into(), "c1".into(), json!("r2"), Some("t2".into())) + .await; + store + .register("e3".into(), "c2".into(), json!("r3"), Some("t3".into())) + .await; + + // Remove one + store.pop("e2").await; + + // Verify consistency + assert!(store.has_event_route("e1").await); + assert!(!store.has_event_route("e2").await); + assert!(store.has_event_route("e3").await); + + assert!(store.has_progress_token("t1").await); + assert!(!store.has_progress_token("t2").await); + assert!(store.has_progress_token("t3").await); + + assert!(store.has_active_routes_for_client("c1").await); + assert!(store.has_active_routes_for_client("c2").await); + } +} From c2e9c48eabf775ec034e12cd4d1a1f8583a8aba8 Mon Sep 17 00:00:00 2001 From: Kushagra Date: Thu, 23 Apr 2026 14:41:52 +0530 Subject: [PATCH 44/69] feat(cep-19): Added ephemeral package handling logic in client transport --- src/proxy/mod.rs | 1 + src/transport/base.rs | 12 +- src/transport/client/mod.rs | 171 ++++++++++++++++++++++++++-- src/transport/server/mod.rs | 2 + tests/conformance_stateless_mode.rs | 3 +- 5 files changed, 179 insertions(+), 10 deletions(-) diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index df1d386..1322fa0 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -110,6 +110,7 @@ mod tests { relay_urls: vec!["wss://relay.example.com".to_string()], server_pubkey: server_pubkey.clone(), encryption_mode: EncryptionMode::Required, + gift_wrap_mode: GiftWrapMode::Optional, is_stateless: true, timeout: Duration::from_secs(60), log_file_path: None, diff --git a/src/transport/base.rs b/src/transport/base.rs index 4f488a5..ccd7e03 100644 --- a/src/transport/base.rs +++ b/src/transport/base.rs @@ -112,6 +112,7 @@ impl BaseTransport { kind: u16, tags: Vec, is_encrypted: Option, + gift_wrap_kind: Option, ) -> Result { let should_encrypt = self.should_encrypt(kind, is_encrypted); @@ -128,13 +129,20 @@ impl BaseTransport { .signer() .await .map_err(|e| Error::Encryption(e.to_string()))?; - let gift_wrap_event = - encryption::gift_wrap_single_layer(&signer, recipient, &event_json).await?; + let selected_gift_wrap_kind = gift_wrap_kind.unwrap_or(GIFT_WRAP_KIND); + let gift_wrap_event = encryption::gift_wrap_single_layer_with_kind( + &signer, + recipient, + &event_json, + selected_gift_wrap_kind, + ) + .await?; self.relay_pool.publish_event(&gift_wrap_event).await?; tracing::debug!( target: LOG_TARGET, signed_event_id = %signed_event_id, envelope_id = %gift_wrap_event.id, + gift_wrap_kind = selected_gift_wrap_kind, "Sent encrypted MCP message" ); } else { diff --git a/src/transport/client/mod.rs b/src/transport/client/mod.rs index 955e8fc..32e7427 100644 --- a/src/transport/client/mod.rs +++ b/src/transport/client/mod.rs @@ -8,6 +8,7 @@ pub mod correlation_store; pub use correlation_store::ClientCorrelationStore; use std::num::NonZeroUsize; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; use std::time::Duration; @@ -35,6 +36,8 @@ pub struct NostrClientTransportConfig { pub server_pubkey: String, /// Encryption mode. pub encryption_mode: EncryptionMode, + /// Gift-wrap policy for encrypted messages. + pub gift_wrap_mode: GiftWrapMode, /// Stateless mode: emulate initialize response locally. pub is_stateless: bool, /// Response timeout (default: 30s). @@ -49,6 +52,7 @@ impl Default for NostrClientTransportConfig { relay_urls: vec!["wss://relay.damus.io".to_string()], server_pubkey: String::new(), encryption_mode: EncryptionMode::Optional, + gift_wrap_mode: GiftWrapMode::Optional, is_stateless: false, timeout: Duration::from_secs(30), log_file_path: None, @@ -63,6 +67,8 @@ pub struct NostrClientTransport { server_pubkey: PublicKey, /// Pending request event IDs awaiting responses. pending_requests: ClientCorrelationStore, + /// Learned support for server-side ephemeral gift wraps. + server_supports_ephemeral: Arc, /// Outer gift-wrap event IDs successfully decrypted and verified (inner `verify()`). /// Duplicate outer ids are skipped before decrypt; ids are inserted only after success /// so failed decrypt/verify can be retried on redelivery. @@ -120,6 +126,7 @@ impl NostrClientTransport { config, server_pubkey, pending_requests: ClientCorrelationStore::new(), + server_supports_ephemeral: Arc::new(AtomicBool::new(false)), seen_gift_wrap_ids, message_tx: tx, message_rx: Some(rx), @@ -164,6 +171,7 @@ impl NostrClientTransport { config, server_pubkey, pending_requests: ClientCorrelationStore::new(), + server_supports_ephemeral: Arc::new(AtomicBool::new(false)), seen_gift_wrap_ids, message_tx: tx, message_rx: Some(rx), @@ -217,6 +225,8 @@ impl NostrClientTransport { let server_pubkey = self.server_pubkey; let tx = self.message_tx.clone(); let encryption_mode = self.config.encryption_mode; + let gift_wrap_mode = self.config.gift_wrap_mode; + let server_supports_ephemeral = self.server_supports_ephemeral.clone(); let seen_gift_wrap_ids = self.seen_gift_wrap_ids.clone(); tokio::spawn(async move { @@ -226,6 +236,8 @@ impl NostrClientTransport { server_pubkey, tx, encryption_mode, + gift_wrap_mode, + server_supports_ephemeral, seen_gift_wrap_ids, ) .await; @@ -270,6 +282,7 @@ impl NostrClientTransport { CTXVM_MESSAGES_KIND, tags, None, + Some(self.choose_outbound_gift_wrap_kind()), ) .await .map_err(|error| { @@ -323,12 +336,15 @@ impl NostrClientTransport { let _ = self.message_tx.send(response); } + #[allow(clippy::too_many_arguments)] async fn event_loop( relay_pool: Arc, pending: ClientCorrelationStore, server_pubkey: PublicKey, tx: tokio::sync::mpsc::UnboundedSender, encryption_mode: EncryptionMode, + gift_wrap_mode: GiftWrapMode, + server_supports_ephemeral: Arc, seen_gift_wrap_ids: Arc>>, ) { let mut notifications = relay_pool.notifications(); @@ -336,25 +352,42 @@ impl NostrClientTransport { while let Ok(notification) = notifications.recv().await { if let RelayPoolNotification::Event { event, .. } = notification { let is_gift_wrap = is_gift_wrap_kind(&event.kind); + let outer_kind = event.kind.as_u16(); - // Enforce mode before decrypt/parse. + // Enforce encryption mode before decrypt/parse. if violates_encryption_policy(&event.kind, &encryption_mode) { if is_gift_wrap { tracing::warn!( + target: LOG_TARGET, event_id = %event.id.to_hex(), - "Received encrypted response but encryption is disabled" + event_kind = outer_kind, + configured_mode = ?gift_wrap_mode, + "Skipping encrypted response because client encryption is disabled" ); } else { tracing::warn!( + target: LOG_TARGET, event_id = %event.id.to_hex(), - "Received unencrypted response but encryption is required" + "Skipping plaintext response because client encryption is required" ); } continue; } + // Enforce CEP-19 gift-wrap-mode policy. + if is_gift_wrap && !gift_wrap_mode.allows_kind(outer_kind) { + tracing::warn!( + target: LOG_TARGET, + event_id = %event.id.to_hex(), + event_kind = outer_kind, + configured_mode = ?gift_wrap_mode, + "Skipping gift wrap due to CEP-19 policy" + ); + continue; + } + // Handle gift-wrapped events - let (actual_event_content, actual_pubkey, e_tag) = if is_gift_wrap { + let (actual_event_content, actual_pubkey, e_tag, verified_tags) = if is_gift_wrap { { let guard = match seen_gift_wrap_ids.lock() { Ok(g) => g, @@ -399,7 +432,7 @@ impl NostrClientTransport { guard.put(event.id, ()); } let e_tag = serializers::get_tag_value(&inner.tags, "e"); - (inner.content, inner.pubkey, e_tag) + (inner.content, inner.pubkey, e_tag, inner.tags) } Err(error) => { tracing::error!( @@ -422,7 +455,12 @@ impl NostrClientTransport { } } else { let e_tag = serializers::get_tag_value(&event.tags, "e"); - (event.content.clone(), event.pubkey, e_tag) + ( + event.content.clone(), + event.pubkey, + e_tag, + event.tags.clone(), + ) }; // Verify it's from our server @@ -436,6 +474,16 @@ impl NostrClientTransport { continue; } + // CEP-19: learn ephemeral support from server + if Self::should_learn_ephemeral_support( + actual_pubkey, + server_pubkey, + if is_gift_wrap { Some(outer_kind) } else { None }, + &verified_tags, + ) { + server_supports_ephemeral.store(true, Ordering::Relaxed); + } + // Correlate response if let Some(ref correlated_id) = e_tag { let is_pending = pending.contains(correlated_id.as_str()).await; @@ -460,6 +508,45 @@ impl NostrClientTransport { } } } + + fn choose_outbound_gift_wrap_kind(&self) -> u16 { + match self.config.gift_wrap_mode { + GiftWrapMode::Persistent => GIFT_WRAP_KIND, + GiftWrapMode::Ephemeral => EPHEMERAL_GIFT_WRAP_KIND, + GiftWrapMode::Optional => { + if self.server_supports_ephemeral.load(Ordering::Relaxed) { + EPHEMERAL_GIFT_WRAP_KIND + } else { + GIFT_WRAP_KIND + } + } + } + } + + fn has_support_ephemeral_tag(tags: &Tags) -> bool { + tags.iter().any(|tag| { + tag.kind() + == TagKind::Custom( + crate::core::constants::tags::SUPPORT_ENCRYPTION_EPHEMERAL.into(), + ) + }) + } + + fn should_learn_ephemeral_support( + actual_pubkey: PublicKey, + server_pubkey: PublicKey, + event_kind: Option, + tags: &Tags, + ) -> bool { + actual_pubkey == server_pubkey + && (event_kind == Some(EPHEMERAL_GIFT_WRAP_KIND) + || Self::has_support_ephemeral_tag(tags)) + } + + /// Returns whether the client has learned ephemeral gift-wrap support from the server. + pub fn server_supports_ephemeral_encryption(&self) -> bool { + self.server_supports_ephemeral.load(Ordering::Relaxed) + } } #[inline] @@ -486,6 +573,7 @@ mod tests { assert_eq!(config.relay_urls, vec!["wss://relay.damus.io".to_string()]); assert!(config.server_pubkey.is_empty()); assert_eq!(config.encryption_mode, EncryptionMode::Optional); + assert_eq!(config.gift_wrap_mode, GiftWrapMode::Optional); assert!(!config.is_stateless); assert_eq!(config.timeout, Duration::from_secs(30)); assert!(config.log_file_path.is_none()); @@ -500,9 +588,78 @@ mod tests { assert!(config.is_stateless); } + #[test] + fn test_has_support_ephemeral_tag_detects_capability() { + let tags = Tags::from_list(vec![Tag::custom( + TagKind::Custom(crate::core::constants::tags::SUPPORT_ENCRYPTION_EPHEMERAL.into()), + Vec::::new(), + )]); + assert!(NostrClientTransport::has_support_ephemeral_tag(&tags)); + } + + #[test] + fn test_has_support_ephemeral_tag_absent() { + let tags = Tags::from_list(vec![Tag::custom( + TagKind::Custom(crate::core::constants::tags::SUPPORT_ENCRYPTION.into()), + Vec::::new(), + )]); + assert!(!NostrClientTransport::has_support_ephemeral_tag(&tags)); + } + + #[test] + fn test_should_learn_ephemeral_support_requires_matching_server_pubkey() { + let server_keys = Keys::generate(); + let other_keys = Keys::generate(); + let tags = Tags::from_list(vec![Tag::custom( + TagKind::Custom(crate::core::constants::tags::SUPPORT_ENCRYPTION_EPHEMERAL.into()), + Vec::::new(), + )]); + + assert!(!NostrClientTransport::should_learn_ephemeral_support( + other_keys.public_key(), + server_keys.public_key(), + Some(EPHEMERAL_GIFT_WRAP_KIND), + &tags, + )); + assert!(NostrClientTransport::should_learn_ephemeral_support( + server_keys.public_key(), + server_keys.public_key(), + Some(EPHEMERAL_GIFT_WRAP_KIND), + &tags, + )); + } + + #[test] + fn test_should_learn_from_ephemeral_kind_even_without_tag() { + let server_keys = Keys::generate(); + let empty_tags = Tags::from_list(vec![]); + + assert!(NostrClientTransport::should_learn_ephemeral_support( + server_keys.public_key(), + server_keys.public_key(), + Some(EPHEMERAL_GIFT_WRAP_KIND), + &empty_tags, + )); + } + + #[test] + fn test_should_learn_from_tag_without_ephemeral_kind() { + let server_keys = Keys::generate(); + let tags = Tags::from_list(vec![Tag::custom( + TagKind::Custom(crate::core::constants::tags::SUPPORT_ENCRYPTION_EPHEMERAL.into()), + Vec::::new(), + )]); + + assert!(NostrClientTransport::should_learn_ephemeral_support( + server_keys.public_key(), + server_keys.public_key(), + Some(GIFT_WRAP_KIND), // persistent kind, but tag present + &tags, + )); + } + #[test] fn test_stateless_emulated_initialize_response_shape() { - // Verify the emulated response has the expected structure let request_id = serde_json::json!(1); let response = JsonRpcMessage::Response(JsonRpcResponse { jsonrpc: "2.0".to_string(), diff --git a/src/transport/server/mod.rs b/src/transport/server/mod.rs index 7781557..4516bec 100644 --- a/src/transport/server/mod.rs +++ b/src/transport/server/mod.rs @@ -346,6 +346,7 @@ impl NostrServerTransport { CTXVM_MESSAGES_KIND, tags, Some(is_encrypted), + None, ) .await .map_err(|error| { @@ -412,6 +413,7 @@ impl NostrServerTransport { CTXVM_MESSAGES_KIND, tags, Some(is_encrypted), + None, ) .await?; diff --git a/tests/conformance_stateless_mode.rs b/tests/conformance_stateless_mode.rs index 54b633b..6eee1b6 100644 --- a/tests/conformance_stateless_mode.rs +++ b/tests/conformance_stateless_mode.rs @@ -3,7 +3,7 @@ use std::time::Duration; use contextvm_sdk::core::constants::{mcp_protocol_version, INITIALIZE_METHOD}; -use contextvm_sdk::core::types::{EncryptionMode, JsonRpcMessage, JsonRpcRequest}; +use contextvm_sdk::core::types::{EncryptionMode, GiftWrapMode, JsonRpcMessage, JsonRpcRequest}; use contextvm_sdk::signer; use contextvm_sdk::transport::client::{NostrClientTransport, NostrClientTransportConfig}; use tokio::time::timeout; @@ -19,6 +19,7 @@ async fn make_stateless_transport() -> ( relay_urls: Vec::new(), server_pubkey: server_keys.public_key().to_hex(), encryption_mode: EncryptionMode::Optional, + gift_wrap_mode: GiftWrapMode::Optional, is_stateless: true, timeout: Duration::from_secs(1), log_file_path: None, From e5fe6d66ee90298dea3df6a32805559a1448f973 Mon Sep 17 00:00:00 2001 From: Harsh Date: Fri, 24 Apr 2026 16:30:06 +0530 Subject: [PATCH 45/69] test: add LRU eviction and registerRequest initialize flag conformance tests --- src/transport/client/correlation_store.rs | 78 +++++++++++---- src/transport/client/mod.rs | 3 +- src/transport/server/correlation_store.rs | 61 ++++++++---- tests/conformance_stores.rs | 112 ++++++++++++++++++++-- 4 files changed, 210 insertions(+), 44 deletions(-) diff --git a/src/transport/client/correlation_store.rs b/src/transport/client/correlation_store.rs index 140f664..924e283 100644 --- a/src/transport/client/correlation_store.rs +++ b/src/transport/client/correlation_store.rs @@ -1,14 +1,27 @@ //! Client-side correlation store for tracking pending request event IDs. -use std::collections::HashMap; +use std::num::NonZeroUsize; use std::sync::Arc; +use lru::LruCache; use tokio::sync::RwLock; +/// A pending request tracked by the correlation store. +#[derive(Debug, Clone)] +pub struct PendingRequest { + /// The original JSON-RPC request ID before event-ID replacement. + pub original_id: serde_json::Value, + /// Whether this request is an `initialize` handshake. + pub is_initialize: bool, +} + /// Tracks pending request event IDs and their original request IDs on the client side. +/// +/// An optional capacity limit enables LRU eviction of the oldest entry when the +/// store is full. #[derive(Clone)] pub struct ClientCorrelationStore { - pending_requests: Arc>>, + pending_requests: Arc>>, } impl Default for ClientCorrelationStore { @@ -20,34 +33,61 @@ impl Default for ClientCorrelationStore { impl ClientCorrelationStore { pub fn new() -> Self { Self { - pending_requests: Arc::new(RwLock::new(HashMap::new())), + pending_requests: Arc::new(RwLock::new(LruCache::unbounded())), + } + } + + /// Create a store with an upper bound on pending requests. + /// When the limit is reached the oldest entry is evicted. + pub fn with_max_pending(max_pending: usize) -> Self { + Self { + pending_requests: Arc::new(RwLock::new(LruCache::new( + NonZeroUsize::new(max_pending).expect("max_pending must be non-zero"), + ))), } } /// Register a pending request with its original JSON-RPC request ID. - pub async fn register(&self, event_id: String, original_id: serde_json::Value) { + pub async fn register( + &self, + event_id: String, + original_id: serde_json::Value, + is_initialize: bool, + ) { + self.pending_requests.write().await.push( + event_id, + PendingRequest { + original_id, + is_initialize, + }, + ); + } + + /// Check whether a given event ID corresponds to an `initialize` request. + pub async fn is_initialize_request(&self, event_id: &str) -> bool { self.pending_requests - .write() + .read() .await - .insert(event_id, original_id); + .peek(event_id) + .is_some_and(|r| r.is_initialize) } pub async fn contains(&self, event_id: &str) -> bool { - self.pending_requests.read().await.contains_key(event_id) + self.pending_requests.read().await.contains(event_id) } /// Remove a pending request. Returns `true` if the key existed. pub async fn remove(&self, event_id: &str) -> bool { - self.pending_requests - .write() - .await - .remove(event_id) - .is_some() + self.pending_requests.write().await.pop(event_id).is_some() } /// Retrieve the original request ID for a given event ID without removing it. pub async fn get_original_id(&self, event_id: &str) -> Option { - self.pending_requests.read().await.get(event_id).cloned() + self.pending_requests + .read() + .await + .peek(event_id) + .map(|r| r.original_id.clone()) } /// Number of pending requests currently tracked. @@ -74,8 +114,12 @@ mod tests { #[tokio::test] async fn contains_after_clear() { let store = ClientCorrelationStore::new(); - store.register("e1".into(), serde_json::Value::Null).await; - store.register("e2".into(), serde_json::Value::Null).await; + store + .register("e1".into(), serde_json::Value::Null, false) + .await; + store + .register("e2".into(), serde_json::Value::Null, false) + .await; assert!(store.contains("e1").await); store.clear().await; assert!(!store.contains("e1").await); @@ -85,7 +129,9 @@ mod tests { #[tokio::test] async fn register_and_remove_roundtrip() { let store = ClientCorrelationStore::new(); - store.register("e1".into(), serde_json::Value::Null).await; + store + .register("e1".into(), serde_json::Value::Null, false) + .await; assert!(store.contains("e1").await); assert!(store.remove("e1").await); assert!(!store.contains("e1").await); diff --git a/src/transport/client/mod.rs b/src/transport/client/mod.rs index d7f2c2b..f615967 100644 --- a/src/transport/client/mod.rs +++ b/src/transport/client/mod.rs @@ -284,8 +284,9 @@ impl NostrClientTransport { })?; if let JsonRpcMessage::Request(ref req) = message { + let is_initialize = req.method == INITIALIZE_METHOD; self.pending_requests - .register(event_id.to_hex(), req.id.clone()) + .register(event_id.to_hex(), req.id.clone(), is_initialize) .await; } diff --git a/src/transport/server/correlation_store.rs b/src/transport/server/correlation_store.rs index 1879d80..09fd16e 100644 --- a/src/transport/server/correlation_store.rs +++ b/src/transport/server/correlation_store.rs @@ -1,8 +1,10 @@ //! Server-side event route store for mapping event IDs to client routes. use std::collections::{HashMap, HashSet}; +use std::num::NonZeroUsize; use std::sync::Arc; +use lru::LruCache; use tokio::sync::RwLock; /// A route entry for an in-flight request. @@ -18,8 +20,8 @@ pub struct RouteEntry { /// Internal state behind the lock. struct Inner { - /// Primary index: event_id → route entry. - routes: HashMap, + /// Primary index: event_id → route entry (LRU-ordered). + routes: LruCache, /// Secondary index: progress_token → event_id. progress_token_to_event: HashMap, /// Secondary index: client_pubkey → set of event_ids. @@ -27,36 +29,43 @@ struct Inner { } impl Inner { - fn new() -> Self { + fn new(max_routes: Option) -> Self { + let routes = match max_routes { + Some(n) => LruCache::new(NonZeroUsize::new(n).expect("max_routes must be non-zero")), + None => LruCache::unbounded(), + }; Self { - routes: HashMap::new(), + routes, progress_token_to_event: HashMap::new(), client_event_ids: HashMap::new(), } } - /// Remove a single route and clean up all secondary indexes. - fn remove_route(&mut self, event_id: &str) -> Option { - let route = self.routes.remove(event_id)?; - - // Clean up progress token index. + /// Clean up secondary indexes for a removed route. + fn cleanup_indexes(&mut self, event_id: &str, route: &RouteEntry) { if let Some(ref token) = route.progress_token { self.progress_token_to_event.remove(token); } - - // Clean up client index. if let Some(set) = self.client_event_ids.get_mut(&route.client_pubkey) { set.remove(event_id); if set.is_empty() { self.client_event_ids.remove(&route.client_pubkey); } } + } + /// Remove a single route and clean up all secondary indexes. + fn remove_route(&mut self, event_id: &str) -> Option { + let route = self.routes.pop(event_id)?; + self.cleanup_indexes(event_id, &route); Some(route) } } /// Maps event IDs to full route entries for response routing on the server side. +/// +/// An optional capacity limit enables LRU eviction; when the limit is reached +/// the oldest entry is evicted and its secondary indexes are cleaned up. #[derive(Clone)] pub struct ServerEventRouteStore { inner: Arc>, @@ -71,7 +80,15 @@ impl Default for ServerEventRouteStore { impl ServerEventRouteStore { pub fn new() -> Self { Self { - inner: Arc::new(RwLock::new(Inner::new())), + inner: Arc::new(RwLock::new(Inner::new(None))), + } + } + + /// Create a store with an upper bound on event routes. + /// When the limit is reached the oldest entry is evicted. + pub fn with_max_routes(max_routes: usize) -> Self { + Self { + inner: Arc::new(RwLock::new(Inner::new(Some(max_routes)))), } } @@ -99,14 +116,22 @@ impl ServerEventRouteStore { .insert(token.clone(), event_id.clone()); } - inner.routes.insert( - event_id, + // Insert into LRU; handle possible eviction. + let evicted = inner.routes.push( + event_id.clone(), RouteEntry { client_pubkey, original_request_id, progress_token, }, ); + + if let Some((evicted_key, evicted_route)) = evicted { + if evicted_key != event_id { + // A different entry was evicted due to capacity — clean up its indexes. + inner.cleanup_indexes(&evicted_key, &evicted_route); + } + } } /// Returns the client public key for the given event ID without removing it. @@ -115,13 +140,13 @@ impl ServerEventRouteStore { .read() .await .routes - .get(event_id) + .peek(event_id) .map(|r| r.client_pubkey.clone()) } /// Returns the full route entry for the given event ID without removing it. pub async fn get_route(&self, event_id: &str) -> Option { - self.inner.read().await.routes.get(event_id).cloned() + self.inner.read().await.routes.peek(event_id).cloned() } /// Removes and returns the full route entry for the given event ID. @@ -140,7 +165,7 @@ impl ServerEventRouteStore { let count = event_ids.len(); for event_id in &event_ids { - if let Some(route) = inner.routes.remove(event_id.as_str()) { + if let Some(route) = inner.routes.pop(event_id.as_str()) { if let Some(ref token) = route.progress_token { inner.progress_token_to_event.remove(token); } @@ -151,7 +176,7 @@ impl ServerEventRouteStore { /// Check whether a route exists for the given event ID. pub async fn has_event_route(&self, event_id: &str) -> bool { - self.inner.read().await.routes.contains_key(event_id) + self.inner.read().await.routes.contains(event_id) } /// Check whether the given client has any active routes. diff --git a/tests/conformance_stores.rs b/tests/conformance_stores.rs index 0aa659f..55917ec 100644 --- a/tests/conformance_stores.rs +++ b/tests/conformance_stores.rs @@ -4,8 +4,6 @@ //! - `src/transport/nostr-client/correlation-store.test.ts` //! - `src/transport/nostr-server/session-store.test.ts` //! - `src/transport/nostr-server/correlation-store.test.ts` -//! -//! LRU eviction tests are deferred — only non-eviction tests are ported here. use contextvm_sdk::{ClientCorrelationStore, ServerEventRouteStore, SessionStore}; use serde_json::json; @@ -22,14 +20,18 @@ mod client_correlation_store { #[tokio::test] async fn stores_request_with_event_id() { let store = ClientCorrelationStore::new(); - store.register("event123".into(), json!("req1")).await; + store + .register("event123".into(), json!("req1"), false) + .await; assert!(store.contains("event123").await); } #[tokio::test] async fn stores_and_resolves_original_request_id() { let store = ClientCorrelationStore::new(); - store.register("event456".into(), json!("req2")).await; + store + .register("event456".into(), json!("req2"), false) + .await; // Retrieve the stored original ID. let original = store.get_original_id("event456").await.unwrap(); @@ -40,12 +42,23 @@ mod client_correlation_store { assert!(store.get_original_id("event456").await.is_none()); } + #[tokio::test] + async fn register_request_flags_initialize_requests() { + let store = ClientCorrelationStore::new(); + store.register("e_init".into(), json!("r1"), true).await; + store.register("e_normal".into(), json!("r2"), false).await; + + assert!(store.is_initialize_request("e_init").await); + assert!(!store.is_initialize_request("e_normal").await); + assert!(!store.is_initialize_request("unknown").await); + } + // ── resolveResponse (get_original_id + remove) ──────────────── #[tokio::test] async fn restores_original_request_id() { let store = ClientCorrelationStore::new(); - store.register("event789".into(), json!(42)).await; + store.register("event789".into(), json!(42), false).await; let original = store.get_original_id("event789").await.unwrap(); assert_eq!(original, json!(42)); } @@ -59,7 +72,7 @@ mod client_correlation_store { #[tokio::test] async fn get_and_remove_roundtrip() { let store = ClientCorrelationStore::new(); - store.register("event1".into(), json!("req1")).await; + store.register("event1".into(), json!("req1"), false).await; // Lookup succeeds before removal. let original = store.get_original_id("event1").await.unwrap(); @@ -76,7 +89,7 @@ mod client_correlation_store { #[tokio::test] async fn removes_existing_request() { let store = ClientCorrelationStore::new(); - store.register("event1".into(), json!(null)).await; + store.register("event1".into(), json!(null), false).await; assert!(store.remove("event1").await); assert!(!store.contains("event1").await); } @@ -92,11 +105,30 @@ mod client_correlation_store { #[tokio::test] async fn removes_all_pending_requests() { let store = ClientCorrelationStore::new(); - store.register("event1".into(), json!(null)).await; - store.register("event2".into(), json!(null)).await; + store.register("event1".into(), json!(null), false).await; + store.register("event2".into(), json!(null), false).await; store.clear().await; assert_eq!(store.count().await, 0); } + + // ── LRU eviction (TS SDK client test 9) ─────────────────────── + + #[tokio::test] + async fn evicts_oldest_when_capacity_reached() { + let store = ClientCorrelationStore::with_max_pending(2); + for i in 0..5 { + store + .register(format!("event{i}"), json!(null), false) + .await; + } + assert_eq!(store.count().await, 2); + // Only the two most recent entries survive. + assert!(!store.contains("event0").await); + assert!(!store.contains("event1").await); + assert!(!store.contains("event2").await); + assert!(store.contains("event3").await); + assert!(store.contains("event4").await); + } } // ════════════════════════════════════════════════════════════════════ @@ -722,4 +754,66 @@ mod server_correlation_store { assert!(store.has_active_routes_for_client("c1").await); assert!(store.has_active_routes_for_client("c2").await); } + + // ── LRU eviction (TS SDK server tests 28–30) ───────────────── + + #[tokio::test] + async fn evicts_oldest_route_when_capacity_reached() { + let store = ServerEventRouteStore::with_max_routes(2); + + store + .register("event1".into(), "client1".into(), json!("req1"), None) + .await; + store + .register("event2".into(), "client1".into(), json!("req2"), None) + .await; + store + .register("event3".into(), "client1".into(), json!("req3"), None) + .await; + + // event1 should have been evicted. + assert!(!store.has_event_route("event1").await); + assert_eq!(store.event_route_count().await, 2); + } + + #[tokio::test] + async fn cleans_up_progress_tokens_on_eviction() { + let store = ServerEventRouteStore::with_max_routes(1); + + store + .register( + "event1".into(), + "client1".into(), + json!("req1"), + Some("token1".into()), + ) + .await; + store + .register( + "event2".into(), + "client1".into(), + json!("req2"), + Some("token2".into()), + ) + .await; + + assert!(!store.has_progress_token("token1").await); + assert!(store.has_progress_token("token2").await); + } + + #[tokio::test] + async fn cleans_up_client_index_on_eviction() { + let store = ServerEventRouteStore::with_max_routes(1); + + store + .register("event1".into(), "client1".into(), json!("req1"), None) + .await; + store + .register("event2".into(), "client2".into(), json!("req2"), None) + .await; + + // client1's only route was evicted. + assert!(!store.has_active_routes_for_client("client1").await); + assert!(store.has_active_routes_for_client("client2").await); + } } From e5ccd48ede71635617a9b9b6376c2ef9795fd25d Mon Sep 17 00:00:00 2001 From: Harsh Date: Sun, 26 Apr 2026 07:01:00 +0530 Subject: [PATCH 46/69] test: port remaining integration and stateless mode tests, closing all TS SDK parity gaps --- tests/conformance_stateless_mode.rs | 26 +- tests/transport_integration.rs | 1200 ++++++++++++++++++++++++++- 2 files changed, 1222 insertions(+), 4 deletions(-) diff --git a/tests/conformance_stateless_mode.rs b/tests/conformance_stateless_mode.rs index 54b633b..a1a0a3c 100644 --- a/tests/conformance_stateless_mode.rs +++ b/tests/conformance_stateless_mode.rs @@ -3,7 +3,9 @@ use std::time::Duration; use contextvm_sdk::core::constants::{mcp_protocol_version, INITIALIZE_METHOD}; -use contextvm_sdk::core::types::{EncryptionMode, JsonRpcMessage, JsonRpcRequest}; +use contextvm_sdk::core::types::{ + EncryptionMode, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, +}; use contextvm_sdk::signer; use contextvm_sdk::transport::client::{NostrClientTransport, NostrClientTransportConfig}; use tokio::time::timeout; @@ -169,3 +171,25 @@ async fn should_handle_statelessly_returns_false_for_other_methods() { "non-initialize request should not create a local emulated response" ); } + +#[tokio::test] +async fn notifications_initialized_swallowed_in_stateless_mode() { + let (transport, mut rx) = make_stateless_transport().await; + + let notification = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/initialized".to_string(), + params: None, + }); + + transport + .send(¬ification) + .await + .expect("notifications/initialized should be accepted in stateless mode"); + + let recv_result = timeout(Duration::from_millis(200), rx.recv()).await; + assert!( + recv_result.is_err(), + "notifications/initialized must be swallowed in stateless mode" + ); +} diff --git a/tests/transport_integration.rs b/tests/transport_integration.rs index c577b79..fede992 100644 --- a/tests/transport_integration.rs +++ b/tests/transport_integration.rs @@ -9,15 +9,16 @@ use std::sync::Arc; use std::time::Duration; use contextvm_sdk::core::constants::{ - mcp_protocol_version, GIFT_WRAP_KIND, SERVER_ANNOUNCEMENT_KIND, + mcp_protocol_version, CTXVM_MESSAGES_KIND, GIFT_WRAP_KIND, PROMPTS_LIST_KIND, + RESOURCES_LIST_KIND, RESOURCETEMPLATES_LIST_KIND, SERVER_ANNOUNCEMENT_KIND, TOOLS_LIST_KIND, }; use contextvm_sdk::core::types::EncryptionMode; use contextvm_sdk::relay::mock::MockRelayPool; use contextvm_sdk::transport::client::{NostrClientTransport, NostrClientTransportConfig}; use contextvm_sdk::transport::server::{NostrServerTransport, NostrServerTransportConfig}; use contextvm_sdk::{ - JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, RelayPoolTrait, - ServerInfo, + CapabilityExclusion, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, + RelayPoolTrait, ServerInfo, }; use nostr_sdk::prelude::*; @@ -601,3 +602,1196 @@ async fn correlated_notification_has_e_tag() { "notification event must have e tag referencing the original request event id" ); } + +// ── 9. Encryption Required client, Optional server ────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn encryption_required_client_optional_server() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + encryption_mode: EncryptionMode::Optional, + ..Default::default() + }, + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Required, + ..Default::default() + }, + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + + server.start().await.expect("server start"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("enc-opt-1"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send encrypted request"); + + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout waiting for server to receive encrypted request") + .expect("server channel closed"); + + assert_eq!( + incoming.message.method(), + Some("tools/list"), + "Optional-mode server must accept encrypted messages from Required-mode client" + ); + assert!( + incoming.is_encrypted, + "message from Required-mode client must be marked encrypted" + ); +} + +// ── 10. Encryption Optional both sides, encrypted path ────────────────────── +// Optional client defaults to encrypting (unwrap_or(true)), Optional server +// accepts encrypted messages. Tests the Optional/Optional negotiation path. + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn encryption_optional_both_sides_encrypted_path() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + encryption_mode: EncryptionMode::Optional, + ..Default::default() + }, + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Optional, + ..Default::default() + }, + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + + server.start().await.expect("server start"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("opt-both-1"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send request"); + + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout waiting for server to receive request") + .expect("server channel closed"); + + assert_eq!(incoming.message.method(), Some("tools/list")); + assert!( + incoming.is_encrypted, + "Optional client defaults to encrypting; Optional server must accept" + ); +} + +// ── 11. Announce includes encryption tags ──────────────────────────────────── + +#[tokio::test] +async fn announce_includes_encryption_tags() { + let pool = Arc::new(MockRelayPool::new()); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + is_announced_server: true, + server_info: Some(ServerInfo { + name: Some("Encrypted-Server".to_string()), + ..Default::default() + }), + encryption_mode: EncryptionMode::Required, + ..Default::default() + }, + Arc::clone(&pool) as Arc, + ) + .await + .expect("create server transport"); + + server.start().await.expect("server start"); + server.announce().await.expect("server announce"); + + let events = pool.stored_events().await; + let announcement = events + .iter() + .find(|e| e.kind == Kind::Custom(SERVER_ANNOUNCEMENT_KIND)) + .expect("kind 11316 event must be published"); + + // support_encryption is a valueless tag — check tag name directly. + let has_support_encryption = announcement + .tags + .iter() + .any(|t| t.clone().to_vec().first().map(|s| s.as_str()) == Some("support_encryption")); + let has_support_encryption_ephemeral = announcement.tags.iter().any(|t| { + t.clone().to_vec().first().map(|s| s.as_str()) == Some("support_encryption_ephemeral") + }); + + assert!( + has_support_encryption, + "announcement must include support_encryption tag" + ); + assert!( + has_support_encryption_ephemeral, + "announcement must include support_encryption_ephemeral tag" + ); +} + +// ── 12. Announce includes server metadata tags ────────────────────────────── + +#[tokio::test] +async fn announce_includes_server_metadata_tags() { + let pool = Arc::new(MockRelayPool::new()); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + is_announced_server: true, + server_info: Some(ServerInfo { + name: Some("Meta-Server".to_string()), + about: Some("A test server".to_string()), + website: Some("https://example.com".to_string()), + picture: Some("https://example.com/pic.png".to_string()), + ..Default::default() + }), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + Arc::clone(&pool) as Arc, + ) + .await + .expect("create server transport"); + + server.start().await.expect("server start"); + server.announce().await.expect("server announce"); + + let events = pool.stored_events().await; + let announcement = events + .iter() + .find(|e| e.kind == Kind::Custom(SERVER_ANNOUNCEMENT_KIND)) + .expect("kind 11316 event must be published"); + + let name_tag = contextvm_sdk::core::serializers::get_tag_value(&announcement.tags, "name"); + let about_tag = contextvm_sdk::core::serializers::get_tag_value(&announcement.tags, "about"); + let website_tag = + contextvm_sdk::core::serializers::get_tag_value(&announcement.tags, "website"); + let picture_tag = + contextvm_sdk::core::serializers::get_tag_value(&announcement.tags, "picture"); + + assert_eq!( + name_tag.as_deref(), + Some("Meta-Server"), + "name tag must be present" + ); + assert_eq!( + about_tag.as_deref(), + Some("A test server"), + "about tag must be present" + ); + assert_eq!( + website_tag.as_deref(), + Some("https://example.com"), + "website tag must be present" + ); + assert_eq!( + picture_tag.as_deref(), + Some("https://example.com/pic.png"), + "picture tag must be present" + ); +} + +// ── 13. Publish tools produces correct kind ───────────────────────────────── + +#[tokio::test] +async fn publish_tools_produces_correct_kind() { + let pool = Arc::new(MockRelayPool::new()); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + is_announced_server: true, + server_info: Some(ServerInfo { + name: Some("Tools-Server".to_string()), + ..Default::default() + }), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + Arc::clone(&pool) as Arc, + ) + .await + .expect("create server transport"); + + server.start().await.expect("server start"); + server.announce().await.expect("server announce"); + + let tools = vec![serde_json::json!({ + "name": "get_weather", + "description": "Get the weather", + "inputSchema": { "type": "object" } + })]; + server.publish_tools(tools).await.expect("publish tools"); + + let events = pool.stored_events().await; + let tools_event = events + .iter() + .find(|e| e.kind == Kind::Custom(TOOLS_LIST_KIND)) + .expect("kind 11317 event must be published"); + + let content: serde_json::Value = + serde_json::from_str(&tools_event.content).expect("tools content must be JSON"); + assert!( + content.get("tools").is_some(), + "tools event content must contain 'tools' key" + ); + let tools_arr = content["tools"].as_array().expect("tools must be an array"); + assert_eq!(tools_arr.len(), 1); + assert_eq!(tools_arr[0]["name"], "get_weather"); +} + +// ── 14. Broadcast notification reaches initialized client ───────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn broadcast_notification_reaches_initialized_client() { + let (c1_pool, s_pool) = MockRelayPool::create_pair(); + let server_pk = s_pool.mock_public_key(); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(s_pool), + ) + .await + .expect("create server transport"); + + let mut srv_rx = server + .take_message_receiver() + .expect("server message receiver"); + server.start().await.expect("server start"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pk.to_hex(), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(c1_pool), + ) + .await + .expect("create client transport"); + let mut c_rx = client + .take_message_receiver() + .expect("client message receiver"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + // Client sends initialize request. + let init_req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: Some(serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "capabilities": {}, + "clientInfo": { "name": "c1", "version": "0.0.0" } + })), + }); + client + .send(&init_req) + .await + .expect("client send initialize"); + + let incoming = tokio::time::timeout(Duration::from_millis(500), srv_rx.recv()) + .await + .expect("timeout") + .expect("channel closed"); + + // Server responds to initialize. + let init_resp = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + result: serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "serverInfo": { "name": "test-server", "version": "0.0.0" }, + "capabilities": {} + }), + }); + server + .send_response(&incoming.event_id, init_resp) + .await + .expect("send init response"); + + // Client receives the init response. + let _ = tokio::time::timeout(Duration::from_millis(500), c_rx.recv()) + .await + .expect("timeout") + .expect("channel closed"); + + // Client sends notifications/initialized → session becomes initialized. + let init_notif = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/initialized".to_string(), + params: None, + }); + client + .send(&init_notif) + .await + .expect("send initialized notification"); + + // Drain srv_rx until we see notifications/initialized (skipping any + // echoed events from the shared mock relay broadcast channel). + loop { + let msg = tokio::time::timeout(Duration::from_millis(500), srv_rx.recv()) + .await + .expect("timeout waiting for notifications/initialized on server") + .expect("server channel closed"); + if msg.message.method() == Some("notifications/initialized") { + break; + } + } + + // Now broadcast — only the initialized client session should receive it. + let broadcast = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/progress".to_string(), + params: Some(serde_json::json!({ "progressToken": "bc-1", "progress": 1, "total": 1 })), + }); + server + .broadcast_notification(&broadcast) + .await + .expect("broadcast notification"); + + let msg = tokio::time::timeout(Duration::from_millis(500), c_rx.recv()) + .await + .expect("timeout waiting for client to receive broadcast") + .expect("client channel closed"); + + assert_eq!(msg.method(), Some("notifications/progress")); +} + +// ── 15. Uncorrelated notification passes through ──────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn uncorrelated_notification_passes_through() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + let mut client_rx = client + .take_message_receiver() + .expect("client message receiver"); + + server.start().await.expect("server start"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + let init_req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("unc-init"), + method: "initialize".to_string(), + params: Some(serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "capabilities": {}, + "clientInfo": { "name": "unc-test", "version": "0.0.0" } + })), + }); + client.send(&init_req).await.expect("send initialize"); + + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout") + .expect("channel closed"); + + let init_resp = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("unc-init"), + result: serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "serverInfo": { "name": "test", "version": "0.0.0" }, + "capabilities": {} + }), + }); + server + .send_response(&incoming.event_id, init_resp) + .await + .expect("send init response"); + + let _ = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .expect("timeout") + .expect("channel closed"); + + // Uncorrelated notification (no e tag) must pass through to client. + let notification = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/progress".to_string(), + params: Some(serde_json::json!({ "progressToken": "unc-1", "progress": 50, "total": 100 })), + }); + server + .send_notification(&incoming.client_pubkey, ¬ification, None) + .await + .expect("send uncorrelated notification"); + + let client_msg = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .expect("timeout waiting for client to receive notification") + .expect("client channel closed"); + + assert!(client_msg.is_notification()); + assert_eq!(client_msg.method(), Some("notifications/progress")); +} + +// ── 16. Correlated notification with unknown e tag is dropped ─────────────── +// NOTE: The Rust SDK drops ANY server event whose e-tag references an unknown +// pending request, including notifications. The TS SDK may forward such events. +// This test documents the Rust SDK's stricter correlation enforcement. + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn correlated_notification_unknown_e_tag_is_dropped() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + let mut client_rx = client + .take_message_receiver() + .expect("client message receiver"); + + server.start().await.expect("server start"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + let init_req = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("corr-init"), + method: "initialize".to_string(), + params: Some(serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "capabilities": {}, + "clientInfo": { "name": "corr-test", "version": "0.0.0" } + })), + }); + client.send(&init_req).await.expect("send initialize"); + + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout") + .expect("channel closed"); + + let init_resp = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("corr-init"), + result: serde_json::json!({ + "protocolVersion": mcp_protocol_version(), + "serverInfo": { "name": "test", "version": "0.0.0" }, + "capabilities": {} + }), + }); + server + .send_response(&incoming.event_id, init_resp) + .await + .expect("send init response"); + + let _ = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .expect("timeout") + .expect("channel closed"); + + // Notification with e tag referencing unknown event id must be dropped. + let fake_event_id = "a".repeat(64); + let notification = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/progress".to_string(), + params: Some(serde_json::json!({ "progressToken": "fake", "progress": 1, "total": 1 })), + }); + server + .send_notification(&incoming.client_pubkey, ¬ification, Some(&fake_event_id)) + .await + .expect("send notification with unknown e tag"); + + let result = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()).await; + assert!( + result.is_err(), + "notification with unknown e tag must be dropped by client" + ); +} + +// ── 17. Auth: allowed pubkey receives response ────────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn auth_allowed_pubkey_receives_response() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + let client_pubkey = client_pool.mock_public_key(); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + allowed_public_keys: vec![client_pubkey.to_hex()], + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + server.start().await.expect("server start"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut client_rx = client + .take_message_receiver() + .expect("client message receiver"); + + client.start().await.expect("client start"); + let_event_loops_start().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("auth-1"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send request"); + + // Server should receive it (pubkey is in the allowlist). + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout waiting for server to receive request") + .expect("server channel closed"); + + assert_eq!(incoming.message.method(), Some("tools/list")); + + // Server sends response back. + let response = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("auth-1"), + result: serde_json::json!({ "tools": [] }), + }); + server + .send_response(&incoming.event_id, response) + .await + .expect("send response"); + + // Client should receive the response. + let client_msg = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .expect("timeout waiting for client to receive response") + .expect("client channel closed"); + + assert!(client_msg.is_response()); + assert_eq!(client_msg.id(), Some(&serde_json::json!("auth-1"))); +} + +// ── 18. Excluded capability bypasses auth ─────────────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn excluded_capability_bypasses_auth() { + let allowed_keys = Keys::generate(); // a DIFFERENT pubkey, NOT the client + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + allowed_public_keys: vec![allowed_keys.public_key().to_hex()], + excluded_capabilities: vec![CapabilityExclusion { + method: "tools/list".to_string(), + name: None, + }], + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + server.start().await.expect("server start"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + client.start().await.expect("client start"); + let_event_loops_start().await; + + // Client's pubkey is NOT in the allowlist, but "tools/list" is excluded from auth. + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("excl-1"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send request"); + + // Server should receive it because the method is in excluded_capabilities. + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout waiting for server to receive excluded-capability request") + .expect("server channel closed"); + + assert_eq!( + incoming.message.method(), + Some("tools/list"), + "excluded capability must bypass auth allowlist" + ); +} + +// ── 19. Publish resources produces correct kind ───────────────────────────── + +#[tokio::test] +async fn publish_resources_produces_correct_kind() { + let pool = Arc::new(MockRelayPool::new()); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + Arc::clone(&pool) as Arc, + ) + .await + .expect("create server transport"); + + server.start().await.expect("server start"); + + let resources = vec![serde_json::json!({ + "uri": "file:///readme.md", + "name": "readme", + "mimeType": "text/markdown" + })]; + server + .publish_resources(resources) + .await + .expect("publish resources"); + + let events = pool.stored_events().await; + let event = events + .iter() + .find(|e| e.kind == Kind::Custom(RESOURCES_LIST_KIND)) + .expect("kind 11318 event must be published"); + + let content: serde_json::Value = + serde_json::from_str(&event.content).expect("content must be JSON"); + let arr = content["resources"] + .as_array() + .expect("resources must be an array"); + assert_eq!(arr.len(), 1); + assert_eq!(arr[0]["name"], "readme"); +} + +// ── 20. Publish prompts produces correct kind ─────────────────────────────── + +#[tokio::test] +async fn publish_prompts_produces_correct_kind() { + let pool = Arc::new(MockRelayPool::new()); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + Arc::clone(&pool) as Arc, + ) + .await + .expect("create server transport"); + + server.start().await.expect("server start"); + + let prompts = vec![serde_json::json!({ + "name": "summarize", + "description": "Summarize text" + })]; + server + .publish_prompts(prompts) + .await + .expect("publish prompts"); + + let events = pool.stored_events().await; + let event = events + .iter() + .find(|e| e.kind == Kind::Custom(PROMPTS_LIST_KIND)) + .expect("kind 11320 event must be published"); + + let content: serde_json::Value = + serde_json::from_str(&event.content).expect("content must be JSON"); + let arr = content["prompts"] + .as_array() + .expect("prompts must be an array"); + assert_eq!(arr.len(), 1); + assert_eq!(arr[0]["name"], "summarize"); +} + +// ── 21. Publish resource templates produces correct kind ──────────────────── + +#[tokio::test] +async fn publish_resource_templates_produces_correct_kind() { + let pool = Arc::new(MockRelayPool::new()); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + Arc::clone(&pool) as Arc, + ) + .await + .expect("create server transport"); + + server.start().await.expect("server start"); + + let templates = vec![serde_json::json!({ + "uriTemplate": "file:///{path}", + "name": "file", + "mimeType": "application/octet-stream" + })]; + server + .publish_resource_templates(templates) + .await + .expect("publish resource templates"); + + let events = pool.stored_events().await; + let event = events + .iter() + .find(|e| e.kind == Kind::Custom(RESOURCETEMPLATES_LIST_KIND)) + .expect("kind 11319 event must be published"); + + let content: serde_json::Value = + serde_json::from_str(&event.content).expect("content must be JSON"); + let arr = content["resourceTemplates"] + .as_array() + .expect("resourceTemplates must be an array"); + assert_eq!(arr.len(), 1); + assert_eq!(arr[0]["name"], "file"); +} + +// ── 22. Publish tools with empty list ─────────────────────────────────────── + +#[tokio::test] +async fn publish_tools_empty_list() { + let pool = Arc::new(MockRelayPool::new()); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + Arc::clone(&pool) as Arc, + ) + .await + .expect("create server transport"); + + server.start().await.expect("server start"); + server + .publish_tools(vec![]) + .await + .expect("publish empty tools"); + + let events = pool.stored_events().await; + let event = events + .iter() + .find(|e| e.kind == Kind::Custom(TOOLS_LIST_KIND)) + .expect("kind 11317 event must be published for empty list"); + + let content: serde_json::Value = + serde_json::from_str(&event.content).expect("content must be JSON"); + let arr = content["tools"].as_array().expect("tools must be an array"); + assert!(arr.is_empty(), "empty tools list must produce tools: []"); +} + +// ── 23. Delete announcements k tags match kinds ───────────────────────────── + +#[tokio::test] +async fn delete_announcements_k_tags_match_kinds() { + let pool = Arc::new(MockRelayPool::new()); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + is_announced_server: true, + server_info: Some(ServerInfo { + name: Some("KTag-Server".to_string()), + ..Default::default() + }), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + Arc::clone(&pool) as Arc, + ) + .await + .expect("create server transport"); + + server.start().await.expect("server start"); + server.announce().await.expect("server announce"); + server + .delete_announcements("shutting down") + .await + .expect("delete announcements"); + + let events = pool.stored_events().await; + let kind5_events: Vec<_> = events + .iter() + .filter(|e| e.kind == Kind::Custom(5)) + .collect(); + + assert_eq!(kind5_events.len(), 5); + + // Collect k tag values from all kind-5 events. + let mut k_values: Vec = kind5_events + .iter() + .filter_map(|e| { + contextvm_sdk::core::serializers::get_tag_value(&e.tags, "k") + .and_then(|v| v.parse::().ok()) + }) + .collect(); + k_values.sort(); + + let mut expected = vec![ + SERVER_ANNOUNCEMENT_KIND, + TOOLS_LIST_KIND, + RESOURCES_LIST_KIND, + RESOURCETEMPLATES_LIST_KIND, + PROMPTS_LIST_KIND, + ]; + expected.sort(); + + assert_eq!( + k_values, expected, + "each kind-5 event must have a k tag matching one announcement kind" + ); +} + +// ── 24. Encryption Disabled server rejects gift-wrap ──────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn encryption_disabled_server_rejects_gift_wrap() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + // Server has encryption disabled — must reject gift-wrap events. + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + server.start().await.expect("server start"); + + // Client requires encryption — sends gift-wrap (kind 1059). + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Required, + ..Default::default() + }, + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + client.start().await.expect("client start"); + let_event_loops_start().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("gw-reject"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send encrypted request"); + + let result = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()).await; + assert!( + result.is_err(), + "Disabled-mode server must drop gift-wrap events" + ); +} + +// ── 25. Response mirrors client encryption format ─────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn response_mirrors_client_encryption_format() { + // Part A: Disabled client → Optional server → response must be plaintext (kind 25910). + { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + let server_pool = Arc::new(server_pool); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + encryption_mode: EncryptionMode::Optional, + ..Default::default() + }, + Arc::clone(&server_pool) as Arc, + ) + .await + .expect("create server transport"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + let mut client_rx = client + .take_message_receiver() + .expect("client message receiver"); + + server.start().await.expect("server start"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("mirror-plain"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send plaintext request"); + + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout") + .expect("channel closed"); + assert!(!incoming.is_encrypted); + + let response = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("mirror-plain"), + result: serde_json::json!({ "tools": [] }), + }); + server + .send_response(&incoming.event_id, response) + .await + .expect("send plaintext response"); + + // Client receives the response. + let _ = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .expect("timeout") + .expect("channel closed"); + + // Verify response event is plaintext kind 25910, not gift-wrap. + let events = server_pool.stored_events().await; + let response_events: Vec<_> = events + .iter() + .filter(|e| e.pubkey == server_pubkey && e.content.contains("mirror-plain")) + .collect(); + assert!( + !response_events.is_empty(), + "server must publish a response event" + ); + assert!( + response_events + .iter() + .all(|e| e.kind == Kind::Custom(CTXVM_MESSAGES_KIND)), + "response to plaintext client must be kind {} (plaintext)", + CTXVM_MESSAGES_KIND + ); + } + + // Part B: Required client → Optional server → response must be gift-wrap (kind 1059). + { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + let server_pool = Arc::new(server_pool); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + encryption_mode: EncryptionMode::Optional, + ..Default::default() + }, + Arc::clone(&server_pool) as Arc, + ) + .await + .expect("create server transport"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Required, + ..Default::default() + }, + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + let mut client_rx = client + .take_message_receiver() + .expect("client message receiver"); + + server.start().await.expect("server start"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("mirror-enc"), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send encrypted request"); + + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout") + .expect("channel closed"); + assert!(incoming.is_encrypted); + + // Snapshot gift-wrap count before server responds. + let gw_before = server_pool + .stored_events() + .await + .iter() + .filter(|e| e.kind == Kind::Custom(GIFT_WRAP_KIND)) + .count(); + + let response = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("mirror-enc"), + result: serde_json::json!({ "tools": [] }), + }); + server + .send_response(&incoming.event_id, response) + .await + .expect("send encrypted response"); + + // Client receives the response. + let _ = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .expect("timeout") + .expect("channel closed"); + + // Verify server published exactly one new gift-wrap for the response. + let gw_after = server_pool + .stored_events() + .await + .iter() + .filter(|e| e.kind == Kind::Custom(GIFT_WRAP_KIND)) + .count(); + assert_eq!( + gw_after, + gw_before + 1, + "server must publish one new gift-wrap (kind {}) as the response", + GIFT_WRAP_KIND + ); + } +} From 96963db87b3f618406921c8a870d3047759c7c49 Mon Sep 17 00:00:00 2001 From: Kushagra Date: Sun, 26 Apr 2026 11:48:45 +0530 Subject: [PATCH 47/69] fix(lru): unbounded LRU cache initialisation --- src/transport/client/correlation_store.rs | 20 +++++++++++++--- src/transport/server/correlation_store.rs | 28 +++++++++++++++++------ 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/src/transport/client/correlation_store.rs b/src/transport/client/correlation_store.rs index 924e283..858b37a 100644 --- a/src/transport/client/correlation_store.rs +++ b/src/transport/client/correlation_store.rs @@ -6,6 +6,8 @@ use std::sync::Arc; use lru::LruCache; use tokio::sync::RwLock; +use crate::core::constants::DEFAULT_LRU_SIZE; + /// A pending request tracked by the correlation store. #[derive(Debug, Clone)] pub struct PendingRequest { @@ -32,9 +34,7 @@ impl Default for ClientCorrelationStore { impl ClientCorrelationStore { pub fn new() -> Self { - Self { - pending_requests: Arc::new(RwLock::new(LruCache::unbounded())), - } + Self::with_max_pending(DEFAULT_LRU_SIZE) } /// Create a store with an upper bound on pending requests. @@ -136,4 +136,18 @@ mod tests { assert!(store.remove("e1").await); assert!(!store.contains("e1").await); } + + #[tokio::test] + async fn default_store_is_bounded() { + let store = ClientCorrelationStore::new(); + for i in 0..=DEFAULT_LRU_SIZE { + store + .register(format!("e{i}"), serde_json::Value::Null, false) + .await; + } + + assert_eq!(store.count().await, DEFAULT_LRU_SIZE); + assert!(!store.contains("e0").await); + assert!(store.contains(&format!("e{DEFAULT_LRU_SIZE}")).await); + } } diff --git a/src/transport/server/correlation_store.rs b/src/transport/server/correlation_store.rs index 09fd16e..3d8509e 100644 --- a/src/transport/server/correlation_store.rs +++ b/src/transport/server/correlation_store.rs @@ -7,6 +7,8 @@ use std::sync::Arc; use lru::LruCache; use tokio::sync::RwLock; +use crate::core::constants::DEFAULT_LRU_SIZE; + /// A route entry for an in-flight request. #[derive(Debug, Clone)] pub struct RouteEntry { @@ -29,11 +31,9 @@ struct Inner { } impl Inner { - fn new(max_routes: Option) -> Self { - let routes = match max_routes { - Some(n) => LruCache::new(NonZeroUsize::new(n).expect("max_routes must be non-zero")), - None => LruCache::unbounded(), - }; + fn new(max_routes: usize) -> Self { + let routes = + LruCache::new(NonZeroUsize::new(max_routes).expect("max_routes must be non-zero")); Self { routes, progress_token_to_event: HashMap::new(), @@ -80,7 +80,7 @@ impl Default for ServerEventRouteStore { impl ServerEventRouteStore { pub fn new() -> Self { Self { - inner: Arc::new(RwLock::new(Inner::new(None))), + inner: Arc::new(RwLock::new(Inner::new(DEFAULT_LRU_SIZE))), } } @@ -88,7 +88,7 @@ impl ServerEventRouteStore { /// When the limit is reached the oldest entry is evicted. pub fn with_max_routes(max_routes: usize) -> Self { Self { - inner: Arc::new(RwLock::new(Inner::new(Some(max_routes)))), + inner: Arc::new(RwLock::new(Inner::new(max_routes))), } } @@ -303,4 +303,18 @@ mod tests { assert!(store.get("e1").await.is_none()); assert!(store.get("e2").await.is_none()); } + + #[tokio::test] + async fn default_store_is_bounded() { + let store = ServerEventRouteStore::new(); + for i in 0..=DEFAULT_LRU_SIZE { + store + .register(format!("e{i}"), "pk1".into(), json!(i), None) + .await; + } + + assert_eq!(store.event_route_count().await, DEFAULT_LRU_SIZE); + assert!(!store.has_event_route("e0").await); + assert!(store.has_event_route(&format!("e{DEFAULT_LRU_SIZE}")).await); + } } From 350d5696bcf042db1c8f5e06bc66ae03625280fa Mon Sep 17 00:00:00 2001 From: ContextVM-org Date: Tue, 28 Apr 2026 16:47:56 +0100 Subject: [PATCH 48/69] test: add GiftWrapMode import to stateless mode conformance test --- tests/conformance_stateless_mode.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conformance_stateless_mode.rs b/tests/conformance_stateless_mode.rs index e1c0310..7cb420e 100644 --- a/tests/conformance_stateless_mode.rs +++ b/tests/conformance_stateless_mode.rs @@ -6,7 +6,7 @@ use contextvm_sdk::core::constants::{mcp_protocol_version, INITIALIZE_METHOD}; use contextvm_sdk::core::types::{ EncryptionMode, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, }; -use contextvm_sdk::signer; +use contextvm_sdk::{GiftWrapMode, signer}; use contextvm_sdk::transport::client::{NostrClientTransport, NostrClientTransportConfig}; use tokio::time::timeout; From 98d9fd4ec03001bae67b85b8ee868138c14036c7 Mon Sep 17 00:00:00 2001 From: ContextVM Date: Tue, 28 Apr 2026 17:32:52 +0100 Subject: [PATCH 49/69] style: reorder import statements alphabetically --- tests/conformance_stateless_mode.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conformance_stateless_mode.rs b/tests/conformance_stateless_mode.rs index 7cb420e..9999e23 100644 --- a/tests/conformance_stateless_mode.rs +++ b/tests/conformance_stateless_mode.rs @@ -6,8 +6,8 @@ use contextvm_sdk::core::constants::{mcp_protocol_version, INITIALIZE_METHOD}; use contextvm_sdk::core::types::{ EncryptionMode, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, }; -use contextvm_sdk::{GiftWrapMode, signer}; use contextvm_sdk::transport::client::{NostrClientTransport, NostrClientTransportConfig}; +use contextvm_sdk::{signer, GiftWrapMode}; use tokio::time::timeout; async fn make_stateless_transport() -> ( From 7413e108f5faff11da4f884597dcea9b421a985c Mon Sep 17 00:00:00 2001 From: ContextVM Date: Wed, 29 Apr 2026 12:28:06 +0100 Subject: [PATCH 50/69] format --- sdk | 1 + tests/transport_integration.rs | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 160000 sdk diff --git a/sdk b/sdk new file mode 160000 index 0000000..5f773a2 --- /dev/null +++ b/sdk @@ -0,0 +1 @@ +Subproject commit 5f773a20d9ea4b0a5f06d1b860d3d3da7509699f diff --git a/tests/transport_integration.rs b/tests/transport_integration.rs index 8a09c2e..01ff24f 100644 --- a/tests/transport_integration.rs +++ b/tests/transport_integration.rs @@ -2006,7 +2006,7 @@ async fn send_response_publish_failure_allows_one_successful_retry() { Arc::clone(&server_pool), 1, )); - + let mut server = NostrServerTransport::with_relay_pool( NostrServerTransportConfig { encryption_mode: EncryptionMode::Disabled, From be992f0cf046c656937e32b45dd40152996ed13d Mon Sep 17 00:00:00 2001 From: Harsh Date: Wed, 29 Apr 2026 18:31:35 +0530 Subject: [PATCH 51/69] fix: send JSON-RPC -32000 Unauthorized error on announced servers with allowlist --- src/transport/server/mod.rs | 54 +++++++++ tests/transport_integration.rs | 210 +++++++++++++++++++++++++++++++++ 2 files changed, 264 insertions(+) diff --git a/src/transport/server/mod.rs b/src/transport/server/mod.rs index 5ae383e..ff07ef2 100644 --- a/src/transport/server/mod.rs +++ b/src/transport/server/mod.rs @@ -224,6 +224,7 @@ impl NostrServerTransport { let allowed = self.config.allowed_public_keys.clone(); let excluded = self.config.excluded_capabilities.clone(); let encryption_mode = self.config.encryption_mode; + let is_announced_server = self.config.is_announced_server; let seen_gift_wrap_ids = self.seen_gift_wrap_ids.clone(); tokio::spawn(async move { @@ -235,6 +236,7 @@ impl NostrServerTransport { allowed, excluded, encryption_mode, + is_announced_server, seen_gift_wrap_ids, ) .await; @@ -656,6 +658,7 @@ impl NostrServerTransport { allowed_pubkeys: Vec, excluded_capabilities: Vec, encryption_mode: EncryptionMode, + is_announced_server: bool, seen_gift_wrap_ids: Arc>>, ) { let mut notifications = relay_pool.notifications(); @@ -799,6 +802,57 @@ impl NostrServerTransport { method = method, "Unauthorized request" ); + + // On announced servers, send a JSON-RPC error back for + // Request messages so the client doesn't hang indefinitely. + if is_announced_server { + if let JsonRpcMessage::Request(ref req) = mcp_msg { + let error_response = + JsonRpcMessage::ErrorResponse(JsonRpcErrorResponse { + jsonrpc: "2.0".to_string(), + id: req.id.clone(), + error: JsonRpcError { + code: -32000, + message: "Unauthorized".to_string(), + data: None, + }, + }); + + if let Ok(client_pk) = PublicKey::from_hex(&sender_pubkey) { + let event_id_parsed = EventId::from_hex(&event_id) + .unwrap_or(EventId::all_zeros()); + let tags = BaseTransport::create_response_tags( + &client_pk, + &event_id_parsed, + ); + + let base = BaseTransport { + relay_pool: Arc::clone(&relay_pool), + encryption_mode, + is_connected: true, + }; + if let Err(e) = base + .send_mcp_message( + &error_response, + &client_pk, + CTXVM_MESSAGES_KIND, + tags, + Some(is_encrypted), + None, + ) + .await + { + tracing::error!( + target: LOG_TARGET, + error = %e, + sender_pubkey = %sender_pubkey, + "Failed to send unauthorized error response" + ); + } + } + } + } + continue; } } diff --git a/tests/transport_integration.rs b/tests/transport_integration.rs index 01ff24f..e4b0e80 100644 --- a/tests/transport_integration.rs +++ b/tests/transport_integration.rs @@ -2119,3 +2119,213 @@ async fn send_response_publish_failure_allows_one_successful_retry() { "client must receive the retried response exactly once" ); } + +// ── 28. Announced server sends unauthorized error response ─────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn announced_server_sends_unauthorized_error_response() { + let allowed_keys = Keys::generate(); // a DIFFERENT pubkey — client is NOT in the allowlist + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + // Announced server with an allowlist that does NOT include the client. + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + allowed_public_keys: vec![allowed_keys.public_key().to_hex()], + is_announced_server: true, + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + server.start().await.expect("server start"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut client_rx = client + .take_message_receiver() + .expect("client message receiver"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + // Send a non-initialize request from the unauthorized client. + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(42), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send request"); + + // The server handler must NOT receive the request (it's unauthorized). + let server_forward = tokio::time::timeout(Duration::from_millis(300), server_rx.recv()).await; + assert!( + server_forward.is_err(), + "unauthorized request must not reach the server handler" + ); + + // The client MUST receive a -32000 Unauthorized error response. + let error_msg = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .expect("timeout waiting for unauthorized error response") + .expect("client channel closed"); + + match error_msg { + JsonRpcMessage::ErrorResponse(err) => { + assert_eq!(err.error.code, -32000, "error code must be -32000"); + assert_eq!( + err.error.message, "Unauthorized", + "error message must be 'Unauthorized'" + ); + } + other => panic!( + "expected ErrorResponse, got: {:?}", + std::mem::discriminant(&other) + ), + } +} + +// ── 29. Private server silently drops unauthorized request ─────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn private_server_silently_drops_unauthorized_request() { + let allowed_keys = Keys::generate(); + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + // Private server (is_announced_server defaults to false). + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + allowed_public_keys: vec![allowed_keys.public_key().to_hex()], + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + server.start().await.expect("server start"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut client_rx = client + .take_message_receiver() + .expect("client message receiver"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + let request = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(99), + method: "tools/list".to_string(), + params: None, + }); + client.send(&request).await.expect("send request"); + + // Server handler must not receive it. + let server_forward = tokio::time::timeout(Duration::from_millis(300), server_rx.recv()).await; + assert!( + server_forward.is_err(), + "unauthorized request must not reach the server handler" + ); + + // Client must NOT receive any error response (private server silently drops). + let client_response = tokio::time::timeout(Duration::from_millis(300), client_rx.recv()).await; + assert!( + client_response.is_err(), + "private server must silently drop unauthorized requests without sending an error" + ); +} + +// ── 30. Announced server does not error on unauthorized notification ───────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn announced_server_does_not_error_on_unauthorized_notification() { + let allowed_keys = Keys::generate(); + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + allowed_public_keys: vec![allowed_keys.public_key().to_hex()], + is_announced_server: true, + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + server.start().await.expect("server start"); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut client_rx = client + .take_message_receiver() + .expect("client message receiver"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + // Send a notification (not a request) from the unauthorized client. + let notification = JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/progress".to_string(), + params: None, + }); + client.send(¬ification).await.expect("send notification"); + + // Server handler must not receive the notification. + let server_forward = tokio::time::timeout(Duration::from_millis(300), server_rx.recv()).await; + assert!( + server_forward.is_err(), + "unauthorized notification must not reach the server handler" + ); + + // Client must NOT receive an error (notifications never get error replies). + let client_response = tokio::time::timeout(Duration::from_millis(300), client_rx.recv()).await; + assert!( + client_response.is_err(), + "announced server must not send error response for unauthorized notifications" + ); +} From 4280b1e89276d3cd5bb3a7a9ef595675e2e9d63d Mon Sep 17 00:00:00 2001 From: Harsh Date: Sun, 26 Apr 2026 12:54:47 +0530 Subject: [PATCH 52/69] feat: add discovery tags module, PeerCapabilities, and server config foundation for CEP-35 --- src/core/constants.rs | 7 + src/gateway/mod.rs | 1 + src/lib.rs | 1 + src/transport/discovery_tags.rs | 357 ++++++++++++++++++++++ src/transport/mod.rs | 2 + src/transport/server/correlation_store.rs | 4 + src/transport/server/mod.rs | 3 + 7 files changed, 375 insertions(+) create mode 100644 src/transport/discovery_tags.rs diff --git a/src/core/constants.rs b/src/core/constants.rs index f610637..dd8cc08 100644 --- a/src/core/constants.rs +++ b/src/core/constants.rs @@ -64,6 +64,9 @@ pub mod tags { /// Support ephemeral gift wrap kind (21059) for encrypted messages (CEP-19) pub const SUPPORT_ENCRYPTION_EPHEMERAL: &str = "support_encryption_ephemeral"; + + /// Support CEP-22 oversized payload transfer via notifications/progress framing + pub const SUPPORT_OVERSIZED_TRANSFER: &str = "support_oversized_transfer"; } /// Maximum message size (1MB) @@ -162,6 +165,10 @@ mod tests { tags::SUPPORT_ENCRYPTION_EPHEMERAL, "support_encryption_ephemeral" ); + assert_eq!( + tags::SUPPORT_OVERSIZED_TRANSFER, + "support_oversized_transfer" + ); } #[test] diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index ea388b7..e9d00f2 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -117,6 +117,7 @@ mod tests { let nostr_config = NostrServerTransportConfig { relay_urls: vec!["wss://relay.example.com".to_string()], encryption_mode: EncryptionMode::Required, + gift_wrap_mode: GiftWrapMode::Optional, server_info: Some(ServerInfo { name: Some("Test Gateway".to_string()), version: Some("1.0.0".to_string()), diff --git a/src/lib.rs b/src/lib.rs index 828439f..575e356 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -62,6 +62,7 @@ pub use relay::{RelayPool, RelayPoolTrait}; pub use transport::client::{ ClientCorrelationStore, NostrClientTransport, NostrClientTransportConfig, }; +pub use transport::discovery_tags::{DiscoveredPeerCapabilities, PeerCapabilities}; pub use transport::server::{ IncomingRequest, NostrServerTransport, NostrServerTransportConfig, RouteEntry, ServerEventRouteStore, SessionSnapshot, SessionStore, diff --git a/src/transport/discovery_tags.rs b/src/transport/discovery_tags.rs new file mode 100644 index 0000000..20df62c --- /dev/null +++ b/src/transport/discovery_tags.rs @@ -0,0 +1,357 @@ +//! Discovery tag utilities for CEP-35 capability exchange. +//! +//! Ports the TS SDK's `discovery-tags.ts` module. Provides functions to filter, +//! parse, learn, and merge discovery tags on Nostr events exchanged between +//! MCP clients and servers. + +use std::collections::HashSet; + +use nostr_sdk::prelude::*; + +use crate::core::constants::tags; + +/// Routing tag names that are excluded from discovery tags. +const NON_DISCOVERY_TAG_NAMES: &[&str] = &["p", "e"]; + +/// Capability flags learned from inbound peer discovery tags. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct PeerCapabilities { + /// Peer supports NIP-44/NIP-59 encrypted messaging. + pub supports_encryption: bool, + /// Peer supports ephemeral gift wraps (kind 21059, CEP-19). + pub supports_ephemeral_encryption: bool, + /// Peer supports CEP-22 oversized payload transfer. + pub supports_oversized_transfer: bool, +} + +/// Returns `true` when the tag list contains a single-valued tag whose name matches `name`. +/// +/// A single-valued tag is a tag array whose only element is the tag name itself, +/// e.g. `["support_encryption"]`. +pub fn has_single_tag(tags: &[Tag], name: &str) -> bool { + tags.iter().any(|tag| { + let v = tag.clone().to_vec(); + v.len() == 1 && v[0] == name + }) +} + +/// Filters out routing tags (`p`, `e`) and returns cloned discovery tags. +/// +/// Mirrors TS SDK `getDiscoveryTags()`. +pub fn get_discovery_tags(tags: &[Tag]) -> Vec { + tags.iter() + .filter(|tag| { + let v = (*tag).clone().to_vec(); + match v.first() { + Some(name) => !NON_DISCOVERY_TAG_NAMES.contains(&name.as_str()), + None => false, + } + }) + .cloned() + .collect() +} + +/// Inspects tags and returns discovered peer capabilities. +/// +/// Mirrors TS SDK `learnPeerCapabilities()`. +pub fn learn_peer_capabilities(tags: &[Tag]) -> PeerCapabilities { + PeerCapabilities { + supports_encryption: has_single_tag(tags, tags::SUPPORT_ENCRYPTION), + supports_ephemeral_encryption: has_single_tag(tags, tags::SUPPORT_ENCRYPTION_EPHEMERAL), + supports_oversized_transfer: has_single_tag(tags, tags::SUPPORT_OVERSIZED_TRANSFER), + } +} + +/// Parsed capability flags together with the raw discovery tags. +#[derive(Debug, Clone)] +pub struct DiscoveredPeerCapabilities { + /// The filtered discovery tags (routing tags stripped). + pub discovery_tags: Vec, + /// Parsed capability flags. + pub capabilities: PeerCapabilities, +} + +/// Parses peer discovery tags into normalized capability flags plus the raw +/// discovery tags for storage/forwarding. +/// +/// Mirrors TS SDK `parseDiscoveredPeerCapabilities()`. +pub fn parse_discovered_peer_capabilities(tags: &[Tag]) -> DiscoveredPeerCapabilities { + let discovery_tags = get_discovery_tags(tags); + let capabilities = learn_peer_capabilities(&discovery_tags); + DiscoveredPeerCapabilities { + discovery_tags, + capabilities, + } +} + +/// Merges incoming discovery tags into the current set, preserving order and +/// deduplicating by full tag content (all elements must match). +/// +/// Mirrors TS SDK `mergeDiscoveryTags()`. +pub fn merge_discovery_tags(current: &[Tag], incoming: &[Tag]) -> Vec { + let mut merged: Vec = current.to_vec(); + let mut seen: HashSet> = merged.iter().map(|t| t.clone().to_vec()).collect(); + + for tag in incoming { + let key = tag.clone().to_vec(); + if seen.insert(key) { + merged.push(tag.clone()); + } + } + + merged +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_tag(parts: &[&str]) -> Tag { + let kind = TagKind::Custom(parts[0].into()); + let values: Vec = parts[1..].iter().map(|s| s.to_string()).collect(); + Tag::custom(kind, values) + } + + fn tag_name(tag: &Tag) -> String { + tag.clone().to_vec()[0].clone() + } + + // ── has_single_tag ────────────────────────────────────────────── + + #[test] + fn has_single_tag_finds_present() { + let tags = vec![make_tag(&["support_encryption"])]; + assert!(has_single_tag(&tags, "support_encryption")); + } + + #[test] + fn has_single_tag_ignores_multi_value() { + let tags = vec![make_tag(&["support_encryption", "extra"])]; + assert!(!has_single_tag(&tags, "support_encryption")); + } + + #[test] + fn has_single_tag_returns_false_when_absent() { + let tags = vec![make_tag(&["other_tag"])]; + assert!(!has_single_tag(&tags, "support_encryption")); + } + + #[test] + fn has_single_tag_empty_tags() { + assert!(!has_single_tag(&[], "support_encryption")); + } + + // ── get_discovery_tags ────────────────────────────────────────── + + #[test] + fn get_discovery_tags_filters_routing_tags() { + let tags = vec![ + Tag::public_key(Keys::generate().public_key()), + Tag::event(EventId::all_zeros()), + make_tag(&["support_encryption"]), + make_tag(&["name", "My Server"]), + ]; + let discovery = get_discovery_tags(&tags); + assert_eq!(discovery.len(), 2); + assert_eq!(tag_name(&discovery[0]), "support_encryption"); + assert_eq!(tag_name(&discovery[1]), "name"); + } + + #[test] + fn get_discovery_tags_empty_input() { + let discovery = get_discovery_tags(&[]); + assert!(discovery.is_empty()); + } + + #[test] + fn get_discovery_tags_all_routing() { + let tags = vec![ + Tag::public_key(Keys::generate().public_key()), + Tag::event(EventId::all_zeros()), + ]; + let discovery = get_discovery_tags(&tags); + assert!(discovery.is_empty()); + } + + #[test] + fn get_discovery_tags_preserves_order() { + let tags = vec![ + make_tag(&["about", "hello"]), + Tag::public_key(Keys::generate().public_key()), + make_tag(&["website", "https://example.com"]), + make_tag(&["support_encryption"]), + ]; + let discovery = get_discovery_tags(&tags); + assert_eq!(discovery.len(), 3); + assert_eq!(tag_name(&discovery[0]), "about"); + assert_eq!(tag_name(&discovery[1]), "website"); + assert_eq!(tag_name(&discovery[2]), "support_encryption"); + } + + // ── learn_peer_capabilities ───────────────────────────────────── + + #[test] + fn learn_peer_capabilities_all_present() { + let tags = vec![ + make_tag(&["support_encryption"]), + make_tag(&["support_encryption_ephemeral"]), + make_tag(&["support_oversized_transfer"]), + ]; + let caps = learn_peer_capabilities(&tags); + assert!(caps.supports_encryption); + assert!(caps.supports_ephemeral_encryption); + assert!(caps.supports_oversized_transfer); + } + + #[test] + fn learn_peer_capabilities_none_present() { + let tags = vec![make_tag(&["name", "Server"])]; + let caps = learn_peer_capabilities(&tags); + assert!(!caps.supports_encryption); + assert!(!caps.supports_ephemeral_encryption); + assert!(!caps.supports_oversized_transfer); + } + + #[test] + fn learn_peer_capabilities_partial() { + let tags = vec![make_tag(&["support_encryption"])]; + let caps = learn_peer_capabilities(&tags); + assert!(caps.supports_encryption); + assert!(!caps.supports_ephemeral_encryption); + assert!(!caps.supports_oversized_transfer); + } + + #[test] + fn learn_peer_capabilities_empty() { + let caps = learn_peer_capabilities(&[]); + assert_eq!(caps, PeerCapabilities::default()); + } + + #[test] + fn learn_peer_capabilities_ignores_multi_value_capability_tags() { + // Tags with values (e.g. ["support_encryption", "extra"]) are not + // single-valued and should not be treated as capability flags. + let tags = vec![ + make_tag(&["support_encryption", "yes"]), + make_tag(&["support_encryption_ephemeral"]), + ]; + let caps = learn_peer_capabilities(&tags); + assert!(!caps.supports_encryption); + assert!(caps.supports_ephemeral_encryption); + assert!(!caps.supports_oversized_transfer); + } + + // ── parse_discovered_peer_capabilities ────────────────────────── + + #[test] + fn parse_discovered_peer_capabilities_filters_and_parses() { + let tags = vec![ + Tag::public_key(Keys::generate().public_key()), + Tag::event(EventId::all_zeros()), + make_tag(&["support_encryption"]), + make_tag(&["support_encryption_ephemeral"]), + make_tag(&["name", "Test Server"]), + ]; + let result = parse_discovered_peer_capabilities(&tags); + + // Routing tags filtered out + assert_eq!(result.discovery_tags.len(), 3); + + // Capabilities parsed correctly + assert!(result.capabilities.supports_encryption); + assert!(result.capabilities.supports_ephemeral_encryption); + assert!(!result.capabilities.supports_oversized_transfer); + } + + #[test] + fn parse_discovered_peer_capabilities_empty() { + let result = parse_discovered_peer_capabilities(&[]); + assert!(result.discovery_tags.is_empty()); + assert_eq!(result.capabilities, PeerCapabilities::default()); + } + + // ── merge_discovery_tags ──────────────────────────────────────── + + #[test] + fn merge_discovery_tags_appends_new() { + let current = vec![make_tag(&["support_encryption"])]; + let incoming = vec![make_tag(&["name", "Server"])]; + let merged = merge_discovery_tags(¤t, &incoming); + assert_eq!(merged.len(), 2); + assert_eq!(tag_name(&merged[0]), "support_encryption"); + assert_eq!(tag_name(&merged[1]), "name"); + } + + #[test] + fn merge_discovery_tags_deduplicates() { + let tag_a = make_tag(&["support_encryption"]); + let tag_b = make_tag(&["support_encryption"]); + let current = vec![tag_a]; + let incoming = vec![tag_b]; + let merged = merge_discovery_tags(¤t, &incoming); + assert_eq!(merged.len(), 1); + } + + #[test] + fn merge_discovery_tags_deduplicates_with_values() { + let current = vec![make_tag(&["name", "Server"])]; + let incoming = vec![ + make_tag(&["name", "Server"]), // exact dup + make_tag(&["name", "Other Server"]), // same tag name, different value — not a dup + ]; + let merged = merge_discovery_tags(¤t, &incoming); + assert_eq!(merged.len(), 2); + } + + #[test] + fn merge_discovery_tags_preserves_order() { + let current = vec![make_tag(&["b_tag"]), make_tag(&["a_tag"])]; + let incoming = vec![make_tag(&["c_tag"]), make_tag(&["a_tag"])]; + let merged = merge_discovery_tags(¤t, &incoming); + assert_eq!(merged.len(), 3); + assert_eq!(tag_name(&merged[0]), "b_tag"); + assert_eq!(tag_name(&merged[1]), "a_tag"); + assert_eq!(tag_name(&merged[2]), "c_tag"); + } + + #[test] + fn merge_discovery_tags_both_empty() { + let merged = merge_discovery_tags(&[], &[]); + assert!(merged.is_empty()); + } + + #[test] + fn merge_discovery_tags_current_empty() { + let incoming = vec![make_tag(&["support_encryption"])]; + let merged = merge_discovery_tags(&[], &incoming); + assert_eq!(merged.len(), 1); + } + + #[test] + fn merge_discovery_tags_incoming_empty() { + let current = vec![make_tag(&["support_encryption"])]; + let merged = merge_discovery_tags(¤t, &[]); + assert_eq!(merged.len(), 1); + } + + // ── PeerCapabilities ──────────────────────────────────────────── + + #[test] + fn peer_capabilities_default_all_false() { + let caps = PeerCapabilities::default(); + assert!(!caps.supports_encryption); + assert!(!caps.supports_ephemeral_encryption); + assert!(!caps.supports_oversized_transfer); + } + + #[test] + fn peer_capabilities_copy_semantics() { + let caps = PeerCapabilities { + supports_encryption: true, + supports_ephemeral_encryption: true, + supports_oversized_transfer: false, + }; + let copy = caps; + assert_eq!(caps, copy); + } +} diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 0a31f45..a5d53e8 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -5,7 +5,9 @@ pub mod base; pub mod client; +pub mod discovery_tags; pub mod server; pub use client::{ClientCorrelationStore, NostrClientTransport, NostrClientTransportConfig}; +pub use discovery_tags::*; pub use server::{NostrServerTransport, NostrServerTransportConfig, ServerEventRouteStore}; diff --git a/src/transport/server/correlation_store.rs b/src/transport/server/correlation_store.rs index 3d8509e..cad616e 100644 --- a/src/transport/server/correlation_store.rs +++ b/src/transport/server/correlation_store.rs @@ -18,6 +18,9 @@ pub struct RouteEntry { pub original_request_id: serde_json::Value, /// Optional progress token for this request. pub progress_token: Option, + /// The outer gift-wrap event kind that carried this request (e.g. 1059 or 21059). + /// Populated from the inbound event in a later PR; `None` until then. + pub wrap_kind: Option, } /// Internal state behind the lock. @@ -123,6 +126,7 @@ impl ServerEventRouteStore { client_pubkey, original_request_id, progress_token, + wrap_kind: None, }, ); diff --git a/src/transport/server/mod.rs b/src/transport/server/mod.rs index ff07ef2..fa09762 100644 --- a/src/transport/server/mod.rs +++ b/src/transport/server/mod.rs @@ -35,6 +35,8 @@ pub struct NostrServerTransportConfig { pub relay_urls: Vec, /// Encryption mode. pub encryption_mode: EncryptionMode, + /// Gift-wrap kind selection policy (CEP-19). + pub gift_wrap_mode: GiftWrapMode, /// Server information for announcements. pub server_info: Option, /// Whether this server publishes public announcements (CEP-6). @@ -56,6 +58,7 @@ impl Default for NostrServerTransportConfig { Self { relay_urls: vec!["wss://relay.damus.io".to_string()], encryption_mode: EncryptionMode::Optional, + gift_wrap_mode: GiftWrapMode::Optional, server_info: None, is_announced_server: false, allowed_public_keys: Vec::new(), From 1e495725869e5a8bc1886855279b068cad32b970 Mon Sep 17 00:00:00 2001 From: Harsh Date: Thu, 30 Apr 2026 17:56:10 +0530 Subject: [PATCH 53/69] refactor: remove merge_discovery_tags per CEP-35 spec update --- src/transport/discovery_tags.rs | 88 +-------------------------------- 1 file changed, 2 insertions(+), 86 deletions(-) diff --git a/src/transport/discovery_tags.rs b/src/transport/discovery_tags.rs index 20df62c..ca20a9c 100644 --- a/src/transport/discovery_tags.rs +++ b/src/transport/discovery_tags.rs @@ -1,10 +1,8 @@ //! Discovery tag utilities for CEP-35 capability exchange. //! //! Ports the TS SDK's `discovery-tags.ts` module. Provides functions to filter, -//! parse, learn, and merge discovery tags on Nostr events exchanged between -//! MCP clients and servers. - -use std::collections::HashSet; +//! parse, and learn discovery tags on Nostr events exchanged between MCP clients +//! and servers. use nostr_sdk::prelude::*; @@ -84,24 +82,6 @@ pub fn parse_discovered_peer_capabilities(tags: &[Tag]) -> DiscoveredPeerCapabil } } -/// Merges incoming discovery tags into the current set, preserving order and -/// deduplicating by full tag content (all elements must match). -/// -/// Mirrors TS SDK `mergeDiscoveryTags()`. -pub fn merge_discovery_tags(current: &[Tag], incoming: &[Tag]) -> Vec { - let mut merged: Vec = current.to_vec(); - let mut seen: HashSet> = merged.iter().map(|t| t.clone().to_vec()).collect(); - - for tag in incoming { - let key = tag.clone().to_vec(); - if seen.insert(key) { - merged.push(tag.clone()); - } - } - - merged -} - #[cfg(test)] mod tests { use super::*; @@ -270,70 +250,6 @@ mod tests { assert_eq!(result.capabilities, PeerCapabilities::default()); } - // ── merge_discovery_tags ──────────────────────────────────────── - - #[test] - fn merge_discovery_tags_appends_new() { - let current = vec![make_tag(&["support_encryption"])]; - let incoming = vec![make_tag(&["name", "Server"])]; - let merged = merge_discovery_tags(¤t, &incoming); - assert_eq!(merged.len(), 2); - assert_eq!(tag_name(&merged[0]), "support_encryption"); - assert_eq!(tag_name(&merged[1]), "name"); - } - - #[test] - fn merge_discovery_tags_deduplicates() { - let tag_a = make_tag(&["support_encryption"]); - let tag_b = make_tag(&["support_encryption"]); - let current = vec![tag_a]; - let incoming = vec![tag_b]; - let merged = merge_discovery_tags(¤t, &incoming); - assert_eq!(merged.len(), 1); - } - - #[test] - fn merge_discovery_tags_deduplicates_with_values() { - let current = vec![make_tag(&["name", "Server"])]; - let incoming = vec![ - make_tag(&["name", "Server"]), // exact dup - make_tag(&["name", "Other Server"]), // same tag name, different value — not a dup - ]; - let merged = merge_discovery_tags(¤t, &incoming); - assert_eq!(merged.len(), 2); - } - - #[test] - fn merge_discovery_tags_preserves_order() { - let current = vec![make_tag(&["b_tag"]), make_tag(&["a_tag"])]; - let incoming = vec![make_tag(&["c_tag"]), make_tag(&["a_tag"])]; - let merged = merge_discovery_tags(¤t, &incoming); - assert_eq!(merged.len(), 3); - assert_eq!(tag_name(&merged[0]), "b_tag"); - assert_eq!(tag_name(&merged[1]), "a_tag"); - assert_eq!(tag_name(&merged[2]), "c_tag"); - } - - #[test] - fn merge_discovery_tags_both_empty() { - let merged = merge_discovery_tags(&[], &[]); - assert!(merged.is_empty()); - } - - #[test] - fn merge_discovery_tags_current_empty() { - let incoming = vec![make_tag(&["support_encryption"])]; - let merged = merge_discovery_tags(&[], &incoming); - assert_eq!(merged.len(), 1); - } - - #[test] - fn merge_discovery_tags_incoming_empty() { - let current = vec![make_tag(&["support_encryption"])]; - let merged = merge_discovery_tags(¤t, &[]); - assert_eq!(merged.len(), 1); - } - // ── PeerCapabilities ──────────────────────────────────────────── #[test] From f0ea0976746b08fd1e188a3f2041dde0c5a73cd2 Mon Sep 17 00:00:00 2001 From: Harsh Date: Thu, 30 Apr 2026 20:12:05 +0530 Subject: [PATCH 54/69] fix: register client request before publish and drop uncorrelated responses --- src/transport/base.rs | 52 +++++++++++++++++++++++++++++++++++++ src/transport/client/mod.rs | 41 ++++++++++++++++++++++++++--- 2 files changed, 90 insertions(+), 3 deletions(-) diff --git a/src/transport/base.rs b/src/transport/base.rs index ccd7e03..781d58b 100644 --- a/src/transport/base.rs +++ b/src/transport/base.rs @@ -101,6 +101,58 @@ impl BaseTransport { self.relay_pool.sign(builder).await } + /// Prepare an MCP message for publishing without actually publishing it. + /// + /// Signs (and optionally gift-wraps) the event, returning the inner signed + /// event ID together with the final event that should be published to relays. + pub async fn prepare_mcp_message( + &self, + message: &JsonRpcMessage, + recipient: &PublicKey, + kind: u16, + tags: Vec, + is_encrypted: Option, + gift_wrap_kind: Option, + ) -> Result<(EventId, Event)> { + let should_encrypt = self.should_encrypt(kind, is_encrypted); + + let event = self.create_signed_event(message, kind, tags).await?; + let signed_event_id = event.id; + + if should_encrypt { + let event_json = + serde_json::to_string(&event).map_err(|e| Error::Encryption(e.to_string()))?; + let signer = self + .relay_pool + .signer() + .await + .map_err(|e| Error::Encryption(e.to_string()))?; + let selected_gift_wrap_kind = gift_wrap_kind.unwrap_or(GIFT_WRAP_KIND); + let gift_wrap_event = encryption::gift_wrap_single_layer_with_kind( + &signer, + recipient, + &event_json, + selected_gift_wrap_kind, + ) + .await?; + tracing::debug!( + target: LOG_TARGET, + signed_event_id = %signed_event_id, + envelope_id = %gift_wrap_event.id, + gift_wrap_kind = selected_gift_wrap_kind, + "Prepared encrypted MCP message" + ); + Ok((signed_event_id, gift_wrap_event)) + } else { + tracing::debug!( + target: LOG_TARGET, + signed_event_id = %signed_event_id, + "Prepared unencrypted MCP message" + ); + Ok((signed_event_id, event)) + } + } + /// Send an MCP message to a recipient, optionally encrypting. /// /// Returns the signed MCP event ID. diff --git a/src/transport/client/mod.rs b/src/transport/client/mod.rs index c3646f9..675be02 100644 --- a/src/transport/client/mod.rs +++ b/src/transport/client/mod.rs @@ -274,9 +274,10 @@ impl NostrClientTransport { } let tags = BaseTransport::create_recipient_tags(&self.server_pubkey); - let event_id = self + + let (event_id, publishable_event) = self .base - .send_mcp_message( + .prepare_mcp_message( message, &self.server_pubkey, CTXVM_MESSAGES_KIND, @@ -291,7 +292,7 @@ impl NostrClientTransport { error = %error, server_pubkey = %self.server_pubkey.to_hex(), method = ?message.method(), - "Failed to send client message" + "Failed to prepare client message" ); error })?; @@ -303,6 +304,18 @@ impl NostrClientTransport { .await; } + if let Err(error) = self.base.relay_pool.publish_event(&publishable_event).await { + self.pending_requests.remove(&event_id.to_hex()).await; + tracing::error!( + target: LOG_TARGET, + error = %error, + server_pubkey = %self.server_pubkey.to_hex(), + method = ?message.method(), + "Failed to publish client message" + ); + return Err(error); + } + tracing::debug!( target: LOG_TARGET, event_id = %event_id.to_hex(), @@ -502,6 +515,28 @@ impl NostrClientTransport { // Parse MCP message if let Some(mcp_msg) = validation::validate_and_parse(&actual_event_content) { + // Drop uncorrelated responses and server-to-client requests (matches TS SDK). + match &mcp_msg { + JsonRpcMessage::Response(_) | JsonRpcMessage::ErrorResponse(_) + if e_tag.is_none() => + { + tracing::warn!( + target: LOG_TARGET, + "Dropping response/error without correlation `e` tag" + ); + continue; + } + JsonRpcMessage::Request(_) => { + tracing::warn!( + target: LOG_TARGET, + method = ?mcp_msg.method(), + "Dropping server-to-client request (invalid in MCP)" + ); + continue; + } + _ => {} + } + // Clean up pending request if let Some(ref correlated_id) = e_tag { pending.remove(correlated_id.as_str()).await; From c499b8d6b5914f642c70939e1a0788fd1648b393 Mon Sep 17 00:00:00 2001 From: Kushagra Date: Sat, 2 May 2026 10:07:46 +0530 Subject: [PATCH 55/69] feat(cep19): added missing test cases --- src/transport/server/mod.rs | 30 ++++++++++++++----- tests/transport_integration.rs | 54 +++++++++++++++++++++++++--------- 2 files changed, 62 insertions(+), 22 deletions(-) diff --git a/src/transport/server/mod.rs b/src/transport/server/mod.rs index c56a8df..5a4e68f 100644 --- a/src/transport/server/mod.rs +++ b/src/transport/server/mod.rs @@ -943,7 +943,7 @@ impl NostrServerTransport { let has_sent = sessions .get_session(&sender_pubkey) .await - .map_or(false, |s| s.has_sent_common_tags); + .is_some_and(|s| s.has_sent_common_tags); if !has_sent { Self::append_common_response_tags( &mut tags, @@ -1546,7 +1546,8 @@ mod tests { } #[test] - fn test_select_outbound_notification_gift_wrap_kind_falls_back_to_mode_if_correlated_not_allowed() { + fn test_select_outbound_notification_gift_wrap_kind_falls_back_to_mode_if_correlated_not_allowed( + ) { assert_eq!( NostrServerTransport::select_outbound_notification_gift_wrap_kind( GiftWrapMode::Ephemeral, @@ -1572,7 +1573,8 @@ mod tests { } #[test] - fn test_select_outbound_notification_gift_wrap_kind_uses_persistent_if_ephemeral_supported_but_mode_persistent() { + fn test_select_outbound_notification_gift_wrap_kind_uses_persistent_if_ephemeral_supported_but_mode_persistent( + ) { assert_eq!( NostrServerTransport::select_outbound_notification_gift_wrap_kind( GiftWrapMode::Persistent, @@ -1585,7 +1587,8 @@ mod tests { } #[test] - fn test_select_outbound_notification_gift_wrap_kind_uses_default_mode_if_ephemeral_not_supported() { + fn test_select_outbound_notification_gift_wrap_kind_uses_default_mode_if_ephemeral_not_supported( + ) { assert_eq!( NostrServerTransport::select_outbound_notification_gift_wrap_kind( GiftWrapMode::Optional, @@ -1609,7 +1612,9 @@ mod tests { ); let kinds: Vec = tags.iter().map(|t| format!("{:?}", t.kind())).collect(); assert!( - kinds.iter().any(|k| k.contains("support_encryption_ephemeral")), + kinds + .iter() + .any(|k| k.contains("support_encryption_ephemeral")), "should include support_encryption_ephemeral tag" ); } @@ -1628,14 +1633,20 @@ mod tests { EncryptionMode::Disabled, GiftWrapMode::Optional, ); - let tag_value = crate::core::serializers::get_tag_value(&tags, "name"); + let tag_value = tags + .iter() + .find(|t| (*t).clone().to_vec().first().map(|s| s.as_str()) == Some("name")) + .and_then(|t| t.clone().to_vec().get(1).cloned()); assert_eq!(tag_value.as_deref(), Some("TestServer")); } #[test] fn test_append_common_response_tags_extra_tags() { let mut tags = Vec::new(); - let extra_tags = vec![Tag::custom(TagKind::Custom("custom_tag".into()), vec!["value".to_string()])]; + let extra_tags = vec![Tag::custom( + TagKind::Custom("custom_tag".into()), + vec!["value".to_string()], + )]; NostrServerTransport::append_common_response_tags( &mut tags, None, @@ -1643,7 +1654,10 @@ mod tests { EncryptionMode::Disabled, GiftWrapMode::Optional, ); - let tag_value = crate::core::serializers::get_tag_value(&tags, "custom_tag"); + let tag_value = tags + .iter() + .find(|t| (*t).clone().to_vec().first().map(|s| s.as_str()) == Some("custom_tag")) + .and_then(|t| t.clone().to_vec().get(1).cloned()); assert_eq!(tag_value.as_deref(), Some("value")); } } diff --git a/tests/transport_integration.rs b/tests/transport_integration.rs index ac54d0f..a51ee8b 100644 --- a/tests/transport_integration.rs +++ b/tests/transport_integration.rs @@ -2406,10 +2406,15 @@ async fn first_response_includes_discovery_tags() { }); client.send(&request2).await.expect("send request 2"); - let incoming2 = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) - .await - .expect("timeout") - .expect("channel closed"); + let incoming2 = loop { + let msg = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout") + .expect("channel closed"); + if msg.message.is_request() && msg.message.id() == Some(&serde_json::json!("req-2")) { + break msg; + } + }; let response2 = JsonRpcMessage::Response(JsonRpcResponse { jsonrpc: "2.0".to_string(), @@ -2422,20 +2427,32 @@ async fn first_response_includes_discovery_tags() { .expect("send response 2"); let events = s_pool.stored_events().await; - let mut responses: Vec<_> = events + let responses: Vec<_> = events .iter() .filter(|e| e.kind == Kind::Custom(contextvm_sdk::core::constants::CTXVM_MESSAGES_KIND)) .cloned() .collect(); - - let resp1 = responses.iter().find(|e| e.content.contains("req-1")).expect("resp1 missing"); - let resp2 = responses.iter().find(|e| e.content.contains("req-2")).expect("resp2 missing"); + + let resp1 = responses + .iter() + .find(|e| e.content.contains("req-1") && e.content.contains("result")) + .expect("resp1 missing"); + let resp2 = responses + .iter() + .find(|e| e.content.contains("req-2") && e.content.contains("result")) + .expect("resp2 missing"); let name1 = contextvm_sdk::core::serializers::get_tag_value(&resp1.tags, "name"); - let enc1 = resp1.tags.iter().any(|t| t.clone().to_vec().first().map(|s| s.as_str()) == Some("support_encryption")); + let enc1 = resp1 + .tags + .iter() + .any(|t| t.clone().to_vec().first().map(|s| s.as_str()) == Some("support_encryption")); let name2 = contextvm_sdk::core::serializers::get_tag_value(&resp2.tags, "name"); - let enc2 = resp2.tags.iter().any(|t| t.clone().to_vec().first().map(|s| s.as_str()) == Some("support_encryption")); + let enc2 = resp2 + .tags + .iter() + .any(|t| t.clone().to_vec().first().map(|s| s.as_str()) == Some("support_encryption")); assert_eq!(name1.as_deref(), Some("Disco-Server")); assert!(enc1); @@ -2504,7 +2521,11 @@ async fn notification_mirror_selection_wrt_cep_19() { params: None, }); server - .send_notification(&incoming1.client_pubkey, ¬ification, Some(&incoming1.event_id)) + .send_notification( + &incoming1.client_pubkey, + ¬ification, + Some(&incoming1.event_id), + ) .await .expect("send notification"); @@ -2512,10 +2533,15 @@ async fn notification_mirror_selection_wrt_cep_19() { let events = s_pool.stored_events().await; let ephemeral_wraps: Vec<_> = events .iter() - .filter(|e| e.kind == Kind::Custom(contextvm_sdk::core::constants::EPHEMERAL_GIFT_WRAP_KIND)) + .filter(|e| { + e.kind == Kind::Custom(contextvm_sdk::core::constants::EPHEMERAL_GIFT_WRAP_KIND) + }) .cloned() .collect(); - + // 1 from client (request), 1 from server (notification). The client also sends other msgs? - assert!(ephemeral_wraps.len() >= 2, "Expected ephemeral wraps for both request and notification"); + assert!( + ephemeral_wraps.len() >= 2, + "Expected ephemeral wraps for both request and notification" + ); } From 14d5328ba8ae6f15f0694fb218afe8a24c9a6f0c Mon Sep 17 00:00:00 2001 From: Harsh Date: Sat, 2 May 2026 01:25:42 +0530 Subject: [PATCH 56/69] feat: enrich ClientSession, add tag composition, server discovery tag emission and capability learning --- src/core/types.rs | 13 +- src/transport/base.rs | 68 +++++ src/transport/server/mod.rs | 167 +++++++++--- src/transport/server/session_store.rs | 120 +++++++++ tests/transport_integration.rs | 350 ++++++++++++++++++++++++-- 5 files changed, 663 insertions(+), 55 deletions(-) diff --git a/src/core/types.rs b/src/core/types.rs index 2df5818..d66311e 100644 --- a/src/core/types.rs +++ b/src/core/types.rs @@ -91,10 +91,16 @@ pub struct ClientSession { pub is_initialized: bool, /// Whether the client's messages were encrypted. pub is_encrypted: bool, - /// Whether common discovery tags have been sent to this client. + /// Whether server discovery tags have been sent to this client (one-shot flag). pub has_sent_common_tags: bool, - /// Whether the client has demonstrated support for ephemeral gift wraps. + /// Whether the client has demonstrated support for ephemeral gift wraps (CEP-19). pub supports_ephemeral_gift_wrap: bool, + /// Learned from client discovery tags: peer supports NIP-44 encryption. + pub supports_encryption: bool, + /// Learned from client discovery tags: peer supports ephemeral gift wraps (CEP-19). + pub supports_ephemeral_encryption: bool, + /// Learned from client discovery tags: peer supports CEP-22 oversized transfer. + pub supports_oversized_transfer: bool, /// Last activity timestamp. pub last_activity: Instant, /// Pending requests: event_id → original request ID. @@ -111,6 +117,9 @@ impl ClientSession { is_encrypted, has_sent_common_tags: false, supports_ephemeral_gift_wrap: false, + supports_encryption: false, + supports_ephemeral_encryption: false, + supports_oversized_transfer: false, last_activity: Instant::now(), pending_requests: HashMap::new(), event_to_progress_token: HashMap::new(), diff --git a/src/transport/base.rs b/src/transport/base.rs index 781d58b..bacb9c0 100644 --- a/src/transport/base.rs +++ b/src/transport/base.rs @@ -232,6 +232,21 @@ impl BaseTransport { pub fn create_response_tags(pubkey: &PublicKey, event_id: &EventId) -> Vec { vec![Tag::public_key(*pubkey), Tag::event(*event_id)] } + + /// Compose outbound event tags in canonical order: + /// routing (p, e) -> discovery (one-shot caps) -> negotiation (pmi, persistent). + pub fn compose_outbound_tags( + base_tags: &[Tag], + discovery_tags: &[Tag], + negotiation_tags: &[Tag], + ) -> Vec { + let mut tags = + Vec::with_capacity(base_tags.len() + discovery_tags.len() + negotiation_tags.len()); + tags.extend_from_slice(base_tags); + tags.extend_from_slice(discovery_tags); + tags.extend_from_slice(negotiation_tags); + tags + } } #[cfg(test)] @@ -396,4 +411,57 @@ mod tests { let big = "x".repeat(MAX_MESSAGE_SIZE + 1); assert!(!crate::core::validation::validate_message_size(&big)); } + + // ── compose_outbound_tags ────────────────────────────────── + + fn make_custom_tag(name: &str) -> Tag { + Tag::custom(TagKind::Custom(name.into()), Vec::::new()) + } + + #[test] + fn compose_outbound_tags_ordering() { + let keys = Keys::generate(); + let base = vec![Tag::public_key(keys.public_key())]; + let discovery = vec![make_custom_tag("support_encryption")]; + let negotiation = vec![make_custom_tag("pmi")]; + + let result = BaseTransport::compose_outbound_tags(&base, &discovery, &negotiation); + assert_eq!(result.len(), 3); + assert_eq!(result[0].clone().to_vec()[0], "p"); + assert_eq!(result[1].clone().to_vec()[0], "support_encryption"); + assert_eq!(result[2].clone().to_vec()[0], "pmi"); + } + + #[test] + fn compose_outbound_tags_empty_discovery() { + let keys = Keys::generate(); + let base = vec![Tag::public_key(keys.public_key())]; + let negotiation = vec![make_custom_tag("pmi")]; + + let result = BaseTransport::compose_outbound_tags(&base, &[], &negotiation); + assert_eq!(result.len(), 2); + assert_eq!(result[0].clone().to_vec()[0], "p"); + assert_eq!(result[1].clone().to_vec()[0], "pmi"); + } + + #[test] + fn compose_outbound_tags_all_empty() { + let result = BaseTransport::compose_outbound_tags(&[], &[], &[]); + assert!(result.is_empty()); + } + + #[test] + fn compose_outbound_tags_preserves_all_elements() { + let discovery = vec![ + make_custom_tag("support_encryption"), + make_custom_tag("support_encryption_ephemeral"), + ]; + let result = BaseTransport::compose_outbound_tags(&[], &discovery, &[]); + assert_eq!(result.len(), 2); + assert_eq!(result[0].clone().to_vec()[0], "support_encryption"); + assert_eq!( + result[1].clone().to_vec()[0], + "support_encryption_ephemeral" + ); + } } diff --git a/src/transport/server/mod.rs b/src/transport/server/mod.rs index 5a4e68f..022aee1 100644 --- a/src/transport/server/mod.rs +++ b/src/transport/server/mod.rs @@ -26,6 +26,7 @@ use crate::core::validation; use crate::encryption; use crate::relay::{RelayPool, RelayPoolTrait}; use crate::transport::base::BaseTransport; +use crate::transport::discovery_tags::learn_peer_capabilities; use crate::util::tracing_setup; @@ -333,8 +334,8 @@ impl NostrServerTransport { let original_request_id = route.original_request_id; let progress_token = route.progress_token; - let sessions = self.sessions.read().await; - let session = sessions.get(&client_pubkey_hex).ok_or_else(|| { + let mut sessions_w = self.sessions.write().await; + let session = sessions_w.get_mut(&client_pubkey_hex).ok_or_else(|| { tracing::error!( target: LOG_TARGET, client_pubkey = %client_pubkey_hex, @@ -351,7 +352,10 @@ impl NostrServerTransport { } let is_encrypted = session.is_encrypted; - drop(sessions); + + // CEP-35: include discovery tags on first response to this client + let discovery_tags = self.take_pending_server_discovery_tags(session); + drop(sessions_w); // CEP-19: Look up the incoming wrap kind for mirroring let mirrored_wrap_kind = self @@ -382,23 +386,8 @@ impl NostrServerTransport { Error::Other(error.to_string()) })?; - let mut tags = BaseTransport::create_response_tags(&client_pubkey, &event_id_parsed); - - // Send server info and capabilities on the first response. - let mut sent_common_tags = false; - let session_snapshot = self.sessions.get_session(&client_pubkey_hex).await; - if let Some(snap) = session_snapshot { - if !snap.has_sent_common_tags { - Self::append_common_response_tags( - &mut tags, - self.config.server_info.as_ref(), - &self.extra_common_tags, - self.config.encryption_mode, - self.config.gift_wrap_mode, - ); - sent_common_tags = true; - } - } + let base_tags = BaseTransport::create_response_tags(&client_pubkey, &event_id_parsed); + let tags = BaseTransport::compose_outbound_tags(&base_tags, &discovery_tags, &[]); if let Err(error) = self .base @@ -437,16 +426,7 @@ impl NostrServerTransport { return Err(error); } - if sent_common_tags { - self.sessions - .mark_common_tags_sent(&client_pubkey_hex) - .await; - } - - // Clean up only after successful send - self.event_routes.pop(event_id).await; - - // Clean up wrap-kind tracking and reverse mapping + // Clean up wrap-kind tracking self.request_wrap_kinds.write().await.remove(event_id); let mut sessions = self.sessions.write().await; @@ -477,23 +457,28 @@ impl NostrServerTransport { notification: &JsonRpcMessage, correlated_event_id: Option<&str>, ) -> Result<()> { - let sessions = self.sessions.read().await; + let mut sessions = self.sessions.write().await; let session = sessions - .get(client_pubkey_hex) + .get_mut(client_pubkey_hex) .ok_or_else(|| Error::Other(format!("No session for {client_pubkey_hex}")))?; let is_encrypted = session.is_encrypted; let supports_ephemeral = session.supports_ephemeral_gift_wrap; + + // CEP-35: include discovery tags on first message to this client + let discovery_tags = self.take_pending_server_discovery_tags(session); drop(sessions); let client_pubkey = PublicKey::from_hex(client_pubkey_hex).map_err(|e| Error::Other(e.to_string()))?; - let mut tags = BaseTransport::create_recipient_tags(&client_pubkey); + let mut base_tags = BaseTransport::create_recipient_tags(&client_pubkey); if let Some(eid) = correlated_event_id { let event_id = EventId::from_hex(eid).map_err(|e| Error::Other(e.to_string()))?; - tags.push(Tag::event(event_id)); + base_tags.push(Tag::event(event_id)); } + let tags = BaseTransport::compose_outbound_tags(&base_tags, &discovery_tags, &[]); + // CEP-19: Look up mirrored wrap kind from correlated request let correlated_wrap_kind = if let Some(event_id) = correlated_event_id { self.request_wrap_kinds @@ -732,6 +717,70 @@ impl NostrServerTransport { self.publish_resource_templates(templates).await } + // ── CEP-35 discovery tag helpers ────────────────────────────── + + /// Build common discovery tags from server config. + /// + /// Includes server info tags (name, about, website, picture) and capability + /// tags (support_encryption, support_encryption_ephemeral) based on the + /// transport's encryption and gift-wrap mode. + fn get_common_tags(&self) -> Vec { + let mut tags = Vec::new(); + + // Server info tags + if let Some(ref info) = self.config.server_info { + if let Some(ref name) = info.name { + tags.push(Tag::custom( + TagKind::Custom(tags::NAME.into()), + vec![name.clone()], + )); + } + if let Some(ref about) = info.about { + tags.push(Tag::custom( + TagKind::Custom(tags::ABOUT.into()), + vec![about.clone()], + )); + } + if let Some(ref website) = info.website { + tags.push(Tag::custom( + TagKind::Custom(tags::WEBSITE.into()), + vec![website.clone()], + )); + } + if let Some(ref picture) = info.picture { + tags.push(Tag::custom( + TagKind::Custom(tags::PICTURE.into()), + vec![picture.clone()], + )); + } + } + + // Capability tags + if self.config.encryption_mode != EncryptionMode::Disabled { + tags.push(Tag::custom( + TagKind::Custom(tags::SUPPORT_ENCRYPTION.into()), + Vec::::new(), + )); + if self.config.gift_wrap_mode.supports_ephemeral() { + tags.push(Tag::custom( + TagKind::Custom(tags::SUPPORT_ENCRYPTION_EPHEMERAL.into()), + Vec::::new(), + )); + } + } + + tags + } + + /// One-shot: returns common tags if not yet sent to this client, empty otherwise. + fn take_pending_server_discovery_tags(&self, session: &mut ClientSession) -> Vec { + if session.has_sent_common_tags { + return vec![]; + } + session.has_sent_common_tags = true; + self.get_common_tags() + } + // ── Internal ──────────────────────────────────────────────── fn is_capability_excluded( @@ -792,7 +841,7 @@ impl NostrServerTransport { continue; } - let (content, sender_pubkey, event_id, is_encrypted) = if is_gift_wrap { + let (content, sender_pubkey, event_id, is_encrypted, inner_tags) = if is_gift_wrap { if encryption_mode == EncryptionMode::Disabled { tracing::warn!( target: LOG_TARGET, @@ -848,11 +897,13 @@ impl NostrServerTransport { }; guard.put(event.id, ()); } + let inner_tags: Vec = inner.tags.to_vec(); ( inner.content, inner.pubkey.to_hex(), inner.id.to_hex(), true, + inner_tags, ) } Err(error) => { @@ -888,6 +939,7 @@ impl NostrServerTransport { event.pubkey.to_hex(), event.id.to_hex(), false, + event.tags.to_vec(), ) }; @@ -1014,6 +1066,16 @@ impl NostrServerTransport { session.supports_ephemeral_gift_wrap = true; } + // CEP-35: learn client capabilities from inner event tags + let discovered = learn_peer_capabilities(&inner_tags); + session.supports_encryption |= discovered.supports_encryption; + session.supports_ephemeral_encryption |= discovered.supports_ephemeral_encryption; + // Only learn oversized support if CEP-22 is enabled on this server + // TODO: wire from config when CEP-22 lands + let oversized_enabled = false; + session.supports_oversized_transfer |= + oversized_enabled && discovered.supports_oversized_transfer; + // Track request for correlation if let JsonRpcMessage::Request(ref req) = mcp_msg { let original_id = req.id.clone(); @@ -1449,7 +1511,6 @@ mod tests { #[test] fn test_select_outbound_gift_wrap_kind_plaintext() { - // Plaintext: no encryption regardless of mode assert_eq!( NostrServerTransport::select_outbound_gift_wrap_kind( GiftWrapMode::Optional, @@ -1462,7 +1523,6 @@ mod tests { #[test] fn test_select_outbound_gift_wrap_kind_mirrors_incoming() { - // Mirrors ephemeral kind when Optional mode allows it assert_eq!( NostrServerTransport::select_outbound_gift_wrap_kind( GiftWrapMode::Optional, @@ -1475,7 +1535,6 @@ mod tests { #[test] fn test_select_outbound_gift_wrap_kind_persistent_mode_overrides_ephemeral() { - // Persistent mode: ephemeral mirror ignored, falls back to GIFT_WRAP_KIND assert_eq!( NostrServerTransport::select_outbound_gift_wrap_kind( GiftWrapMode::Persistent, @@ -1660,4 +1719,36 @@ mod tests { .and_then(|t| t.clone().to_vec().get(1).cloned()); assert_eq!(tag_value.as_deref(), Some("value")); } + + // ── CEP-35 discovery tag helpers ──────────────────────────── + + #[test] + fn test_cep35_client_session_new_fields_default_false() { + let session = ClientSession::new(false); + assert!(!session.has_sent_common_tags); + assert!(!session.supports_encryption); + assert!(!session.supports_ephemeral_encryption); + assert!(!session.supports_oversized_transfer); + } + + #[test] + fn test_cep35_capability_or_assign() { + let mut session = ClientSession::new(false); + + session.supports_encryption |= true; + session.supports_ephemeral_encryption |= false; + + session.supports_encryption |= false; + session.supports_ephemeral_encryption |= true; + + assert!(session.supports_encryption, "OR-assign must not downgrade"); + assert!(session.supports_ephemeral_encryption); + assert!(!session.supports_oversized_transfer); + } + + #[test] + fn test_config_gift_wrap_mode_default() { + let config = NostrServerTransportConfig::default(); + assert_eq!(config.gift_wrap_mode, GiftWrapMode::Optional); + } } diff --git a/src/transport/server/session_store.rs b/src/transport/server/session_store.rs index 7f29e25..6188482 100644 --- a/src/transport/server/session_store.rs +++ b/src/transport/server/session_store.rs @@ -218,4 +218,124 @@ mod tests { assert!(keys.contains(&"client-1")); assert!(keys.contains(&"client-2")); } + + // ── CEP-35 capability fields ──────────────────────────────── + + #[tokio::test] + async fn new_session_capability_fields_default_false() { + let store = SessionStore::new(); + store.get_or_create_session("client-1", false).await; + + let sessions = store.read().await; + let session = sessions.get("client-1").unwrap(); + assert!(!session.has_sent_common_tags); + assert!(!session.supports_encryption); + assert!(!session.supports_ephemeral_encryption); + assert!(!session.supports_oversized_transfer); + } + + #[tokio::test] + async fn has_sent_common_tags_flag() { + let store = SessionStore::new(); + store.get_or_create_session("client-1", false).await; + + let mut sessions = store.write().await; + let session = sessions.get_mut("client-1").unwrap(); + assert!(!session.has_sent_common_tags); + session.has_sent_common_tags = true; + assert!(session.has_sent_common_tags); + } + + #[tokio::test] + async fn capability_or_assign_persists() { + let store = SessionStore::new(); + store.get_or_create_session("client-1", false).await; + + // First update: learn encryption support + { + let mut sessions = store.write().await; + let session = sessions.get_mut("client-1").unwrap(); + session.supports_encryption |= true; + session.supports_ephemeral_encryption |= false; + } + + // Second update: learn ephemeral support; encryption stays true + { + let mut sessions = store.write().await; + let session = sessions.get_mut("client-1").unwrap(); + session.supports_encryption |= false; // should stay true + session.supports_ephemeral_encryption |= true; + } + + let sessions = store.read().await; + let session = sessions.get("client-1").unwrap(); + assert!(session.supports_encryption, "OR-assign must not downgrade"); + assert!(session.supports_ephemeral_encryption); + assert!(!session.supports_oversized_transfer); + } + + #[tokio::test] + async fn capability_fields_independent_per_client() { + let store = SessionStore::new(); + store.get_or_create_session("client-a", false).await; + store.get_or_create_session("client-b", false).await; + + { + let mut sessions = store.write().await; + let sa = sessions.get_mut("client-a").unwrap(); + sa.supports_encryption = true; + sa.has_sent_common_tags = true; + } + + let sessions = store.read().await; + let sa = sessions.get("client-a").unwrap(); + let sb = sessions.get("client-b").unwrap(); + assert!(sa.supports_encryption); + assert!(sa.has_sent_common_tags); + assert!(!sb.supports_encryption); + assert!(!sb.has_sent_common_tags); + } + + #[tokio::test] + async fn get_or_create_preserves_capability_fields() { + let store = SessionStore::new(); + store.get_or_create_session("client-1", false).await; + + // Set capability fields + { + let mut sessions = store.write().await; + let session = sessions.get_mut("client-1").unwrap(); + session.supports_encryption = true; + session.has_sent_common_tags = true; + } + + // Re-enter via get_or_create (existing session) + let created = store.get_or_create_session("client-1", true).await; + assert!(!created); + + // Capability fields must survive + let sessions = store.read().await; + let session = sessions.get("client-1").unwrap(); + assert!(session.supports_encryption); + assert!(session.has_sent_common_tags); + } + + #[tokio::test] + async fn clear_resets_capability_fields() { + let store = SessionStore::new(); + store.get_or_create_session("client-1", false).await; + { + let mut sessions = store.write().await; + let s = sessions.get_mut("client-1").unwrap(); + s.supports_encryption = true; + } + + store.clear().await; + store.get_or_create_session("client-1", false).await; + + let sessions = store.read().await; + let session = sessions.get("client-1").unwrap(); + assert!(!session.supports_encryption); + assert!(!session.has_sent_common_tags); + } } diff --git a/tests/transport_integration.rs b/tests/transport_integration.rs index a51ee8b..e84d5b9 100644 --- a/tests/transport_integration.rs +++ b/tests/transport_integration.rs @@ -12,11 +12,13 @@ use std::sync::{ use std::time::Duration; use async_trait::async_trait; +use contextvm_sdk::core::constants::tags; use contextvm_sdk::core::constants::{ - mcp_protocol_version, CTXVM_MESSAGES_KIND, GIFT_WRAP_KIND, PROMPTS_LIST_KIND, - RESOURCES_LIST_KIND, RESOURCETEMPLATES_LIST_KIND, SERVER_ANNOUNCEMENT_KIND, TOOLS_LIST_KIND, + mcp_protocol_version, CTXVM_MESSAGES_KIND, EPHEMERAL_GIFT_WRAP_KIND, GIFT_WRAP_KIND, + PROMPTS_LIST_KIND, RESOURCES_LIST_KIND, RESOURCETEMPLATES_LIST_KIND, SERVER_ANNOUNCEMENT_KIND, + TOOLS_LIST_KIND, }; -use contextvm_sdk::core::types::EncryptionMode; +use contextvm_sdk::core::types::{EncryptionMode, GiftWrapMode}; use contextvm_sdk::relay::mock::MockRelayPool; use contextvm_sdk::transport::client::{NostrClientTransport, NostrClientTransportConfig}; use contextvm_sdk::transport::server::{NostrServerTransport, NostrServerTransportConfig}; @@ -2330,7 +2332,7 @@ async fn announced_server_does_not_error_on_unauthorized_notification() { ); } -// ── 31. First response includes discovery tags ────────────────────────────── +// ── 31. First response includes discovery tags (upstream CEP-19) ───────────── #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn first_response_includes_discovery_tags() { @@ -2346,7 +2348,7 @@ async fn first_response_includes_discovery_tags() { ..Default::default() }), encryption_mode: EncryptionMode::Optional, - gift_wrap_mode: contextvm_sdk::core::types::GiftWrapMode::Optional, + gift_wrap_mode: GiftWrapMode::Optional, ..Default::default() }, Arc::clone(&s_pool) as Arc, @@ -2429,7 +2431,7 @@ async fn first_response_includes_discovery_tags() { let events = s_pool.stored_events().await; let responses: Vec<_> = events .iter() - .filter(|e| e.kind == Kind::Custom(contextvm_sdk::core::constants::CTXVM_MESSAGES_KIND)) + .filter(|e| e.kind == Kind::Custom(CTXVM_MESSAGES_KIND)) .cloned() .collect(); @@ -2472,7 +2474,7 @@ async fn notification_mirror_selection_wrt_cep_19() { let mut server = NostrServerTransport::with_relay_pool( NostrServerTransportConfig { encryption_mode: EncryptionMode::Optional, - gift_wrap_mode: contextvm_sdk::core::types::GiftWrapMode::Optional, + gift_wrap_mode: GiftWrapMode::Optional, ..Default::default() }, Arc::clone(&s_pool) as Arc, @@ -2484,7 +2486,7 @@ async fn notification_mirror_selection_wrt_cep_19() { NostrClientTransportConfig { server_pubkey: server_pubkey.to_hex(), encryption_mode: EncryptionMode::Optional, - gift_wrap_mode: contextvm_sdk::core::types::GiftWrapMode::Ephemeral, // Forces client to use Ephemeral + gift_wrap_mode: GiftWrapMode::Ephemeral, ..Default::default() }, as_pool(client_pool), @@ -2500,7 +2502,6 @@ async fn notification_mirror_selection_wrt_cep_19() { client.start().await.expect("client start"); let_event_loops_start().await; - // Send a request. It should be encrypted and wrapped with Ephemeral (21059) let request1 = JsonRpcMessage::Request(JsonRpcRequest { jsonrpc: "2.0".to_string(), id: serde_json::json!("req-1"), @@ -2514,7 +2515,6 @@ async fn notification_mirror_selection_wrt_cep_19() { .expect("timeout") .expect("channel closed"); - // Reply with a correlated notification let notification = JsonRpcMessage::Notification(JsonRpcNotification { jsonrpc: "2.0".to_string(), method: "notifications/progress".to_string(), @@ -2529,19 +2529,339 @@ async fn notification_mirror_selection_wrt_cep_19() { .await .expect("send notification"); - // The notification should have been sent as an Ephemeral Gift Wrap (21059) let events = s_pool.stored_events().await; let ephemeral_wraps: Vec<_> = events .iter() - .filter(|e| { - e.kind == Kind::Custom(contextvm_sdk::core::constants::EPHEMERAL_GIFT_WRAP_KIND) - }) + .filter(|e| e.kind == Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND)) .cloned() .collect(); - // 1 from client (request), 1 from server (notification). The client also sends other msgs? assert!( ephemeral_wraps.len() >= 2, "Expected ephemeral wraps for both request and notification" ); } + +// ── CEP-35: Server-side discovery tag emission & capability learning ───────── + +fn event_tag_vecs(event: &Event) -> Vec> { + event.tags.iter().map(|t| t.clone().to_vec()).collect() +} + +fn has_tag_name(event: &Event, name: &str) -> bool { + event_tag_vecs(event) + .iter() + .any(|v| v.first().map(|s| s.as_str()) == Some(name)) +} + +fn get_tag_value(event: &Event, name: &str) -> Option { + event_tag_vecs(event).iter().find_map(|v| { + if v.first().map(|s| s.as_str()) == Some(name) { + v.get(1).cloned() + } else { + None + } + }) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn server_response_includes_encryption_tags_when_enabled() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + let server_pool_arc = Arc::new(server_pool); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + encryption_mode: EncryptionMode::Optional, + gift_wrap_mode: GiftWrapMode::Optional, + ..Default::default() + }, + Arc::clone(&server_pool_arc) as Arc, + ) + .await + .unwrap(); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(client_pool), + ) + .await + .unwrap(); + + let mut server_rx = server.take_message_receiver().unwrap(); + let mut client_rx = client.take_message_receiver().unwrap(); + server.start().await.unwrap(); + client.start().await.unwrap(); + let_event_loops_start().await; + + client + .send(&JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: None, + })) + .await + .unwrap(); + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .unwrap() + .unwrap(); + server + .send_response( + &incoming.event_id, + JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + result: serde_json::json!({}), + }), + ) + .await + .unwrap(); + let _ = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .unwrap(); + + let events = server_pool_arc.stored_events().await; + let response_event = events + .iter() + .find(|e| e.kind == Kind::Custom(CTXVM_MESSAGES_KIND) && has_tag_name(e, "e")) + .expect("response event must exist"); + + assert!( + has_tag_name(response_event, tags::SUPPORT_ENCRYPTION), + "first response must include support_encryption when mode != Disabled" + ); + assert!( + has_tag_name(response_event, tags::SUPPORT_ENCRYPTION_EPHEMERAL), + "first response must include support_encryption_ephemeral when GiftWrapMode != Persistent" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn server_response_excludes_ephemeral_tag_when_persistent() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + let server_pool_arc = Arc::new(server_pool); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + encryption_mode: EncryptionMode::Optional, + gift_wrap_mode: GiftWrapMode::Persistent, + ..Default::default() + }, + Arc::clone(&server_pool_arc) as Arc, + ) + .await + .unwrap(); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(client_pool), + ) + .await + .unwrap(); + + let mut server_rx = server.take_message_receiver().unwrap(); + let mut client_rx = client.take_message_receiver().unwrap(); + server.start().await.unwrap(); + client.start().await.unwrap(); + let_event_loops_start().await; + + client + .send(&JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: None, + })) + .await + .unwrap(); + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .unwrap() + .unwrap(); + server + .send_response( + &incoming.event_id, + JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + result: serde_json::json!({}), + }), + ) + .await + .unwrap(); + let _ = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .unwrap(); + + let events = server_pool_arc.stored_events().await; + let response_event = events + .iter() + .find(|e| e.kind == Kind::Custom(CTXVM_MESSAGES_KIND) && has_tag_name(e, "e")) + .unwrap(); + + assert!( + has_tag_name(response_event, tags::SUPPORT_ENCRYPTION), + "support_encryption must be present" + ); + assert!( + !has_tag_name(response_event, tags::SUPPORT_ENCRYPTION_EPHEMERAL), + "support_encryption_ephemeral must NOT be present when GiftWrapMode is Persistent" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn server_learns_capabilities_from_client_request() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(server_pool), + ) + .await + .unwrap(); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(client_pool), + ) + .await + .unwrap(); + + let mut server_rx = server.take_message_receiver().unwrap(); + server.start().await.unwrap(); + client.start().await.unwrap(); + let_event_loops_start().await; + + client + .send(&JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: None, + })) + .await + .unwrap(); + + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .unwrap() + .unwrap(); + + assert_eq!(incoming.message.method(), Some("initialize")); + client + .send(&JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(2), + method: "tools/list".to_string(), + params: None, + })) + .await + .unwrap(); + let incoming2 = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .unwrap() + .unwrap(); + assert_eq!(incoming2.message.method(), Some("tools/list")); + assert_eq!(incoming.client_pubkey, incoming2.client_pubkey); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn server_disabled_encryption_omits_encryption_tags() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + let server_pool_arc = Arc::new(server_pool); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + server_info: Some(ServerInfo { + name: Some("NoEncrypt".to_string()), + ..Default::default() + }), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + Arc::clone(&server_pool_arc) as Arc, + ) + .await + .unwrap(); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(client_pool), + ) + .await + .unwrap(); + + let mut server_rx = server.take_message_receiver().unwrap(); + let mut client_rx = client.take_message_receiver().unwrap(); + server.start().await.unwrap(); + client.start().await.unwrap(); + let_event_loops_start().await; + + client + .send(&JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: None, + })) + .await + .unwrap(); + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .unwrap() + .unwrap(); + server + .send_response( + &incoming.event_id, + JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + result: serde_json::json!({}), + }), + ) + .await + .unwrap(); + let _ = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .unwrap(); + + let events = server_pool_arc.stored_events().await; + let response_event = events + .iter() + .find(|e| e.kind == Kind::Custom(CTXVM_MESSAGES_KIND) && has_tag_name(e, "e")) + .unwrap(); + + assert!(has_tag_name(response_event, tags::NAME)); + assert!( + !has_tag_name(response_event, tags::SUPPORT_ENCRYPTION), + "encryption tags must be omitted when EncryptionMode is Disabled" + ); + assert!(!has_tag_name( + response_event, + tags::SUPPORT_ENCRYPTION_EPHEMERAL + )); +} From 378ae4d7e1552fcb4eed0c43a264b846a5fd856f Mon Sep 17 00:00:00 2001 From: Harsh Date: Sun, 3 May 2026 00:52:02 +0530 Subject: [PATCH 57/69] feat: client discovery tag emission and capability learning for CEP-35 --- src/transport/client/mod.rs | 407 +++++++++++++++++++++++++++------ tests/transport_integration.rs | 196 ++++++++++++++++ 2 files changed, 534 insertions(+), 69 deletions(-) diff --git a/src/transport/client/mod.rs b/src/transport/client/mod.rs index 675be02..65aa47b 100644 --- a/src/transport/client/mod.rs +++ b/src/transport/client/mod.rs @@ -23,6 +23,7 @@ use crate::core::validation; use crate::encryption; use crate::relay::{RelayPool, RelayPoolTrait}; use crate::transport::base::BaseTransport; +use crate::transport::discovery_tags::{parse_discovered_peer_capabilities, PeerCapabilities}; use crate::util::tracing_setup; @@ -67,6 +68,12 @@ pub struct NostrClientTransport { server_pubkey: PublicKey, /// Pending request event IDs awaiting responses. pending_requests: ClientCorrelationStore, + /// CEP-35: one-shot flag for client discovery tag emission. + has_sent_discovery_tags: AtomicBool, + /// CEP-35: learned server capabilities from inbound discovery tags. + discovered_server_capabilities: Arc>, + /// CEP-35: first inbound event carrying discovery tags (session baseline). + server_initialize_event: Arc>>, /// Learned support for server-side ephemeral gift wraps. server_supports_ephemeral: Arc, /// Outer gift-wrap event IDs successfully decrypted and verified (inner `verify()`). @@ -126,6 +133,9 @@ impl NostrClientTransport { config, server_pubkey, pending_requests: ClientCorrelationStore::new(), + has_sent_discovery_tags: AtomicBool::new(false), + discovered_server_capabilities: Arc::new(Mutex::new(PeerCapabilities::default())), + server_initialize_event: Arc::new(Mutex::new(None)), server_supports_ephemeral: Arc::new(AtomicBool::new(false)), seen_gift_wrap_ids, message_tx: tx, @@ -171,6 +181,9 @@ impl NostrClientTransport { config, server_pubkey, pending_requests: ClientCorrelationStore::new(), + has_sent_discovery_tags: AtomicBool::new(false), + discovered_server_capabilities: Arc::new(Mutex::new(PeerCapabilities::default())), + server_initialize_event: Arc::new(Mutex::new(None)), server_supports_ephemeral: Arc::new(AtomicBool::new(false)), seen_gift_wrap_ids, message_tx: tx, @@ -226,6 +239,8 @@ impl NostrClientTransport { let tx = self.message_tx.clone(); let encryption_mode = self.config.encryption_mode; let gift_wrap_mode = self.config.gift_wrap_mode; + let discovered_caps = self.discovered_server_capabilities.clone(); + let init_event = self.server_initialize_event.clone(); let server_supports_ephemeral = self.server_supports_ephemeral.clone(); let seen_gift_wrap_ids = self.seen_gift_wrap_ids.clone(); @@ -237,6 +252,8 @@ impl NostrClientTransport { tx, encryption_mode, gift_wrap_mode, + discovered_caps, + init_event, server_supports_ephemeral, seen_gift_wrap_ids, ) @@ -273,7 +290,14 @@ impl NostrClientTransport { } } - let tags = BaseTransport::create_recipient_tags(&self.server_pubkey); + let is_request = message.is_request(); + let base_tags = BaseTransport::create_recipient_tags(&self.server_pubkey); + let discovery_tags = if is_request { + self.get_pending_client_discovery_tags() + } else { + vec![] + }; + let tags = BaseTransport::compose_outbound_tags(&base_tags, &discovery_tags, &[]); let (event_id, publishable_event) = self .base @@ -316,6 +340,11 @@ impl NostrClientTransport { return Err(error); } + // Flip one-shot flag only after successful publish + if is_request && !discovery_tags.is_empty() { + self.has_sent_discovery_tags.store(true, Ordering::Relaxed); + } + tracing::debug!( target: LOG_TARGET, event_id = %event_id.to_hex(), @@ -360,6 +389,8 @@ impl NostrClientTransport { tx: tokio::sync::mpsc::UnboundedSender, encryption_mode: EncryptionMode, gift_wrap_mode: GiftWrapMode, + discovered_caps: Arc>, + init_event: Arc>>, server_supports_ephemeral: Arc, seen_gift_wrap_ids: Arc>>, ) { @@ -403,81 +434,91 @@ impl NostrClientTransport { } // Handle gift-wrapped events - let (actual_event_content, actual_pubkey, e_tag, verified_tags) = if is_gift_wrap { - { - let guard = match seen_gift_wrap_ids.lock() { - Ok(g) => g, - Err(poisoned) => poisoned.into_inner(), - }; - if guard.contains(&event.id) { - tracing::debug!( - target: LOG_TARGET, - event_id = %event.id.to_hex(), - "Skipping duplicate gift-wrap (outer id)" - ); - continue; - } - } - // Single-layer NIP-44 decrypt (matches JS/TS SDK) - let signer = match relay_pool.signer().await { - Ok(s) => s, - Err(error) => { - tracing::error!( - target: LOG_TARGET, - error = %error, - "Failed to get signer" - ); - continue; + let (actual_event_content, actual_pubkey, e_tag, verified_tags, source_event) = + if is_gift_wrap { + { + let guard = match seen_gift_wrap_ids.lock() { + Ok(g) => g, + Err(poisoned) => poisoned.into_inner(), + }; + if guard.contains(&event.id) { + tracing::debug!( + target: LOG_TARGET, + event_id = %event.id.to_hex(), + "Skipping duplicate gift-wrap (outer id)" + ); + continue; + } } - }; - match encryption::decrypt_gift_wrap_single_layer(&signer, &event).await { - Ok(decrypted_json) => { - match serde_json::from_str::(&decrypted_json) { - Ok(inner) => { - if let Err(e) = inner.verify() { - tracing::warn!( - "Inner event signature verification failed: {e}" + // Single-layer NIP-44 decrypt (matches JS/TS SDK) + let signer = match relay_pool.signer().await { + Ok(s) => s, + Err(error) => { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to get signer" + ); + continue; + } + }; + match encryption::decrypt_gift_wrap_single_layer(&signer, &event).await { + Ok(decrypted_json) => { + match serde_json::from_str::(&decrypted_json) { + Ok(inner) => { + if let Err(e) = inner.verify() { + tracing::warn!( + "Inner event signature verification failed: {e}" + ); + continue; + } + { + let mut guard = match seen_gift_wrap_ids.lock() { + Ok(g) => g, + Err(poisoned) => poisoned.into_inner(), + }; + guard.put(event.id, ()); + } + let e_tag = serializers::get_tag_value(&inner.tags, "e"); + let inner_clone = inner.clone(); + ( + inner.content, + inner.pubkey, + e_tag, + inner.tags, + inner_clone, + ) + } + Err(error) => { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to parse inner event" ); continue; } - { - let mut guard = match seen_gift_wrap_ids.lock() { - Ok(g) => g, - Err(poisoned) => poisoned.into_inner(), - }; - guard.put(event.id, ()); - } - let e_tag = serializers::get_tag_value(&inner.tags, "e"); - (inner.content, inner.pubkey, e_tag, inner.tags) - } - Err(error) => { - tracing::error!( - target: LOG_TARGET, - error = %error, - "Failed to parse inner event" - ); - continue; } } + Err(error) => { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to decrypt gift wrap" + ); + continue; + } } - Err(error) => { - tracing::error!( - target: LOG_TARGET, - error = %error, - "Failed to decrypt gift wrap" - ); - continue; - } - } - } else { - let e_tag = serializers::get_tag_value(&event.tags, "e"); - ( - event.content.clone(), - event.pubkey, - e_tag, - event.tags.clone(), - ) - }; + } else { + let e_tag = serializers::get_tag_value(&event.tags, "e"); + let event_clone = (*event).clone(); + ( + event.content.clone(), + event.pubkey, + e_tag, + event.tags.clone(), + event_clone, + ) + }; // Verify it's from our server if actual_pubkey != server_pubkey { @@ -490,6 +531,9 @@ impl NostrClientTransport { continue; } + // CEP-35: learn server capabilities from discovery tags + Self::learn_server_discovery(&discovered_caps, &init_event, &source_event); + // CEP-19: learn ephemeral support from server if Self::should_learn_ephemeral_support( actual_pubkey, @@ -547,6 +591,88 @@ impl NostrClientTransport { } } + // ── CEP-35 discovery tag helpers ────────────────────────────── + + /// Constructs client capability tags based on config. + fn get_client_capability_tags(&self) -> Vec { + let mut tags = Vec::new(); + if self.config.encryption_mode != EncryptionMode::Disabled { + tags.push(Tag::custom( + TagKind::Custom(tags::SUPPORT_ENCRYPTION.into()), + Vec::::new(), + )); + if self.config.gift_wrap_mode != GiftWrapMode::Persistent { + tags.push(Tag::custom( + TagKind::Custom(tags::SUPPORT_ENCRYPTION_EPHEMERAL.into()), + Vec::::new(), + )); + } + } + tags + } + + /// One-shot: returns capability tags if not yet sent, empty otherwise. + fn get_pending_client_discovery_tags(&self) -> Vec { + if self.has_sent_discovery_tags.load(Ordering::Relaxed) { + vec![] + } else { + self.get_client_capability_tags() + } + } + + /// Parses inbound event tags and updates learned server capabilities. + fn learn_server_discovery( + discovered_caps: &Mutex, + init_event: &Mutex>, + event: &Event, + ) { + let tag_vec: Vec = event.tags.clone().to_vec(); + let discovered = parse_discovered_peer_capabilities(&tag_vec); + if discovered.discovery_tags.is_empty() { + return; + } + + { + let mut caps = match discovered_caps.lock() { + Ok(g) => g, + Err(p) => p.into_inner(), + }; + caps.supports_encryption |= discovered.capabilities.supports_encryption; + caps.supports_ephemeral_encryption |= + discovered.capabilities.supports_ephemeral_encryption; + caps.supports_oversized_transfer |= discovered.capabilities.supports_oversized_transfer; + } + + let mut stored = match init_event.lock() { + Ok(g) => g, + Err(p) => p.into_inner(), + }; + if stored.is_none() { + *stored = Some(event.clone()); + } + // Note: TS SDK has an upgrade path where a later event with an InitializeResult + // replaces a non-initialize baseline. Not implemented here -- edge case only + // relevant if the first server message with discovery tags is a notification. + } + + /// Returns a clone of the first inbound event that carried server discovery tags. + pub fn get_server_initialize_event(&self) -> Option { + let guard = match self.server_initialize_event.lock() { + Ok(g) => g, + Err(p) => p.into_inner(), + }; + guard.clone() + } + + /// Returns a snapshot of the learned server capabilities from discovery tags. + pub fn discovered_server_capabilities(&self) -> PeerCapabilities { + let guard = match self.discovered_server_capabilities.lock() { + Ok(g) => g, + Err(p) => p.into_inner(), + }; + *guard + } + fn choose_outbound_gift_wrap_kind(&self) -> u16 { match self.config.gift_wrap_mode { GiftWrapMode::Persistent => GIFT_WRAP_KIND, @@ -820,4 +946,147 @@ mod tests { "Disabled mode must accept plaintext events" ); } + + // ── CEP-35 client discovery tag emission ──────────────────── + + fn make_transport_for_tags( + encryption_mode: EncryptionMode, + gift_wrap_mode: GiftWrapMode, + ) -> NostrClientTransport { + let keys = Keys::generate(); + NostrClientTransport { + base: BaseTransport { + relay_pool: Arc::new(crate::relay::mock::MockRelayPool::new()), + encryption_mode, + is_connected: false, + }, + config: NostrClientTransportConfig { + encryption_mode, + gift_wrap_mode, + server_pubkey: Keys::generate().public_key().to_hex(), + ..Default::default() + }, + server_pubkey: keys.public_key(), + pending_requests: ClientCorrelationStore::new(), + has_sent_discovery_tags: AtomicBool::new(false), + discovered_server_capabilities: Arc::new(Mutex::new(PeerCapabilities::default())), + server_initialize_event: Arc::new(Mutex::new(None)), + server_supports_ephemeral: Arc::new(AtomicBool::new(false)), + seen_gift_wrap_ids: Arc::new(Mutex::new(LruCache::new(NonZeroUsize::new(10).unwrap()))), + message_tx: tokio::sync::mpsc::unbounded_channel().0, + message_rx: None, + } + } + + fn make_tag(parts: &[&str]) -> Tag { + let kind = TagKind::Custom(parts[0].into()); + let values: Vec = parts[1..].iter().map(|s| s.to_string()).collect(); + Tag::custom(kind, values) + } + + fn tag_names(tags: &[Tag]) -> Vec { + tags.iter().map(|t| t.clone().to_vec()[0].clone()).collect() + } + + #[test] + fn client_capability_tags_encryption_optional() { + let t = make_transport_for_tags(EncryptionMode::Optional, GiftWrapMode::Optional); + let tags = t.get_client_capability_tags(); + let names = tag_names(&tags); + assert_eq!( + names, + vec!["support_encryption", "support_encryption_ephemeral"] + ); + } + + #[test] + fn client_capability_tags_encryption_disabled() { + let t = make_transport_for_tags(EncryptionMode::Disabled, GiftWrapMode::Optional); + let tags = t.get_client_capability_tags(); + assert!(tags.is_empty()); + } + + #[test] + fn client_capability_tags_persistent_gift_wrap() { + let t = make_transport_for_tags(EncryptionMode::Optional, GiftWrapMode::Persistent); + let tags = t.get_client_capability_tags(); + let names = tag_names(&tags); + assert_eq!(names, vec!["support_encryption"]); + } + + #[test] + fn client_discovery_tags_sent_once() { + let t = make_transport_for_tags(EncryptionMode::Optional, GiftWrapMode::Optional); + let first = t.get_pending_client_discovery_tags(); + assert!(!first.is_empty()); + + t.has_sent_discovery_tags.store(true, Ordering::Relaxed); + let second = t.get_pending_client_discovery_tags(); + assert!(second.is_empty()); + } + + // ── CEP-35 client capability learning ─────────────────────── + + fn make_event_with_tags(tag_parts: &[&[&str]]) -> Event { + let keys = Keys::generate(); + let tags: Vec = tag_parts.iter().map(|p| make_tag(p)).collect(); + let builder = EventBuilder::new(Kind::Custom(CTXVM_MESSAGES_KIND), "{}").tags(tags); + let unsigned = builder.build(keys.public_key()); + unsigned.sign_with_keys(&keys).unwrap() + } + + #[test] + fn client_learn_server_discovery_sets_baseline() { + let caps = Mutex::new(PeerCapabilities::default()); + let init = Mutex::new(None); + let event = make_event_with_tags(&[&["support_encryption"], &["name", "TestServer"]]); + + NostrClientTransport::learn_server_discovery(&caps, &init, &event); + + let c = caps.lock().unwrap(); + assert!(c.supports_encryption); + assert!(!c.supports_ephemeral_encryption); + + let stored = init.lock().unwrap(); + assert!(stored.is_some()); + assert_eq!(stored.as_ref().unwrap().id, event.id); + } + + #[test] + fn client_learn_server_discovery_or_assigns() { + let caps = Mutex::new(PeerCapabilities::default()); + let init = Mutex::new(None); + + let event1 = make_event_with_tags(&[&["support_encryption"]]); + NostrClientTransport::learn_server_discovery(&caps, &init, &event1); + + // Second event with different caps does NOT downgrade + let event2 = make_event_with_tags(&[&["support_encryption_ephemeral"]]); + NostrClientTransport::learn_server_discovery(&caps, &init, &event2); + + let c = caps.lock().unwrap(); + assert!(c.supports_encryption, "must not downgrade"); + assert!(c.supports_ephemeral_encryption, "must learn new cap"); + } + + #[test] + fn client_baseline_not_replaced_on_later_events() { + let caps = Mutex::new(PeerCapabilities::default()); + let init = Mutex::new(None); + + let event1 = make_event_with_tags(&[&["support_encryption"], &["name", "First"]]); + NostrClientTransport::learn_server_discovery(&caps, &init, &event1); + let first_id = event1.id; + + let event2 = + make_event_with_tags(&[&["support_encryption_ephemeral"], &["name", "Second"]]); + NostrClientTransport::learn_server_discovery(&caps, &init, &event2); + + let stored = init.lock().unwrap(); + assert_eq!( + stored.as_ref().unwrap().id, + first_id, + "baseline must not be replaced" + ); + } } diff --git a/tests/transport_integration.rs b/tests/transport_integration.rs index e84d5b9..0d78b97 100644 --- a/tests/transport_integration.rs +++ b/tests/transport_integration.rs @@ -2865,3 +2865,199 @@ async fn server_disabled_encryption_omits_encryption_tags() { tags::SUPPORT_ENCRYPTION_EPHEMERAL )); } + +// ── CEP-35: Client-side discovery tag emission & capability learning ───────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn client_disabled_encryption_emits_no_discovery_tags() { + // Disabled encryption: client must not emit cap tags. Positive case (Optional + // mode emits tags) is covered by unit test client_capability_tags_encryption_optional. + let pool = Arc::new(MockRelayPool::new()); + let server_keys = Keys::generate(); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_keys.public_key().to_hex(), + encryption_mode: EncryptionMode::Disabled, + gift_wrap_mode: GiftWrapMode::Optional, + ..Default::default() + }, + Arc::clone(&pool) as Arc, + ) + .await + .unwrap(); + + client.start().await.unwrap(); + let_event_loops_start().await; + + // With Disabled encryption, no cap tags are emitted (correct per spec). + // Verify the event is published with p tag but without cap tags. + client + .send(&JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: None, + })) + .await + .unwrap(); + + let events = pool.stored_events().await; + let client_event = events + .iter() + .find(|e| e.kind == Kind::Custom(CTXVM_MESSAGES_KIND)) + .expect("client must publish a request event"); + + // p tag must be present (routing) + assert!(has_tag_name(client_event, "p")); + // No encryption tags when Disabled (the unit test covers the Optional case) + assert!( + !has_tag_name(client_event, tags::SUPPORT_ENCRYPTION), + "Disabled client must not emit support_encryption" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn client_second_request_carries_no_discovery_tags() { + // Second request must never carry discovery tags. One-shot flag behavior + // is covered by unit test client_discovery_tags_sent_once. + let pool = Arc::new(MockRelayPool::new()); + let server_keys = Keys::generate(); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_keys.public_key().to_hex(), + encryption_mode: EncryptionMode::Disabled, + gift_wrap_mode: GiftWrapMode::Optional, + ..Default::default() + }, + Arc::clone(&pool) as Arc, + ) + .await + .unwrap(); + + client.start().await.unwrap(); + let_event_loops_start().await; + + // First request + client + .send(&JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: None, + })) + .await + .unwrap(); + + // Second request + client + .send(&JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(2), + method: "tools/list".to_string(), + params: None, + })) + .await + .unwrap(); + + let events = pool.stored_events().await; + let ctxvm_events: Vec<&Event> = events + .iter() + .filter(|e| e.kind == Kind::Custom(CTXVM_MESSAGES_KIND)) + .collect(); + assert!(ctxvm_events.len() >= 2); + + let second_event = ctxvm_events + .iter() + .find(|e| e.content.contains("tools/list")) + .expect("second request event must exist"); + + assert!( + !has_tag_name(second_event, tags::SUPPORT_ENCRYPTION), + "second request must NOT include discovery tags" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn client_learns_server_capabilities_from_first_response() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + server_info: Some(ServerInfo { + name: Some("CapServer".to_string()), + ..Default::default() + }), + encryption_mode: EncryptionMode::Optional, + gift_wrap_mode: GiftWrapMode::Optional, + ..Default::default() + }, + as_pool(server_pool), + ) + .await + .unwrap(); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(client_pool), + ) + .await + .unwrap(); + + let mut server_rx = server.take_message_receiver().unwrap(); + let mut client_rx = client.take_message_receiver().unwrap(); + server.start().await.unwrap(); + client.start().await.unwrap(); + let_event_loops_start().await; + + client + .send(&JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: None, + })) + .await + .unwrap(); + + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .unwrap() + .unwrap(); + + server + .send_response( + &incoming.event_id, + JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + result: serde_json::json!({}), + }), + ) + .await + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .unwrap(); + + // Client should have learned capabilities from server's first response + let caps = client.discovered_server_capabilities(); + assert!( + caps.supports_encryption, + "client must learn support_encryption from server response tags" + ); + assert!( + caps.supports_ephemeral_encryption, + "client must learn support_encryption_ephemeral from server response tags" + ); + + let baseline = client.get_server_initialize_event(); + assert!(baseline.is_some(), "baseline event must be set"); +} From 9fb41792b25c6e245b6f82d750af28eb3b35abef Mon Sep 17 00:00:00 2001 From: Harsh Date: Tue, 5 May 2026 02:06:20 +0530 Subject: [PATCH 58/69] test: add integration tests for CEP-35 OR-assign, baseline freeze, and discovery tag emission --- src/relay/mock.rs | 5 + tests/transport_integration.rs | 307 +++++++++++++++++++++++++++++++-- 2 files changed, 302 insertions(+), 10 deletions(-) diff --git a/src/relay/mock.rs b/src/relay/mock.rs index 52e52bc..a950235 100644 --- a/src/relay/mock.rs +++ b/src/relay/mock.rs @@ -70,6 +70,11 @@ impl MockRelayPool { self.keys.public_key() } + /// The ephemeral signing keys (for manual event injection in tests). + pub fn mock_keys(&self) -> Keys { + self.keys.clone() + } + /// Like [`new`](Self::new) but with caller-provided signing keys. pub fn with_keys(keys: Keys) -> Self { let (tx, _rx) = tokio::sync::broadcast::channel(1024); diff --git a/tests/transport_integration.rs b/tests/transport_integration.rs index 0d78b97..e1c4914 100644 --- a/tests/transport_integration.rs +++ b/tests/transport_integration.rs @@ -2554,16 +2554,6 @@ fn has_tag_name(event: &Event, name: &str) -> bool { .any(|v| v.first().map(|s| s.as_str()) == Some(name)) } -fn get_tag_value(event: &Event, name: &str) -> Option { - event_tag_vecs(event).iter().find_map(|v| { - if v.first().map(|s| s.as_str()) == Some(name) { - v.get(1).cloned() - } else { - None - } - }) -} - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn server_response_includes_encryption_tags_when_enabled() { let (client_pool, server_pool) = MockRelayPool::create_pair(); @@ -3061,3 +3051,300 @@ async fn client_learns_server_capabilities_from_first_response() { let baseline = client.get_server_initialize_event(); assert!(baseline.is_some(), "baseline event must be set"); } + +// ── CEP-35: OR-assign, baseline-freeze, and Optional emission ──────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn client_or_assigns_capabilities_across_responses() { + // Server with Persistent gift-wrap emits support_encryption but NOT + // support_encryption_ephemeral on the first response. A second event + // carrying support_encryption_ephemeral must OR-assign into the client's + // learned caps without downgrading the already-learned support_encryption. + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + let server_keys = server_pool.mock_keys(); + + let client_pool = Arc::new(client_pool); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + server_info: Some(ServerInfo { + name: Some("PersistentServer".to_string()), + ..Default::default() + }), + encryption_mode: EncryptionMode::Optional, + gift_wrap_mode: GiftWrapMode::Persistent, + ..Default::default() + }, + as_pool(server_pool), + ) + .await + .unwrap(); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + Arc::clone(&client_pool) as Arc, + ) + .await + .unwrap(); + + let mut server_rx = server.take_message_receiver().unwrap(); + let mut client_rx = client.take_message_receiver().unwrap(); + server.start().await.unwrap(); + client.start().await.unwrap(); + let_event_loops_start().await; + + // First roundtrip — server responds with support_encryption only. + client + .send(&JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: None, + })) + .await + .unwrap(); + + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .unwrap() + .unwrap(); + + server + .send_response( + &incoming.event_id, + JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + result: serde_json::json!({}), + }), + ) + .await + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .unwrap(); + + let caps_after_first = client.discovered_server_capabilities(); + assert!( + caps_after_first.supports_encryption, + "first response must teach support_encryption" + ); + assert!( + !caps_after_first.supports_ephemeral_encryption, + "Persistent server must NOT advertise ephemeral on first response" + ); + + // Inject a second plaintext event signed by the server, carrying + // support_encryption_ephemeral (simulates a capability upgrade). + let client_pubkey = client_pool.mock_public_key(); + let second_response = serde_json::json!({ + "jsonrpc": "2.0", + "method": "notifications/progress" + }); + let inject_event = EventBuilder::new( + Kind::Custom(CTXVM_MESSAGES_KIND), + second_response.to_string(), + ) + .tags(vec![ + Tag::public_key(client_pubkey), + Tag::custom( + TagKind::Custom(tags::SUPPORT_ENCRYPTION_EPHEMERAL.into()), + Vec::::new(), + ), + ]) + .sign_with_keys(&server_keys) + .unwrap(); + + client_pool.publish_event(&inject_event).await.unwrap(); + tokio::time::sleep(Duration::from_millis(50)).await; + + let caps_after_second = client.discovered_server_capabilities(); + assert!( + caps_after_second.supports_encryption, + "support_encryption must survive OR-assign (not downgraded)" + ); + assert!( + caps_after_second.supports_ephemeral_encryption, + "support_encryption_ephemeral must be OR-assigned from second event" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn client_baseline_event_not_replaced_by_later_responses() { + // The first inbound event carrying discovery tags becomes the baseline. + // Later events with different tags must NOT replace it. + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + let server_keys = server_pool.mock_keys(); + + let client_pool = Arc::new(client_pool); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + server_info: Some(ServerInfo { + name: Some("BaselineServer".to_string()), + ..Default::default() + }), + encryption_mode: EncryptionMode::Optional, + gift_wrap_mode: GiftWrapMode::Optional, + ..Default::default() + }, + as_pool(server_pool), + ) + .await + .unwrap(); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + Arc::clone(&client_pool) as Arc, + ) + .await + .unwrap(); + + let mut server_rx = server.take_message_receiver().unwrap(); + let mut client_rx = client.take_message_receiver().unwrap(); + server.start().await.unwrap(); + client.start().await.unwrap(); + let_event_loops_start().await; + + // First roundtrip — establishes baseline. + client + .send(&JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: None, + })) + .await + .unwrap(); + + let incoming = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .unwrap() + .unwrap(); + + server + .send_response( + &incoming.event_id, + JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + result: serde_json::json!({}), + }), + ) + .await + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_millis(500), client_rx.recv()) + .await + .unwrap(); + + let baseline = client.get_server_initialize_event(); + assert!( + baseline.is_some(), + "baseline must be set after first response" + ); + let baseline_id = baseline.unwrap().id; + + // Inject a second event with different discovery tags. + let client_pubkey = client_pool.mock_public_key(); + let notification = serde_json::json!({ + "jsonrpc": "2.0", + "method": "notifications/progress" + }); + let inject_event = + EventBuilder::new(Kind::Custom(CTXVM_MESSAGES_KIND), notification.to_string()) + .tags(vec![ + Tag::public_key(client_pubkey), + Tag::custom( + TagKind::Custom(tags::SUPPORT_ENCRYPTION.into()), + Vec::::new(), + ), + ]) + .sign_with_keys(&server_keys) + .unwrap(); + + client_pool.publish_event(&inject_event).await.unwrap(); + tokio::time::sleep(Duration::from_millis(50)).await; + + let baseline_after = client.get_server_initialize_event(); + assert_eq!( + baseline_after.unwrap().id, + baseline_id, + "baseline event must NOT be replaced by later events" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn client_optional_encryption_emits_discovery_tags() { + // Client with Optional encryption must include discovery tags in the + // inner signed event. We decrypt the published gift wrap to verify. + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + let server_keys = server_pool.mock_keys(); + + let client_pool = Arc::new(client_pool); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Optional, + gift_wrap_mode: GiftWrapMode::Optional, + ..Default::default() + }, + Arc::clone(&client_pool) as Arc, + ) + .await + .unwrap(); + + client.start().await.unwrap(); + let_event_loops_start().await; + + client + .send(&JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: None, + })) + .await + .unwrap(); + + let events = client_pool.stored_events().await; + let gift_wrap = events + .iter() + .find(|e| { + e.kind == Kind::Custom(GIFT_WRAP_KIND) + || e.kind == Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND) + }) + .expect("Optional encryption must produce a gift-wrapped event"); + + // Decrypt using the server's keys (the recipient). + let signer: Arc = Arc::new(server_keys); + let decrypted_json = + contextvm_sdk::encryption::decrypt_gift_wrap_single_layer(&signer, gift_wrap) + .await + .expect("gift wrap must be decryptable with server keys"); + + let inner: Event = + serde_json::from_str(&decrypted_json).expect("decrypted content must be a valid Event"); + + assert!( + has_tag_name(&inner, tags::SUPPORT_ENCRYPTION), + "inner event must carry support_encryption tag" + ); + assert!( + has_tag_name(&inner, tags::SUPPORT_ENCRYPTION_EPHEMERAL), + "inner event must carry support_encryption_ephemeral tag (Optional gift-wrap mode)" + ); +} From ae9c85189c86884984e4fe4252ed5af32a663dd4 Mon Sep 17 00:00:00 2001 From: Harsh Date: Tue, 5 May 2026 16:27:05 +0530 Subject: [PATCH 59/69] fix: add TTL sweep to client and server correlation stores to prevent pending-request leak --- src/gateway/mod.rs | 1 + src/transport/client/correlation_store.rs | 59 +++ src/transport/client/mod.rs | 428 ++++++++++++---------- src/transport/server/correlation_store.rs | 63 ++++ src/transport/server/mod.rs | 23 ++ 5 files changed, 389 insertions(+), 185 deletions(-) diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index e9d00f2..f486b1c 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -128,6 +128,7 @@ mod tests { excluded_capabilities: vec![], cleanup_interval: Duration::from_secs(120), session_timeout: Duration::from_secs(600), + request_timeout: Duration::from_secs(60), log_file_path: None, }; diff --git a/src/transport/client/correlation_store.rs b/src/transport/client/correlation_store.rs index 858b37a..537c1ae 100644 --- a/src/transport/client/correlation_store.rs +++ b/src/transport/client/correlation_store.rs @@ -2,6 +2,7 @@ use std::num::NonZeroUsize; use std::sync::Arc; +use std::time::{Duration, Instant}; use lru::LruCache; use tokio::sync::RwLock; @@ -15,6 +16,8 @@ pub struct PendingRequest { pub original_id: serde_json::Value, /// Whether this request is an `initialize` handshake. pub is_initialize: bool, + /// When the request was registered. + pub registered_at: Instant, } /// Tracks pending request event IDs and their original request IDs on the client side. @@ -59,6 +62,7 @@ impl ClientCorrelationStore { PendingRequest { original_id, is_initialize, + registered_at: Instant::now(), }, ); } @@ -95,6 +99,25 @@ impl ClientCorrelationStore { self.pending_requests.read().await.len() } + /// Remove all entries older than `timeout`. Returns the number of entries removed. + pub async fn sweep_expired(&self, timeout: Duration) -> usize { + let now = Instant::now(); + let mut cache = self.pending_requests.write().await; + let mut expired_keys = Vec::new(); + + for (key, entry) in cache.iter() { + if now.duration_since(entry.registered_at) >= timeout { + expired_keys.push(key.clone()); + } + } + + let count = expired_keys.len(); + for key in expired_keys { + cache.pop(&key); + } + count + } + pub async fn clear(&self) { self.pending_requests.write().await.clear(); } @@ -150,4 +173,40 @@ mod tests { assert!(!store.contains("e0").await); assert!(store.contains(&format!("e{DEFAULT_LRU_SIZE}")).await); } + + #[tokio::test] + async fn sweep_expired_removes_only_stale_entries() { + let store = ClientCorrelationStore::new(); + + // Insert an entry that will be "old" by the time we sweep. + store + .register("old".into(), serde_json::json!(1), false) + .await; + + // Sleep so "old" entry ages past the threshold. + tokio::time::sleep(Duration::from_millis(20)).await; + + // Insert a fresh entry. + store + .register("fresh".into(), serde_json::json!(2), false) + .await; + + // Sweep with a 10ms timeout — "old" should be removed, "fresh" should remain. + let swept = store.sweep_expired(Duration::from_millis(10)).await; + assert_eq!(swept, 1); + assert!(!store.contains("old").await); + assert!(store.contains("fresh").await); + } + + #[tokio::test] + async fn sweep_expired_returns_zero_when_nothing_expired() { + let store = ClientCorrelationStore::new(); + store + .register("e1".into(), serde_json::Value::Null, false) + .await; + + let swept = store.sweep_expired(Duration::from_secs(60)).await; + assert_eq!(swept, 0); + assert!(store.contains("e1").await); + } } diff --git a/src/transport/client/mod.rs b/src/transport/client/mod.rs index 65aa47b..0aaafc6 100644 --- a/src/transport/client/mod.rs +++ b/src/transport/client/mod.rs @@ -243,6 +243,7 @@ impl NostrClientTransport { let init_event = self.server_initialize_event.clone(); let server_supports_ephemeral = self.server_supports_ephemeral.clone(); let seen_gift_wrap_ids = self.seen_gift_wrap_ids.clone(); + let timeout = self.config.timeout; tokio::spawn(async move { Self::event_loop( @@ -256,6 +257,7 @@ impl NostrClientTransport { init_event, server_supports_ephemeral, seen_gift_wrap_ids, + timeout, ) .await; }); @@ -393,200 +395,47 @@ impl NostrClientTransport { init_event: Arc>>, server_supports_ephemeral: Arc, seen_gift_wrap_ids: Arc>>, + timeout: Duration, ) { let mut notifications = relay_pool.notifications(); - - while let Ok(notification) = notifications.recv().await { - if let RelayPoolNotification::Event { event, .. } = notification { - let is_gift_wrap = is_gift_wrap_kind(&event.kind); - let outer_kind = event.kind.as_u16(); - - // Enforce encryption mode before decrypt/parse. - if violates_encryption_policy(&event.kind, &encryption_mode) { - if is_gift_wrap { - tracing::warn!( - target: LOG_TARGET, - event_id = %event.id.to_hex(), - event_kind = outer_kind, - configured_mode = ?gift_wrap_mode, - "Skipping encrypted response because client encryption is disabled" - ); - } else { - tracing::warn!( - target: LOG_TARGET, - event_id = %event.id.to_hex(), - "Skipping plaintext response because client encryption is required" - ); - } - continue; - } - - // Enforce CEP-19 gift-wrap-mode policy. - if is_gift_wrap && !gift_wrap_mode.allows_kind(outer_kind) { - tracing::warn!( - target: LOG_TARGET, - event_id = %event.id.to_hex(), - event_kind = outer_kind, - configured_mode = ?gift_wrap_mode, - "Skipping gift wrap due to CEP-19 policy" - ); - continue; - } - - // Handle gift-wrapped events - let (actual_event_content, actual_pubkey, e_tag, verified_tags, source_event) = - if is_gift_wrap { - { - let guard = match seen_gift_wrap_ids.lock() { - Ok(g) => g, - Err(poisoned) => poisoned.into_inner(), - }; - if guard.contains(&event.id) { - tracing::debug!( - target: LOG_TARGET, - event_id = %event.id.to_hex(), - "Skipping duplicate gift-wrap (outer id)" - ); - continue; - } - } - // Single-layer NIP-44 decrypt (matches JS/TS SDK) - let signer = match relay_pool.signer().await { - Ok(s) => s, - Err(error) => { - tracing::error!( - target: LOG_TARGET, - error = %error, - "Failed to get signer" - ); - continue; - } - }; - match encryption::decrypt_gift_wrap_single_layer(&signer, &event).await { - Ok(decrypted_json) => { - match serde_json::from_str::(&decrypted_json) { - Ok(inner) => { - if let Err(e) = inner.verify() { - tracing::warn!( - "Inner event signature verification failed: {e}" - ); - continue; - } - { - let mut guard = match seen_gift_wrap_ids.lock() { - Ok(g) => g, - Err(poisoned) => poisoned.into_inner(), - }; - guard.put(event.id, ()); - } - let e_tag = serializers::get_tag_value(&inner.tags, "e"); - let inner_clone = inner.clone(); - ( - inner.content, - inner.pubkey, - e_tag, - inner.tags, - inner_clone, - ) - } - Err(error) => { - tracing::error!( - target: LOG_TARGET, - error = %error, - "Failed to parse inner event" - ); - continue; - } - } - } - Err(error) => { - tracing::error!( - target: LOG_TARGET, - error = %error, - "Failed to decrypt gift wrap" - ); - continue; - } - } - } else { - let e_tag = serializers::get_tag_value(&event.tags, "e"); - let event_clone = (*event).clone(); - ( - event.content.clone(), - event.pubkey, - e_tag, - event.tags.clone(), - event_clone, - ) + // Sweep interval: half the timeout, clamped to [1s, 30s]. + let sweep_interval = (timeout / 2).clamp(Duration::from_secs(1), Duration::from_secs(30)); + let mut sweep_timer = + tokio::time::interval_at(tokio::time::Instant::now() + sweep_interval, sweep_interval); + + loop { + tokio::select! { + result = notifications.recv() => { + let notification = match result { + Ok(n) => n, + Err(_) => break, }; - - // Verify it's from our server - if actual_pubkey != server_pubkey { - tracing::debug!( - target: LOG_TARGET, - event_pubkey = %actual_pubkey.to_hex(), - expected_pubkey = %server_pubkey.to_hex(), - "Skipping event from unexpected pubkey" - ); - continue; + Self::handle_notification( + ¬ification, + &pending, + server_pubkey, + &tx, + encryption_mode, + gift_wrap_mode, + &discovered_caps, + &init_event, + &server_supports_ephemeral, + &seen_gift_wrap_ids, + &relay_pool, + ) + .await; } - - // CEP-35: learn server capabilities from discovery tags - Self::learn_server_discovery(&discovered_caps, &init_event, &source_event); - - // CEP-19: learn ephemeral support from server - if Self::should_learn_ephemeral_support( - actual_pubkey, - server_pubkey, - if is_gift_wrap { Some(outer_kind) } else { None }, - &verified_tags, - ) { - server_supports_ephemeral.store(true, Ordering::Relaxed); - } - - // Correlate response - if let Some(ref correlated_id) = e_tag { - let is_pending = pending.contains(correlated_id.as_str()).await; - if !is_pending { + _ = sweep_timer.tick() => { + let swept = pending.sweep_expired(timeout).await; + if swept > 0 { tracing::warn!( target: LOG_TARGET, - correlated_event_id = %correlated_id, - "Response for unknown request" + swept, + timeout_ms = timeout.as_millis() as u64, + "Swept stale pending requests (rmcp handles timeout errors)" ); - continue; } } - - // Parse MCP message - if let Some(mcp_msg) = validation::validate_and_parse(&actual_event_content) { - // Drop uncorrelated responses and server-to-client requests (matches TS SDK). - match &mcp_msg { - JsonRpcMessage::Response(_) | JsonRpcMessage::ErrorResponse(_) - if e_tag.is_none() => - { - tracing::warn!( - target: LOG_TARGET, - "Dropping response/error without correlation `e` tag" - ); - continue; - } - JsonRpcMessage::Request(_) => { - tracing::warn!( - target: LOG_TARGET, - method = ?mcp_msg.method(), - "Dropping server-to-client request (invalid in MCP)" - ); - continue; - } - _ => {} - } - - // Clean up pending request - if let Some(ref correlated_id) = e_tag { - pending.remove(correlated_id.as_str()).await; - } - let _ = tx.send(mcp_msg); - } } } } @@ -673,6 +522,206 @@ impl NostrClientTransport { *guard } + #[allow(clippy::too_many_arguments)] + async fn handle_notification( + notification: &RelayPoolNotification, + pending: &ClientCorrelationStore, + server_pubkey: PublicKey, + tx: &tokio::sync::mpsc::UnboundedSender, + encryption_mode: EncryptionMode, + gift_wrap_mode: GiftWrapMode, + discovered_caps: &Arc>, + init_event: &Arc>>, + server_supports_ephemeral: &Arc, + seen_gift_wrap_ids: &Arc>>, + relay_pool: &Arc, + ) { + let event = match notification { + RelayPoolNotification::Event { event, .. } => event, + _ => return, + }; + + let is_gift_wrap = is_gift_wrap_kind(&event.kind); + let outer_kind = event.kind.as_u16(); + + // Enforce encryption mode before decrypt/parse. + if violates_encryption_policy(&event.kind, &encryption_mode) { + if is_gift_wrap { + tracing::warn!( + target: LOG_TARGET, + event_id = %event.id.to_hex(), + event_kind = outer_kind, + configured_mode = ?gift_wrap_mode, + "Skipping encrypted response because client encryption is disabled" + ); + } else { + tracing::warn!( + target: LOG_TARGET, + event_id = %event.id.to_hex(), + "Skipping plaintext response because client encryption is required" + ); + } + return; + } + + // Enforce CEP-19 gift-wrap-mode policy. + if is_gift_wrap && !gift_wrap_mode.allows_kind(outer_kind) { + tracing::warn!( + target: LOG_TARGET, + event_id = %event.id.to_hex(), + event_kind = outer_kind, + configured_mode = ?gift_wrap_mode, + "Skipping gift wrap due to CEP-19 policy" + ); + return; + } + + // Handle gift-wrapped events + let (actual_event_content, actual_pubkey, e_tag, verified_tags, source_event) = + if is_gift_wrap { + { + let guard = match seen_gift_wrap_ids.lock() { + Ok(g) => g, + Err(poisoned) => poisoned.into_inner(), + }; + if guard.contains(&event.id) { + tracing::debug!( + target: LOG_TARGET, + event_id = %event.id.to_hex(), + "Skipping duplicate gift-wrap (outer id)" + ); + return; + } + } + // Single-layer NIP-44 decrypt (matches JS/TS SDK) + let signer = match relay_pool.signer().await { + Ok(s) => s, + Err(error) => { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to get signer" + ); + return; + } + }; + match encryption::decrypt_gift_wrap_single_layer(&signer, event).await { + Ok(decrypted_json) => match serde_json::from_str::(&decrypted_json) { + Ok(inner) => { + if let Err(e) = inner.verify() { + tracing::warn!("Inner event signature verification failed: {e}"); + return; + } + { + let mut guard = match seen_gift_wrap_ids.lock() { + Ok(g) => g, + Err(poisoned) => poisoned.into_inner(), + }; + guard.put(event.id, ()); + } + let e_tag = serializers::get_tag_value(&inner.tags, "e"); + let inner_clone = inner.clone(); + (inner.content, inner.pubkey, e_tag, inner.tags, inner_clone) + } + Err(error) => { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to parse inner event" + ); + return; + } + }, + Err(error) => { + tracing::error!( + target: LOG_TARGET, + error = %error, + "Failed to decrypt gift wrap" + ); + return; + } + } + } else { + let e_tag = serializers::get_tag_value(&event.tags, "e"); + let event_clone: Event = (**event).clone(); + ( + event.content.clone(), + event.pubkey, + e_tag, + event.tags.clone(), + event_clone, + ) + }; + + // Verify it's from our server + if actual_pubkey != server_pubkey { + tracing::debug!( + target: LOG_TARGET, + event_pubkey = %actual_pubkey.to_hex(), + expected_pubkey = %server_pubkey.to_hex(), + "Skipping event from unexpected pubkey" + ); + return; + } + + // CEP-35: learn server capabilities from discovery tags + Self::learn_server_discovery(discovered_caps, init_event, &source_event); + + // CEP-19: learn ephemeral support from server + if Self::should_learn_ephemeral_support( + actual_pubkey, + server_pubkey, + if is_gift_wrap { Some(outer_kind) } else { None }, + &verified_tags, + ) { + server_supports_ephemeral.store(true, Ordering::Relaxed); + } + + // Correlate response + if let Some(ref correlated_id) = e_tag { + let is_pending = pending.contains(correlated_id.as_str()).await; + if !is_pending { + tracing::warn!( + target: LOG_TARGET, + correlated_event_id = %correlated_id, + "Response for unknown request" + ); + return; + } + } + + // Parse MCP message + if let Some(mcp_msg) = validation::validate_and_parse(&actual_event_content) { + // Drop uncorrelated responses and server-to-client requests (matches TS SDK). + match &mcp_msg { + JsonRpcMessage::Response(_) | JsonRpcMessage::ErrorResponse(_) + if e_tag.is_none() => + { + tracing::warn!( + target: LOG_TARGET, + "Dropping response/error without correlation `e` tag" + ); + return; + } + JsonRpcMessage::Request(_) => { + tracing::warn!( + target: LOG_TARGET, + method = ?mcp_msg.method(), + "Dropping server-to-client request (invalid in MCP)" + ); + return; + } + _ => {} + } + + // Clean up pending request + if let Some(ref correlated_id) = e_tag { + pending.remove(correlated_id.as_str()).await; + } + let _ = tx.send(mcp_msg); + } + } + fn choose_outbound_gift_wrap_kind(&self) -> u16 { match self.config.gift_wrap_mode { GiftWrapMode::Persistent => GIFT_WRAP_KIND, @@ -752,6 +801,15 @@ mod tests { assert!(config.is_stateless); } + #[test] + fn test_custom_timeout_config() { + let config = NostrClientTransportConfig { + timeout: Duration::from_secs(60), + ..Default::default() + }; + assert_eq!(config.timeout, Duration::from_secs(60)); + } + #[test] fn test_has_support_ephemeral_tag_detects_capability() { let tags = Tags::from_list(vec![Tag::custom( diff --git a/src/transport/server/correlation_store.rs b/src/transport/server/correlation_store.rs index cad616e..80353a8 100644 --- a/src/transport/server/correlation_store.rs +++ b/src/transport/server/correlation_store.rs @@ -3,6 +3,7 @@ use std::collections::{HashMap, HashSet}; use std::num::NonZeroUsize; use std::sync::Arc; +use std::time::{Duration, Instant}; use lru::LruCache; use tokio::sync::RwLock; @@ -21,6 +22,8 @@ pub struct RouteEntry { /// The outer gift-wrap event kind that carried this request (e.g. 1059 or 21059). /// Populated from the inbound event in a later PR; `None` until then. pub wrap_kind: Option, + /// When the route was registered. + pub registered_at: Instant, } /// Internal state behind the lock. @@ -127,6 +130,7 @@ impl ServerEventRouteStore { original_request_id, progress_token, wrap_kind: None, + registered_at: Instant::now(), }, ); @@ -222,6 +226,26 @@ impl ServerEventRouteStore { self.inner.read().await.progress_token_to_event.len() } + /// Remove all route entries older than `timeout`. + /// (Routes for expired sessions are already cleaned by `cleanup_sessions`.) + /// Returns the event IDs of the removed entries. + pub async fn sweep_stale_routes(&self, timeout: Duration) -> Vec { + let now = Instant::now(); + let mut inner = self.inner.write().await; + let mut expired_keys = Vec::new(); + + for (key, entry) in inner.routes.iter() { + if now.duration_since(entry.registered_at) >= timeout { + expired_keys.push(key.clone()); + } + } + + for key in &expired_keys { + inner.remove_route(key); + } + expired_keys + } + pub async fn clear(&self) { let mut inner = self.inner.write().await; inner.routes.clear(); @@ -321,4 +345,43 @@ mod tests { assert!(!store.has_event_route("e0").await); assert!(store.has_event_route(&format!("e{DEFAULT_LRU_SIZE}")).await); } + + #[tokio::test] + async fn sweep_stale_routes_removes_only_expired() { + let store = ServerEventRouteStore::new(); + + // Insert a route that will age past the threshold. + store + .register("old".into(), "pk1".into(), json!(1), Some("tok1".into())) + .await; + + tokio::time::sleep(Duration::from_millis(20)).await; + + // Insert a fresh route. + store + .register("fresh".into(), "pk2".into(), json!(2), None) + .await; + + // Sweep with 10ms timeout — "old" should be removed, "fresh" should remain. + let swept = store.sweep_stale_routes(Duration::from_millis(10)).await; + assert_eq!(swept.len(), 1); + assert_eq!(swept[0], "old"); + assert!(!store.has_event_route("old").await); + assert!(store.has_event_route("fresh").await); + // Secondary indexes should also be cleaned. + assert!(!store.has_progress_token("tok1").await); + assert!(!store.has_active_routes_for_client("pk1").await); + } + + #[tokio::test] + async fn sweep_stale_routes_returns_zero_when_nothing_expired() { + let store = ServerEventRouteStore::new(); + store + .register("e1".into(), "pk1".into(), json!(1), None) + .await; + + let swept = store.sweep_stale_routes(Duration::from_secs(60)).await; + assert!(swept.is_empty()); + assert!(store.has_event_route("e1").await); + } } diff --git a/src/transport/server/mod.rs b/src/transport/server/mod.rs index 022aee1..9dd3fc3 100644 --- a/src/transport/server/mod.rs +++ b/src/transport/server/mod.rs @@ -52,6 +52,8 @@ pub struct NostrServerTransportConfig { pub cleanup_interval: Duration, /// Session timeout (default: 300s). pub session_timeout: Duration, + /// Request route timeout (default: 60s). Stale routes older than this are swept. + pub request_timeout: Duration, /// Optional log file path. Logs always go to stdout and are also appended here when set. pub log_file_path: Option, } @@ -68,6 +70,7 @@ impl Default for NostrServerTransportConfig { excluded_capabilities: Vec::new(), cleanup_interval: Duration::from_secs(60), session_timeout: Duration::from_secs(300), + request_timeout: Duration::from_secs(60), log_file_path: None, } } @@ -277,6 +280,7 @@ impl NostrServerTransport { let request_wrap_kinds_cleanup = self.request_wrap_kinds.clone(); let cleanup_interval = self.config.cleanup_interval; let session_timeout = self.config.session_timeout; + let request_timeout = self.config.request_timeout; tokio::spawn(async move { let mut interval = tokio::time::interval(cleanup_interval); @@ -296,6 +300,24 @@ impl NostrServerTransport { "Cleaned up inactive sessions" ); } + + // Sweep stale route entries in active sessions (rmcp handles timeout errors). + let swept_event_ids = event_routes_cleanup + .sweep_stale_routes(request_timeout) + .await; + if !swept_event_ids.is_empty() { + let mut kinds_w = request_wrap_kinds_cleanup.write().await; + for event_id in &swept_event_ids { + kinds_w.remove(event_id); + } + drop(kinds_w); + tracing::warn!( + target: LOG_TARGET, + swept = swept_event_ids.len(), + timeout_secs = request_timeout.as_secs(), + "Swept stale event routes (rmcp handles timeout errors)" + ); + } } }); @@ -1503,6 +1525,7 @@ mod tests { assert!(config.excluded_capabilities.is_empty()); assert_eq!(config.cleanup_interval, Duration::from_secs(60)); assert_eq!(config.session_timeout, Duration::from_secs(300)); + assert_eq!(config.request_timeout, Duration::from_secs(60)); assert!(config.server_info.is_none()); assert!(config.log_file_path.is_none()); } From 5d4d5b7a7aa6e8825f5edb96022191f09b7b51dd Mon Sep 17 00:00:00 2001 From: Harsh Date: Tue, 5 May 2026 03:57:08 +0530 Subject: [PATCH 60/69] fix: remove single-peer barrier in RMCP worker and add LRU-bounded SessionStore --- src/gateway/mod.rs | 1 + src/relay/mock.rs | 18 ++ src/rmcp_transport/pipeline_tests.rs | 158 +++++++++------ src/rmcp_transport/worker.rs | 130 ++++-------- src/transport/server/mod.rs | 52 +++-- src/transport/server/session_store.rs | 279 ++++++++++++++++++++++---- tests/conformance_stores.rs | 23 ++- tests/transport_integration.rs | 175 ++++++++++++++++ 8 files changed, 608 insertions(+), 228 deletions(-) diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index e9d00f2..443b649 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -126,6 +126,7 @@ mod tests { is_announced_server: true, allowed_public_keys: vec!["abc123".to_string()], excluded_capabilities: vec![], + max_sessions: 1000, cleanup_interval: Duration::from_secs(120), session_timeout: Duration::from_secs(600), log_file_path: None, diff --git a/src/relay/mock.rs b/src/relay/mock.rs index a950235..2b181d6 100644 --- a/src/relay/mock.rs +++ b/src/relay/mock.rs @@ -105,6 +105,24 @@ impl MockRelayPool { (a, b) } + /// Create `n` linked mock relay pools with different signing keys. + /// + /// All pools share the same event store and notification channel so events + /// published by any one pool are visible to all others' `notifications()` + /// receivers. Useful for multi-client integration tests. + pub fn create_linked_group(n: usize) -> Vec { + assert!(n > 0, "group must have at least one pool"); + let (tx, _rx) = tokio::sync::broadcast::channel(1024); + let inner = Arc::new(Mutex::new(MockRelayInner::new())); + (0..n) + .map(|_| Self { + inner: Arc::clone(&inner), + notification_tx: tx.clone(), + keys: Keys::generate(), + }) + .collect() + } + /// Clone of all events published so far (useful for assertions in tests). pub async fn stored_events(&self) -> Vec { self.inner.lock().await.events.clone() diff --git a/src/rmcp_transport/pipeline_tests.rs b/src/rmcp_transport/pipeline_tests.rs index a1a6859..a036799 100644 --- a/src/rmcp_transport/pipeline_tests.rs +++ b/src/rmcp_transport/pipeline_tests.rs @@ -14,16 +14,13 @@ #[cfg(all(test, feature = "rmcp"))] mod tests { - use std::collections::HashMap; - use rmcp::model::{ ClientJsonRpcMessage, ClientResult, RequestId, ServerJsonRpcMessage, ServerResult, }; use crate::core::serializers; use crate::core::types::{ - JsonRpcError, JsonRpcErrorResponse, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, - JsonRpcResponse, + JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, }; use crate::rmcp_transport::convert::{ internal_to_rmcp_client_rx, internal_to_rmcp_server_rx, rmcp_client_tx_to_internal, @@ -210,89 +207,120 @@ mod tests { assert_eq!(v["result"]["tools"], serde_json::json!([])); } - // ── Layer 5: ID correlation map logic (mirrors NostrServerWorker) ──────── + // ── Layer 5: event_id-based request correlation (mirrors NostrServerWorker) ── #[test] - fn layer5_worker_correlation_map_number_id() { - let mut request_id_to_event_id: HashMap = HashMap::new(); - let fake_event_id = "aaaaaa".to_string(); - - // Step 1: incoming request arrives — worker stores req_id → event_id - let req = JsonRpcMessage::Request(JsonRpcRequest { + fn layer5_worker_uses_event_id_as_request_id() { + // Simulate the worker rewriting req.id to the Nostr event_id. + let event_id = "abc123def456"; + let mut req = JsonRpcRequest { jsonrpc: "2.0".to_string(), id: serde_json::json!(42), method: "tools/list".to_string(), params: None, - }); + }; - if let JsonRpcMessage::Request(ref r) = req { - let key = serde_json::to_string(&r.id).unwrap(); - request_id_to_event_id.insert(key, fake_event_id.clone()); - } + // Worker inbound path: rewrite id to event_id + req.id = serde_json::json!(event_id); + assert_eq!(req.id, serde_json::json!("abc123def456")); - // Step 2: rmcp response comes back with id=42 - let response = JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: "2.0".to_string(), - id: serde_json::json!(42), - result: serde_json::json!({}), - }); + // Convert through rmcp bridge — ID must survive the roundtrip + let msg = JsonRpcMessage::Request(req); + let rmcp_rx = internal_to_rmcp_server_rx(&msg).unwrap(); + let v = serde_json::to_value(&rmcp_rx).unwrap(); + assert_eq!(v["id"], serde_json::json!("abc123def456")); - // Step 3: worker looks up the event_id to call send_response - if let JsonRpcMessage::Response(ref r) = response { - let key = serde_json::to_string(&r.id).unwrap(); - let found = request_id_to_event_id.remove(&key); - assert_eq!(found, Some(fake_event_id)); - } else { - panic!("expected Response"); - } + // Simulate rmcp handler echoing the event_id back in the response + let rmcp_tx = ServerJsonRpcMessage::response( + ServerResult::empty(()), + RequestId::String(std::sync::Arc::from(event_id)), + ); + let response = rmcp_server_tx_to_internal(rmcp_tx).unwrap(); - // Map should be empty after handling - assert!(request_id_to_event_id.is_empty()); + // The response ID is the event_id — worker passes it directly to send_response + match response { + JsonRpcMessage::Response(r) => { + assert_eq!(r.id.as_str(), Some(event_id)); + } + other => panic!("expected Response, got {other:?}"), + } } #[test] - fn layer5_worker_correlation_map_string_id() { - let mut request_id_to_event_id: HashMap = HashMap::new(); - let fake_event_id = "bbbbbb".to_string(); - - // String IDs serialize with surrounding quotes: "\"req-abc\"" - let req_id = serde_json::json!("req-abc"); - let key = serde_json::to_string(&req_id).unwrap(); - request_id_to_event_id.insert(key.clone(), fake_event_id.clone()); - - // The response ID serializes identically - let resp_id = serde_json::json!("req-abc"); - let resp_key = serde_json::to_string(&resp_id).unwrap(); - - // Key derived from response ID must match the one stored from request ID - assert_eq!(key, resp_key); - assert_eq!( - request_id_to_event_id.remove(&resp_key), - Some(fake_event_id) + fn layer5_worker_two_clients_no_collision() { + // Two clients both send requests with id: 1. The worker rewrites each + // to its unique Nostr event_id, so no collision occurs. + let event_id_a = "aaaa1111aaaa1111aaaa1111aaaa1111aaaa1111aaaa1111aaaa1111aaaa1111"; + let event_id_b = "bbbb2222bbbb2222bbbb2222bbbb2222bbbb2222bbbb2222bbbb2222bbbb2222"; + + let mut req_a = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "tools/list".to_string(), + params: None, + }; + let mut req_b = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "tools/list".to_string(), + params: None, + }; + + // Worker rewrites both to their respective event IDs + req_a.id = serde_json::json!(event_id_a); + req_b.id = serde_json::json!(event_id_b); + + // After rewrite, the IDs are distinct even though both clients sent id: 1 + assert_ne!(req_a.id, req_b.id); + assert_eq!(req_a.id.as_str(), Some(event_id_a)); + assert_eq!(req_b.id.as_str(), Some(event_id_b)); + + // Responses echo back the event_id — each routes to the correct client + let rmcp_resp_a = ServerJsonRpcMessage::response( + ServerResult::empty(()), + RequestId::String(std::sync::Arc::from(event_id_a)), + ); + let rmcp_resp_b = ServerJsonRpcMessage::response( + ServerResult::empty(()), + RequestId::String(std::sync::Arc::from(event_id_b)), ); + + let resp_a = rmcp_server_tx_to_internal(rmcp_resp_a).unwrap(); + let resp_b = rmcp_server_tx_to_internal(rmcp_resp_b).unwrap(); + + // Each response carries its own event_id — no cross-wiring + assert_eq!(resp_a.id().unwrap().as_str(), Some(event_id_a)); + assert_eq!(resp_b.id().unwrap().as_str(), Some(event_id_b)); } #[test] - fn layer5_error_response_correlation_works() { - let mut map: HashMap = HashMap::new(); - map.insert( - serde_json::to_string(&serde_json::json!(5)).unwrap(), - "evt5".to_string(), - ); - - let error_response = JsonRpcMessage::ErrorResponse(JsonRpcErrorResponse { + fn layer5_error_response_carries_event_id() { + // Error responses also carry the event_id for routing. + let event_id = "deadbeef"; + let mut req = JsonRpcRequest { jsonrpc: "2.0".to_string(), id: serde_json::json!(5), - error: JsonRpcError { - code: -32601, - message: "Method not found".to_string(), + method: "tools/call".to_string(), + params: None, + }; + req.id = serde_json::json!(event_id); + + // rmcp handler returns an error with the rewritten event_id + let rmcp_err = ServerJsonRpcMessage::error( + rmcp::model::ErrorData { + code: rmcp::model::ErrorCode::METHOD_NOT_FOUND, + message: "Method not found".into(), data: None, }, - }); + RequestId::String(std::sync::Arc::from(event_id)), + ); + let internal = rmcp_server_tx_to_internal(rmcp_err).unwrap(); - if let JsonRpcMessage::ErrorResponse(ref r) = error_response { - let key = serde_json::to_string(&r.id).unwrap(); - assert_eq!(map.remove(&key), Some("evt5".to_string())); + match internal { + JsonRpcMessage::ErrorResponse(r) => { + assert_eq!(r.id.as_str(), Some(event_id)); + } + other => panic!("expected ErrorResponse, got {other:?}"), } } diff --git a/src/rmcp_transport/worker.rs b/src/rmcp_transport/worker.rs index 79bb3dd..a42d57e 100644 --- a/src/rmcp_transport/worker.rs +++ b/src/rmcp_transport/worker.rs @@ -3,8 +3,6 @@ //! This file defines wrapper types that bind existing ContextVM Nostr //! transports to rmcp's worker abstraction. -use std::collections::HashMap; - use crate::core::error::Result; use crate::core::types::JsonRpcMessage; use crate::transport::client::{NostrClientTransport, NostrClientTransportConfig}; @@ -19,12 +17,16 @@ use super::convert::{ const LOG_TARGET: &str = "contextvm_sdk::rmcp_transport::worker"; /// rmcp server worker wrapper for ContextVM Nostr server transport. +/// +/// Multiplexes all connected clients through a single rmcp service instance. +/// Inbound requests have their JSON-RPC `id` rewritten to the Nostr `event_id` +/// before being forwarded to the rmcp handler. Since event IDs are globally +/// unique (SHA-256 hashes), this eliminates collisions when different clients +/// use the same JSON-RPC request IDs. The transport's event-route store +/// handles response routing back to the originating client; server-initiated +/// notifications are broadcast to all initialized clients. pub struct NostrServerWorker { transport: NostrServerTransport, - // rmcp service instance is single-peer. Keep one active client per worker. - active_client_pubkey: Option, - // Maps request id (serialized JSON value) -> incoming Nostr event id. - request_id_to_event_id: HashMap, } impl NostrServerWorker { @@ -34,11 +36,7 @@ impl NostrServerWorker { T: nostr_sdk::prelude::IntoNostrSigner, { let transport = NostrServerTransport::new(signer, config).await?; - Ok(Self { - transport, - active_client_pubkey: None, - request_id_to_event_id: HashMap::new(), - }) + Ok(Self { transport }) } /// Access the wrapped transport. @@ -88,46 +86,18 @@ impl Worker for NostrServerWorker { }; let crate::transport::server::IncomingRequest { - message, - client_pubkey, + mut message, event_id, .. } = incoming; - match &self.active_client_pubkey { - Some(active) if active != &client_pubkey => { - tracing::warn!( - target: LOG_TARGET, - active_client = %active, - ignored_client = %client_pubkey, - "Ignoring message from second client: rmcp server worker currently supports one active client per worker" - ); - continue; - } - None => { - tracing::info!( - target: LOG_TARGET, - client_pubkey = %client_pubkey, - "Binding rmcp server worker to first client session" - ); - self.active_client_pubkey = Some(client_pubkey.clone()); - } - _ => {} - } - - if let JsonRpcMessage::Request(req) = &message { - match serde_json::to_string(&req.id) { - Ok(request_key) => { - self.request_id_to_event_id.insert(request_key, event_id); - } - Err(e) => { - tracing::warn!( - target: LOG_TARGET, - error = %e, - "Failed to serialize request id for correlation map" - ); - } - } + // Rewrite the JSON-RPC request ID to the Nostr event_id. + // Event IDs are globally unique (SHA-256), so no collision + // across clients. The transport's event-route store maps + // event_id → (client_pubkey, original_request_id) and + // restores the original ID in `send_response`. + if let JsonRpcMessage::Request(ref mut req) = message { + req.id = serde_json::json!(event_id); } if let Some(rmcp_msg) = internal_to_rmcp_server_rx(&message) { @@ -276,76 +246,44 @@ impl Worker for NostrClientWorker { } impl NostrServerWorker { + /// Forward an outbound message from the rmcp handler to the Nostr transport. + /// + /// Response IDs carry the Nostr event_id set during ingest. The transport's + /// `send_response` uses this to look up the route (client_pubkey + + /// original_request_id) and deliver the response to the correct client. + /// Notifications and server-initiated requests are broadcast to all + /// initialized clients. async fn forward_server_internal(&mut self, message: JsonRpcMessage) -> Result<()> { match message { JsonRpcMessage::Response(resp) => { - let request_key = serde_json::to_string(&resp.id).map_err(|e| { - crate::core::error::Error::Validation(format!( - "failed to serialize rmcp response id for correlation lookup: {e}" - )) + let event_id = resp.id.as_str().map(str::to_owned).ok_or_else(|| { + crate::core::error::Error::Validation( + "rmcp server response id is not a string event_id".to_string(), + ) })?; - let event_id = - if let Some(event_id) = self.request_id_to_event_id.remove(&request_key) { - event_id - } else { - resp.id.as_str().map(str::to_owned).ok_or_else(|| { - crate::core::error::Error::Validation( - "rmcp server response id has no known correlation mapping and is not a string event id" - .to_string(), - ) - })? - }; - self.transport .send_response(&event_id, JsonRpcMessage::Response(resp)) .await } JsonRpcMessage::ErrorResponse(resp) => { - let request_key = serde_json::to_string(&resp.id).map_err(|e| { - crate::core::error::Error::Validation(format!( - "failed to serialize rmcp error response id for correlation lookup: {e}" - )) + let event_id = resp.id.as_str().map(str::to_owned).ok_or_else(|| { + crate::core::error::Error::Validation( + "rmcp server error response id is not a string event_id".to_string(), + ) })?; - let event_id = - if let Some(event_id) = self.request_id_to_event_id.remove(&request_key) { - event_id - } else { - resp.id.as_str().map(str::to_owned).ok_or_else(|| { - crate::core::error::Error::Validation( - "rmcp server error response id has no known correlation mapping and is not a string event id" - .to_string(), - ) - })? - }; - self.transport .send_response(&event_id, JsonRpcMessage::ErrorResponse(resp)) .await } JsonRpcMessage::Notification(notification) => { - let target = self.active_client_pubkey.as_deref().ok_or_else(|| { - crate::core::error::Error::Validation( - "cannot forward rmcp server notification: no active client bound" - .to_string(), - ) - })?; let message = JsonRpcMessage::Notification(notification); - self.transport - .send_notification(target, &message, None) - .await + self.transport.broadcast_notification(&message).await } JsonRpcMessage::Request(request) => { - let target = self.active_client_pubkey.as_deref().ok_or_else(|| { - crate::core::error::Error::Validation( - "cannot forward rmcp server request: no active client bound".to_string(), - ) - })?; let message = JsonRpcMessage::Request(request); - self.transport - .send_notification(target, &message, None) - .await + self.transport.broadcast_notification(&message).await } } } diff --git a/src/transport/server/mod.rs b/src/transport/server/mod.rs index 022aee1..8d75fc5 100644 --- a/src/transport/server/mod.rs +++ b/src/transport/server/mod.rs @@ -48,6 +48,8 @@ pub struct NostrServerTransportConfig { pub allowed_public_keys: Vec, /// Capabilities excluded from pubkey whitelisting. pub excluded_capabilities: Vec, + /// Maximum number of concurrent client sessions (LRU-bounded, default: 1000). + pub max_sessions: usize, /// Session cleanup interval (default: 60s). pub cleanup_interval: Duration, /// Session timeout (default: 300s). @@ -66,6 +68,7 @@ impl Default for NostrServerTransportConfig { is_announced_server: false, allowed_public_keys: Vec::new(), excluded_capabilities: Vec::new(), + max_sessions: session_store::DEFAULT_MAX_SESSIONS, cleanup_interval: Duration::from_secs(60), session_timeout: Duration::from_secs(300), log_file_path: None, @@ -147,10 +150,10 @@ impl NostrServerTransport { encryption_mode: config.encryption_mode, is_connected: false, }, + sessions: SessionStore::with_capacity(config.max_sessions), config, extra_common_tags: Vec::new(), pricing_tags: Vec::new(), - sessions: SessionStore::new(), event_routes: ServerEventRouteStore::new(), request_wrap_kinds: Arc::new(RwLock::new(HashMap::new())), seen_gift_wrap_ids, @@ -184,10 +187,10 @@ impl NostrServerTransport { encryption_mode: config.encryption_mode, is_connected: false, }, + sessions: SessionStore::with_capacity(config.max_sessions), config, extra_common_tags: Vec::new(), pricing_tags: Vec::new(), - sessions: SessionStore::new(), request_wrap_kinds: Arc::new(RwLock::new(HashMap::new())), event_routes: ServerEventRouteStore::new(), seen_gift_wrap_ids, @@ -1054,10 +1057,21 @@ impl NostrServerTransport { } // Session management + let on_evicted_cb = sessions.eviction_callback(); let mut sessions_w = sessions.write().await; - let session = sessions_w - .entry(sender_pubkey.clone()) - .or_insert_with(|| ClientSession::new(is_encrypted)); + if !sessions_w.contains(&sender_pubkey) { + let evicted = + sessions_w.push(sender_pubkey.clone(), ClientSession::new(is_encrypted)); + SessionStore::handle_eviction( + &sender_pubkey, + evicted, + &mut sessions_w, + on_evicted_cb.as_ref(), + &event_routes, + ) + .await; + } + let session = sessions_w.get_mut(&sender_pubkey).unwrap(); session.update_activity(); session.is_encrypted = is_encrypted; @@ -1156,21 +1170,25 @@ impl NostrServerTransport { let mut cleaned = 0; let mut stale_event_ids = Vec::new(); - sessions_w.retain(|pubkey, session| { - if session.last_activity.elapsed() > timeout { + // LruCache has no retain(); collect expired keys then pop each one. + let expired_keys: Vec = sessions_w + .iter() + .filter(|(_, session)| session.last_activity.elapsed() > timeout) + .map(|(k, _)| k.clone()) + .collect(); + + for key in &expired_keys { + if let Some(session) = sessions_w.pop(key) { stale_event_ids.extend(session.pending_requests.keys().cloned()); stale_event_ids.extend(session.event_to_progress_token.keys().cloned()); tracing::debug!( target: LOG_TARGET, - client_pubkey = %pubkey, + client_pubkey = %key, "Session expired" ); cleaned += 1; - false - } else { - true } - }); + } drop(sessions_w); { @@ -1306,10 +1324,7 @@ mod tests { session .pending_requests .insert("evt1".to_string(), serde_json::json!(1)); - sessions - .write() - .await - .insert("pubkey1".to_string(), session); + sessions.write().await.put("pubkey1".to_string(), session); event_routes .register( "evt1".to_string(), @@ -1352,7 +1367,9 @@ mod tests { let event_routes = ServerEventRouteStore::new(); let request_wrap_kinds = Arc::new(RwLock::new(HashMap::new())); - sessions.get_or_create_session("active", false).await; + sessions + .get_or_create_session("active", false, &event_routes) + .await; let cleaned = NostrServerTransport::cleanup_sessions( &sessions, @@ -1501,6 +1518,7 @@ mod tests { assert_eq!(config.gift_wrap_mode, GiftWrapMode::Optional); assert!(config.allowed_public_keys.is_empty()); assert!(config.excluded_capabilities.is_empty()); + assert_eq!(config.max_sessions, 1000); assert_eq!(config.cleanup_interval, Duration::from_secs(60)); assert_eq!(config.session_timeout, Duration::from_secs(300)); assert!(config.server_info.is_none()); diff --git a/src/transport/server/session_store.rs b/src/transport/server/session_store.rs index 6188482..1415a90 100644 --- a/src/transport/server/session_store.rs +++ b/src/transport/server/session_store.rs @@ -1,16 +1,40 @@ //! Server-side session store for managing client sessions. - -use std::collections::HashMap; +//! +//! Uses an LRU cache bounded by `max_sessions` (default 1000, matching the TS SDK +//! server session store). When a new session would exceed capacity the +//! least-recently-used session is evicted. If the evicted session still has +//! active routes in the correlation store it is recreated with clean state +//! (eviction safety, matching TS SDK's `hasActiveRoutesForClient` check), and +//! the optional eviction callback fires so external code can clean up resources. + +use std::num::NonZeroUsize; use std::sync::Arc; +use lru::LruCache; use tokio::sync::RwLock; use crate::core::types::ClientSession; +use crate::transport::server::ServerEventRouteStore; + +const LOG_TARGET: &str = "contextvm_sdk::transport::server::session_store"; + +/// Default maximum number of concurrent client sessions. +/// +/// Matches the TS SDK's `SessionStore` default (`maxSessions ?? 1000`), not +/// the broader `DEFAULT_LRU_SIZE` constant (5000) used elsewhere in the TS SDK. +pub const DEFAULT_MAX_SESSIONS: usize = 1000; + +/// Callback invoked when a session is evicted from the LRU cache. +/// Receives the evicted client's public key (hex). +pub type EvictionCallback = Arc; /// Manages client sessions keyed by public key (hex). +/// +/// Backed by an LRU cache so memory usage is bounded. #[derive(Clone)] pub struct SessionStore { - sessions: Arc>>, + sessions: Arc>>, + on_evicted: Option, } impl Default for SessionStore { @@ -20,20 +44,58 @@ impl Default for SessionStore { } impl SessionStore { + /// Create a store with the default capacity ([`DEFAULT_MAX_SESSIONS`]). pub fn new() -> Self { + Self::with_capacity(DEFAULT_MAX_SESSIONS) + } + + /// Create a store with a specific maximum number of sessions. + pub fn with_capacity(max_sessions: usize) -> Self { Self { - sessions: Arc::new(RwLock::new(HashMap::new())), + sessions: Arc::new(RwLock::new(LruCache::new( + NonZeroUsize::new(max_sessions).expect("max_sessions must be > 0"), + ))), + on_evicted: None, } } + /// Register a callback that fires when a session is evicted from the LRU. + pub fn set_eviction_callback(&mut self, cb: EvictionCallback) { + self.on_evicted = Some(cb); + } + + /// Clone the eviction callback (cheap Arc clone) for use outside the lock. + pub fn eviction_callback(&self) -> Option { + self.on_evicted.clone() + } + /// Get an existing session or create a new one. Returns `true` if a new session was created. - pub async fn get_or_create_session(&self, client_pubkey: &str, is_encrypted: bool) -> bool { + /// + /// `event_routes` is consulted during eviction safety: if the evicted client + /// still has active routes, the session is recreated with clean state + /// (matching TS SDK's `hasActiveRoutesForClient` check). + pub async fn get_or_create_session( + &self, + client_pubkey: &str, + is_encrypted: bool, + event_routes: &ServerEventRouteStore, + ) -> bool { + let on_evicted = self.on_evicted.clone(); let mut sessions = self.sessions.write().await; if let Some(session) = sessions.get_mut(client_pubkey) { session.is_encrypted = is_encrypted; false } else { - sessions.insert(client_pubkey.to_string(), ClientSession::new(is_encrypted)); + let new_session = ClientSession::new(is_encrypted); + let evicted = sessions.push(client_pubkey.to_string(), new_session); + Self::handle_eviction( + client_pubkey, + evicted, + &mut sessions, + on_evicted.as_ref(), + event_routes, + ) + .await; true } } @@ -42,7 +104,7 @@ impl SessionStore { /// Returns `None` if the session does not exist. pub async fn get_session(&self, client_pubkey: &str) -> Option { let sessions = self.sessions.read().await; - sessions.get(client_pubkey).map(|s| SessionSnapshot { + sessions.peek(client_pubkey).map(|s| SessionSnapshot { is_initialized: s.is_initialized, is_encrypted: s.is_encrypted, has_sent_common_tags: s.has_sent_common_tags, @@ -74,7 +136,7 @@ impl SessionStore { /// Remove a session. Returns `true` if it existed. pub async fn remove_session(&self, client_pubkey: &str) -> bool { - self.sessions.write().await.remove(client_pubkey).is_some() + self.sessions.write().await.pop(client_pubkey).is_some() } /// Remove all sessions. @@ -106,19 +168,59 @@ impl SessionStore { .collect() } - /// Acquire write access to the underlying map (transport internals only). + /// Acquire write access to the underlying LRU cache (transport internals only). pub(crate) async fn write( &self, - ) -> tokio::sync::RwLockWriteGuard<'_, HashMap> { + ) -> tokio::sync::RwLockWriteGuard<'_, LruCache> { self.sessions.write().await } - /// Acquire read access to the underlying map (transport internals only). + /// Acquire read access to the underlying LRU cache (transport internals only). pub(crate) async fn read( &self, - ) -> tokio::sync::RwLockReadGuard<'_, HashMap> { + ) -> tokio::sync::RwLockReadGuard<'_, LruCache> { self.sessions.read().await } + + /// Handle a potential LRU eviction after inserting a session. + /// + /// If the evicted client still has active routes in the correlation store, + /// a clean session is re-inserted (eviction safety, matching TS SDK's + /// `hasActiveRoutesForClient` check). The eviction callback fires only + /// for genuine, non-vetoed evictions. + pub(crate) async fn handle_eviction( + inserted_key: &str, + evicted: Option<(String, ClientSession)>, + sessions: &mut LruCache, + on_evicted: Option<&EvictionCallback>, + event_routes: &ServerEventRouteStore, + ) { + if let Some((evicted_key, evicted_session)) = evicted { + // `push` also returns the old value when the *same* key is updated; + // only act when a *different* key was evicted due to capacity. + if evicted_key != inserted_key { + if event_routes + .has_active_routes_for_client(&evicted_key) + .await + { + tracing::warn!( + target: LOG_TARGET, + client_pubkey = %evicted_key, + "LRU eviction of session with active routes; recreating with clean state" + ); + // Re-insert with clean state so the client isn't orphaned. + // Skip the external callback — the session still exists + // (matches TS SDK: vetoed evictions don't fire the callback). + let _ = sessions.push( + evicted_key.clone(), + ClientSession::new(evicted_session.is_encrypted), + ); + } else if let Some(cb) = on_evicted { + cb(evicted_key); + } + } + } + } } /// A lightweight snapshot of session state (avoids exposing the full `ClientSession` @@ -134,12 +236,18 @@ pub struct SessionSnapshot { #[cfg(test)] mod tests { use super::*; + use serde_json::json; + + fn routes() -> ServerEventRouteStore { + ServerEventRouteStore::new() + } #[tokio::test] async fn create_and_retrieve_session() { let store = SessionStore::new(); + let r = routes(); - let created = store.get_or_create_session("client-1", true).await; + let created = store.get_or_create_session("client-1", true, &r).await; assert!(created); let snap = store.get_session("client-1").await.unwrap(); @@ -150,14 +258,14 @@ mod tests { #[tokio::test] async fn get_or_create_returns_existing() { let store = SessionStore::new(); + let r = routes(); - let created = store.get_or_create_session("client-1", false).await; + let created = store.get_or_create_session("client-1", false, &r).await; assert!(created); - let created2 = store.get_or_create_session("client-1", true).await; + let created2 = store.get_or_create_session("client-1", true, &r).await; assert!(!created2); - // is_encrypted should have been updated. let snap = store.get_session("client-1").await.unwrap(); assert!(snap.is_encrypted); } @@ -165,7 +273,8 @@ mod tests { #[tokio::test] async fn mark_initialized() { let store = SessionStore::new(); - store.get_or_create_session("client-1", false).await; + let r = routes(); + store.get_or_create_session("client-1", false, &r).await; assert!(store.mark_initialized("client-1").await); let snap = store.get_session("client-1").await.unwrap(); @@ -181,7 +290,8 @@ mod tests { #[tokio::test] async fn remove_session() { let store = SessionStore::new(); - store.get_or_create_session("client-1", false).await; + let r = routes(); + store.get_or_create_session("client-1", false, &r).await; assert!(store.remove_session("client-1").await); assert!(store.get_session("client-1").await.is_none()); } @@ -195,8 +305,9 @@ mod tests { #[tokio::test] async fn clear_all_sessions() { let store = SessionStore::new(); - store.get_or_create_session("client-1", false).await; - store.get_or_create_session("client-2", true).await; + let r = routes(); + store.get_or_create_session("client-1", false, &r).await; + store.get_or_create_session("client-2", true, &r).await; store.clear().await; @@ -208,8 +319,9 @@ mod tests { #[tokio::test] async fn get_all_sessions() { let store = SessionStore::new(); - store.get_or_create_session("client-1", false).await; - store.get_or_create_session("client-2", true).await; + let r = routes(); + store.get_or_create_session("client-1", false, &r).await; + store.get_or_create_session("client-2", true, &r).await; let all = store.get_all_sessions().await; assert_eq!(all.len(), 2); @@ -224,10 +336,11 @@ mod tests { #[tokio::test] async fn new_session_capability_fields_default_false() { let store = SessionStore::new(); - store.get_or_create_session("client-1", false).await; + let r = routes(); + store.get_or_create_session("client-1", false, &r).await; let sessions = store.read().await; - let session = sessions.get("client-1").unwrap(); + let session = sessions.peek("client-1").unwrap(); assert!(!session.has_sent_common_tags); assert!(!session.supports_encryption); assert!(!session.supports_ephemeral_encryption); @@ -237,7 +350,8 @@ mod tests { #[tokio::test] async fn has_sent_common_tags_flag() { let store = SessionStore::new(); - store.get_or_create_session("client-1", false).await; + let r = routes(); + store.get_or_create_session("client-1", false, &r).await; let mut sessions = store.write().await; let session = sessions.get_mut("client-1").unwrap(); @@ -249,9 +363,9 @@ mod tests { #[tokio::test] async fn capability_or_assign_persists() { let store = SessionStore::new(); - store.get_or_create_session("client-1", false).await; + let r = routes(); + store.get_or_create_session("client-1", false, &r).await; - // First update: learn encryption support { let mut sessions = store.write().await; let session = sessions.get_mut("client-1").unwrap(); @@ -259,16 +373,15 @@ mod tests { session.supports_ephemeral_encryption |= false; } - // Second update: learn ephemeral support; encryption stays true { let mut sessions = store.write().await; let session = sessions.get_mut("client-1").unwrap(); - session.supports_encryption |= false; // should stay true + session.supports_encryption |= false; session.supports_ephemeral_encryption |= true; } let sessions = store.read().await; - let session = sessions.get("client-1").unwrap(); + let session = sessions.peek("client-1").unwrap(); assert!(session.supports_encryption, "OR-assign must not downgrade"); assert!(session.supports_ephemeral_encryption); assert!(!session.supports_oversized_transfer); @@ -277,8 +390,9 @@ mod tests { #[tokio::test] async fn capability_fields_independent_per_client() { let store = SessionStore::new(); - store.get_or_create_session("client-a", false).await; - store.get_or_create_session("client-b", false).await; + let r = routes(); + store.get_or_create_session("client-a", false, &r).await; + store.get_or_create_session("client-b", false, &r).await; { let mut sessions = store.write().await; @@ -288,8 +402,8 @@ mod tests { } let sessions = store.read().await; - let sa = sessions.get("client-a").unwrap(); - let sb = sessions.get("client-b").unwrap(); + let sa = sessions.peek("client-a").unwrap(); + let sb = sessions.peek("client-b").unwrap(); assert!(sa.supports_encryption); assert!(sa.has_sent_common_tags); assert!(!sb.supports_encryption); @@ -299,9 +413,9 @@ mod tests { #[tokio::test] async fn get_or_create_preserves_capability_fields() { let store = SessionStore::new(); - store.get_or_create_session("client-1", false).await; + let r = routes(); + store.get_or_create_session("client-1", false, &r).await; - // Set capability fields { let mut sessions = store.write().await; let session = sessions.get_mut("client-1").unwrap(); @@ -309,13 +423,11 @@ mod tests { session.has_sent_common_tags = true; } - // Re-enter via get_or_create (existing session) - let created = store.get_or_create_session("client-1", true).await; + let created = store.get_or_create_session("client-1", true, &r).await; assert!(!created); - // Capability fields must survive let sessions = store.read().await; - let session = sessions.get("client-1").unwrap(); + let session = sessions.peek("client-1").unwrap(); assert!(session.supports_encryption); assert!(session.has_sent_common_tags); } @@ -323,7 +435,8 @@ mod tests { #[tokio::test] async fn clear_resets_capability_fields() { let store = SessionStore::new(); - store.get_or_create_session("client-1", false).await; + let r = routes(); + store.get_or_create_session("client-1", false, &r).await; { let mut sessions = store.write().await; let s = sessions.get_mut("client-1").unwrap(); @@ -331,11 +444,91 @@ mod tests { } store.clear().await; - store.get_or_create_session("client-1", false).await; + store.get_or_create_session("client-1", false, &r).await; let sessions = store.read().await; - let session = sessions.get("client-1").unwrap(); + let session = sessions.peek("client-1").unwrap(); assert!(!session.supports_encryption); assert!(!session.has_sent_common_tags); } + + // ── LRU eviction ──────────────────────────────────────────── + + #[tokio::test] + async fn lru_eviction_drops_oldest_session() { + let store = SessionStore::with_capacity(3); + let r = routes(); + store.get_or_create_session("a", false, &r).await; + store.get_or_create_session("b", false, &r).await; + store.get_or_create_session("c", false, &r).await; + + store.get_or_create_session("d", false, &r).await; + + assert!( + store.get_session("a").await.is_none(), + "a should be evicted" + ); + assert!(store.get_session("b").await.is_some()); + assert!(store.get_session("c").await.is_some()); + assert!(store.get_session("d").await.is_some()); + assert_eq!(store.session_count().await, 3); + } + + #[tokio::test] + async fn eviction_callback_fires_on_lru_eviction() { + let evicted = Arc::new(std::sync::Mutex::new(Vec::::new())); + let evicted_clone = evicted.clone(); + let r = routes(); + + let mut store = SessionStore::with_capacity(2); + store.set_eviction_callback(Arc::new(move |pubkey| { + evicted_clone.lock().unwrap().push(pubkey); + })); + + store.get_or_create_session("a", false, &r).await; + store.get_or_create_session("b", false, &r).await; + store.get_or_create_session("c", false, &r).await; + + let evicted = evicted.lock().unwrap(); + assert_eq!(evicted.len(), 1); + assert_eq!(evicted[0], "a"); + } + + #[tokio::test] + async fn eviction_safety_recreates_session_with_active_routes() { + let store = SessionStore::with_capacity(2); + let r = routes(); + store.get_or_create_session("a", true, &r).await; + store.get_or_create_session("b", false, &r).await; + + // Register an active route for client "a" in the correlation store + r.register("evt1".into(), "a".into(), json!(1), None).await; + + // Adding "c" would normally evict "a", but eviction safety recreates it + // because "a" has active routes. + store.get_or_create_session("c", false, &r).await; + + let snap = store.get_session("a").await; + assert!( + snap.is_some(), + "session with active routes must survive eviction" + ); + // "b" was evicted instead (next LRU after "a" was re-inserted) + assert!( + store.get_session("b").await.is_none(), + "b should be evicted" + ); + } + + #[tokio::test] + async fn with_capacity_sets_limit() { + let store = SessionStore::with_capacity(5); + let r = routes(); + for i in 0..10 { + store + .get_or_create_session(&format!("client-{i}"), false, &r) + .await; + } + assert_eq!(store.session_count().await, 5); + } } diff --git a/tests/conformance_stores.rs b/tests/conformance_stores.rs index 55917ec..8216403 100644 --- a/tests/conformance_stores.rs +++ b/tests/conformance_stores.rs @@ -138,11 +138,16 @@ mod client_correlation_store { mod server_session_store { use super::*; + fn routes() -> ServerEventRouteStore { + ServerEventRouteStore::new() + } + #[tokio::test] async fn create_and_retrieve_sessions() { let store = SessionStore::new(); + let r = routes(); - let created = store.get_or_create_session("client-1", true).await; + let created = store.get_or_create_session("client-1", true, &r).await; assert!(created); let session = store.get_session("client-1").await.unwrap(); @@ -156,7 +161,8 @@ mod server_session_store { #[tokio::test] async fn mark_sessions_as_initialized() { let store = SessionStore::new(); - store.get_or_create_session("client-1", false).await; + let r = routes(); + store.get_or_create_session("client-1", false, &r).await; let result = store.mark_initialized("client-1").await; assert!(result); @@ -168,7 +174,8 @@ mod server_session_store { #[tokio::test] async fn remove_sessions() { let store = SessionStore::new(); - store.get_or_create_session("client-1", false).await; + let r = routes(); + store.get_or_create_session("client-1", false, &r).await; let result = store.remove_session("client-1").await; assert!(result); @@ -178,8 +185,9 @@ mod server_session_store { #[tokio::test] async fn clear_all_sessions() { let store = SessionStore::new(); - store.get_or_create_session("client-1", false).await; - store.get_or_create_session("client-2", true).await; + let r = routes(); + store.get_or_create_session("client-1", false, &r).await; + store.get_or_create_session("client-2", true, &r).await; store.clear().await; @@ -191,8 +199,9 @@ mod server_session_store { #[tokio::test] async fn iterate_over_all_sessions() { let store = SessionStore::new(); - store.get_or_create_session("client-1", false).await; - store.get_or_create_session("client-2", true).await; + let r = routes(); + store.get_or_create_session("client-1", false, &r).await; + store.get_or_create_session("client-2", true, &r).await; let sessions = store.get_all_sessions().await; assert_eq!(sessions.len(), 2); diff --git a/tests/transport_integration.rs b/tests/transport_integration.rs index e1c4914..90ffe59 100644 --- a/tests/transport_integration.rs +++ b/tests/transport_integration.rs @@ -3348,3 +3348,178 @@ async fn client_optional_encryption_emits_discovery_tags() { "inner event must carry support_encryption_ephemeral tag (Optional gift-wrap mode)" ); } +// ── Multi-client support ───────────────────────────────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn multi_client_concurrent_requests_both_get_responses() { + // Two different clients send requests to the same server; both must get + // their own response (the single-peer barrier is removed). + let mut pools = MockRelayPool::create_linked_group(3); + let server_pool = pools.remove(0); + let client_b_pool = pools.remove(1); + let client_a_pool = pools.remove(0); + let server_pubkey = server_pool.mock_public_key(); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut client_a = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(client_a_pool), + ) + .await + .expect("create client A"); + + let mut client_b = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(client_b_pool), + ) + .await + .expect("create client B"); + + let mut server_rx = server + .take_message_receiver() + .expect("server message receiver"); + let mut client_a_rx = client_a + .take_message_receiver() + .expect("client A message receiver"); + let mut client_b_rx = client_b + .take_message_receiver() + .expect("client B message receiver"); + + server.start().await.expect("server start"); + client_a.start().await.expect("client A start"); + client_b.start().await.expect("client B start"); + let_event_loops_start().await; + + // Client A sends a request. + let req_a = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "tools/list".to_string(), + params: None, + }); + client_a.send(&req_a).await.expect("client A send"); + + // Client B sends a request. + let req_b = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(2), + method: "tools/list".to_string(), + params: None, + }); + client_b.send(&req_b).await.expect("client B send"); + + // Server receives both requests (order may vary). + let incoming_1 = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout rx 1") + .expect("rx closed 1"); + let incoming_2 = tokio::time::timeout(Duration::from_millis(500), server_rx.recv()) + .await + .expect("timeout rx 2") + .expect("rx closed 2"); + + // Send responses to both. + let resp_1 = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: incoming_1.message.id().unwrap().clone(), + result: serde_json::json!({"tools": []}), + }); + server + .send_response(&incoming_1.event_id, resp_1) + .await + .expect("server respond to 1"); + + let resp_2 = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: incoming_2.message.id().unwrap().clone(), + result: serde_json::json!({"tools": []}), + }); + server + .send_response(&incoming_2.event_id, resp_2) + .await + .expect("server respond to 2"); + + // Both clients must receive their respective response. + let resp_a = tokio::time::timeout(Duration::from_millis(500), client_a_rx.recv()) + .await + .expect("timeout client A response") + .expect("client A channel closed"); + let resp_b = tokio::time::timeout(Duration::from_millis(500), client_b_rx.recv()) + .await + .expect("timeout client B response") + .expect("client B channel closed"); + + assert!( + matches!(resp_a, JsonRpcMessage::Response(_)), + "client A must receive a response" + ); + assert!( + matches!(resp_b, JsonRpcMessage::Response(_)), + "client B must receive a response" + ); +} + +// ── Session store LRU tests ───────────────────────────────────────────────── + +use contextvm_sdk::transport::server::SessionStore; +use contextvm_sdk::ServerEventRouteStore; + +#[tokio::test] +async fn session_store_lru_eviction() { + let store = SessionStore::with_capacity(3); + let r = ServerEventRouteStore::new(); + store.get_or_create_session("a", false, &r).await; + store.get_or_create_session("b", false, &r).await; + store.get_or_create_session("c", false, &r).await; + + // 4th session evicts the oldest ("a") + store.get_or_create_session("d", false, &r).await; + + assert!( + store.get_session("a").await.is_none(), + "oldest session must be evicted when capacity is exceeded" + ); + assert!(store.get_session("b").await.is_some()); + assert!(store.get_session("c").await.is_some()); + assert!(store.get_session("d").await.is_some()); + assert_eq!(store.session_count().await, 3); +} + +#[tokio::test] +async fn session_store_eviction_callback_fires() { + let evicted_keys: Arc>> = + Arc::new(std::sync::Mutex::new(Vec::new())); + let captured = evicted_keys.clone(); + let r = ServerEventRouteStore::new(); + + let mut store = SessionStore::with_capacity(2); + store.set_eviction_callback(std::sync::Arc::new(move |pubkey| { + captured.lock().unwrap().push(pubkey); + })); + + store.get_or_create_session("x", false, &r).await; + store.get_or_create_session("y", false, &r).await; + // Adding "z" evicts "x" + store.get_or_create_session("z", false, &r).await; + + let keys = evicted_keys.lock().unwrap(); + assert_eq!(keys.len(), 1, "callback must fire exactly once"); + assert_eq!(keys[0], "x", "evicted key must be the oldest session"); +} From 4eafd7890b8a53d9e81d9185f2d8fb8fcc392dd4 Mon Sep 17 00:00:00 2001 From: Harsh Date: Tue, 5 May 2026 20:03:28 +0530 Subject: [PATCH 61/69] docs: clarify timeout and request_timeout as correlation-retention TTLs --- src/transport/client/mod.rs | 6 +++++- src/transport/server/mod.rs | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/transport/client/mod.rs b/src/transport/client/mod.rs index 0aaafc6..62c796d 100644 --- a/src/transport/client/mod.rs +++ b/src/transport/client/mod.rs @@ -41,7 +41,11 @@ pub struct NostrClientTransportConfig { pub gift_wrap_mode: GiftWrapMode, /// Stateless mode: emulate initialize response locally. pub is_stateless: bool, - /// Response timeout (default: 30s). + /// Correlation-retention TTL for pending client requests (default: 30s). + /// + /// Stale pending entries older than this are swept from the correlation store. + /// This prevents leaks -- rmcp owns actual request timeout and cancellation. + /// Keep this value above your rmcp request timeout to avoid premature cleanup. pub timeout: Duration, /// Optional log file path. Logs always go to stdout and are also appended here when set. pub log_file_path: Option, diff --git a/src/transport/server/mod.rs b/src/transport/server/mod.rs index 9dd3fc3..4c1b808 100644 --- a/src/transport/server/mod.rs +++ b/src/transport/server/mod.rs @@ -52,7 +52,11 @@ pub struct NostrServerTransportConfig { pub cleanup_interval: Duration, /// Session timeout (default: 300s). pub session_timeout: Duration, - /// Request route timeout (default: 60s). Stale routes older than this are swept. + /// Correlation-retention TTL for server-side event routes (default: 60s). + /// + /// Stale route entries older than this are swept from the correlation store. + /// This prevents leaks -- rmcp owns actual request timeout and cancellation. + /// Keep this value above your rmcp request timeout to avoid premature cleanup. pub request_timeout: Duration, /// Optional log file path. Logs always go to stdout and are also appended here when set. pub log_file_path: Option, From bd381bec81a8594cf4b12dee1455b18c24c8a540 Mon Sep 17 00:00:00 2001 From: Harsh Date: Tue, 5 May 2026 21:59:49 +0530 Subject: [PATCH 62/69] fix: cancel spawned event loop tasks on close() using CancellationToken --- Cargo.toml | 3 ++ src/transport/client/mod.rs | 52 ++++++++++++++---- src/transport/server/mod.rs | 98 +++++++++++++++++++++++++--------- tests/transport_integration.rs | 62 +++++++++++++++++++++ 4 files changed, 181 insertions(+), 34 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 95e675a..9ed83c7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,9 @@ rmcp = { version = "0.16.0", features = ["server", "client", "macros", "transpor # LRU cache for gift-wrap (outer event id) deduplication lru = "0.12" +# CancellationToken for graceful event-loop shutdown +tokio-util = { version = "0.7", features = ["rt"] } + [features] # Enable rmcp by default while keeping legacy APIs available. default = ["rmcp"] diff --git a/src/transport/client/mod.rs b/src/transport/client/mod.rs index 62c796d..53f494b 100644 --- a/src/transport/client/mod.rs +++ b/src/transport/client/mod.rs @@ -14,6 +14,7 @@ use std::time::Duration; use lru::LruCache; use nostr_sdk::prelude::*; +use tokio_util::sync::CancellationToken; use crate::core::constants::*; use crate::core::error::{Error, Result}; @@ -85,8 +86,12 @@ pub struct NostrClientTransport { /// so failed decrypt/verify can be retried on redelivery. seen_gift_wrap_ids: Arc>>, /// Channel for receiving processed MCP messages from the event loop. - message_tx: tokio::sync::mpsc::UnboundedSender, + message_tx: Option>, message_rx: Option>, + /// Token used to cancel the spawned event loop on close(). + cancellation_token: CancellationToken, + /// Handle for the spawned event loop task. + event_loop_handle: Option>, } impl NostrClientTransport { @@ -142,8 +147,10 @@ impl NostrClientTransport { server_initialize_event: Arc::new(Mutex::new(None)), server_supports_ephemeral: Arc::new(AtomicBool::new(false)), seen_gift_wrap_ids, - message_tx: tx, + message_tx: Some(tx), message_rx: Some(rx), + cancellation_token: CancellationToken::new(), + event_loop_handle: None, }) } @@ -190,8 +197,10 @@ impl NostrClientTransport { server_initialize_event: Arc::new(Mutex::new(None)), server_supports_ephemeral: Arc::new(AtomicBool::new(false)), seen_gift_wrap_ids, - message_tx: tx, + message_tx: Some(tx), message_rx: Some(rx), + cancellation_token: CancellationToken::new(), + event_loop_handle: None, }) } @@ -236,11 +245,15 @@ impl NostrClientTransport { error })?; - // Spawn event loop + // Spawn event loop with cancellation support let relay_pool = Arc::clone(&self.base.relay_pool); let pending = self.pending_requests.clone(); let server_pubkey = self.server_pubkey; - let tx = self.message_tx.clone(); + let tx = self + .message_tx + .as_ref() + .expect("message_tx must exist before start()") + .clone(); let encryption_mode = self.config.encryption_mode; let gift_wrap_mode = self.config.gift_wrap_mode; let discovered_caps = self.discovered_server_capabilities.clone(); @@ -248,8 +261,9 @@ impl NostrClientTransport { let server_supports_ephemeral = self.server_supports_ephemeral.clone(); let seen_gift_wrap_ids = self.seen_gift_wrap_ids.clone(); let timeout = self.config.timeout; + let token = self.cancellation_token.child_token(); - tokio::spawn(async move { + self.event_loop_handle = Some(tokio::spawn(async move { Self::event_loop( relay_pool, pending, @@ -262,9 +276,10 @@ impl NostrClientTransport { server_supports_ephemeral, seen_gift_wrap_ids, timeout, + token, ) .await; - }); + })); tracing::info!( target: LOG_TARGET, @@ -274,8 +289,13 @@ impl NostrClientTransport { Ok(()) } - /// Close the transport. + /// Close the transport — cancels the event loop and disconnects from relays. pub async fn close(&mut self) -> Result<()> { + self.cancellation_token.cancel(); + if let Some(handle) = self.event_loop_handle.take() { + let _ = handle.await; + } + self.message_tx.take(); self.base.disconnect().await } @@ -384,7 +404,9 @@ impl NostrClientTransport { } }), }); - let _ = self.message_tx.send(response); + if let Some(ref tx) = self.message_tx { + let _ = tx.send(response); + } } #[allow(clippy::too_many_arguments)] @@ -400,6 +422,7 @@ impl NostrClientTransport { server_supports_ephemeral: Arc, seen_gift_wrap_ids: Arc>>, timeout: Duration, + cancel: CancellationToken, ) { let mut notifications = relay_pool.notifications(); // Sweep interval: half the timeout, clamped to [1s, 30s]. @@ -409,6 +432,13 @@ impl NostrClientTransport { loop { tokio::select! { + _ = cancel.cancelled() => { + tracing::info!( + target: LOG_TARGET, + "Client event loop cancelled" + ); + break; + } result = notifications.recv() => { let notification = match result { Ok(n) => n, @@ -1035,8 +1065,10 @@ mod tests { server_initialize_event: Arc::new(Mutex::new(None)), server_supports_ephemeral: Arc::new(AtomicBool::new(false)), seen_gift_wrap_ids: Arc::new(Mutex::new(LruCache::new(NonZeroUsize::new(10).unwrap()))), - message_tx: tokio::sync::mpsc::unbounded_channel().0, + message_tx: Some(tokio::sync::mpsc::unbounded_channel().0), message_rx: None, + cancellation_token: CancellationToken::new(), + event_loop_handle: None, } } diff --git a/src/transport/server/mod.rs b/src/transport/server/mod.rs index ec4ebcb..a222a3e 100644 --- a/src/transport/server/mod.rs +++ b/src/transport/server/mod.rs @@ -18,6 +18,7 @@ use std::time::Duration; use lru::LruCache; use nostr_sdk::prelude::*; +use tokio_util::sync::CancellationToken; use crate::core::constants::*; use crate::core::error::{Error, Result}; @@ -104,8 +105,12 @@ pub struct NostrServerTransport { /// so failed decrypt/verify can be retried on redelivery. seen_gift_wrap_ids: Arc>>, /// Channel for incoming MCP messages (consumed by the MCP server). - message_tx: tokio::sync::mpsc::UnboundedSender, + message_tx: Option>, message_rx: Option>, + /// Token used to cancel spawned tasks (event loop + cleanup) on close(). + cancellation_token: CancellationToken, + /// Handles for spawned tasks (event loop + cleanup). + task_handles: Vec>, } /// An incoming MCP request with metadata for routing the response. @@ -164,8 +169,10 @@ impl NostrServerTransport { event_routes: ServerEventRouteStore::new(), request_wrap_kinds: Arc::new(RwLock::new(HashMap::new())), seen_gift_wrap_ids, - message_tx: tx, + message_tx: Some(tx), message_rx: Some(rx), + cancellation_token: CancellationToken::new(), + task_handles: Vec::new(), }) } @@ -201,8 +208,10 @@ impl NostrServerTransport { request_wrap_kinds: Arc::new(RwLock::new(HashMap::new())), event_routes: ServerEventRouteStore::new(), seen_gift_wrap_ids, - message_tx: tx, + message_tx: Some(tx), message_rx: Some(rx), + cancellation_token: CancellationToken::new(), + task_handles: Vec::new(), }) } @@ -247,12 +256,16 @@ impl NostrServerTransport { error })?; - // Spawn event loop + // Spawn event loop with cancellation support let relay_pool = Arc::clone(&self.base.relay_pool); let sessions = self.sessions.clone(); let event_routes = self.event_routes.clone(); let request_wrap_kinds = self.request_wrap_kinds.clone(); - let tx = self.message_tx.clone(); + let tx = self + .message_tx + .as_ref() + .expect("message_tx must exist before start()") + .clone(); let allowed = self.config.allowed_public_keys.clone(); let excluded = self.config.excluded_capabilities.clone(); let encryption_mode = self.config.encryption_mode; @@ -261,8 +274,9 @@ impl NostrServerTransport { let server_info = self.config.server_info.clone(); let extra_common_tags = self.extra_common_tags.clone(); let seen_gift_wrap_ids = self.seen_gift_wrap_ids.clone(); + let event_loop_token = self.cancellation_token.child_token(); - tokio::spawn(async move { + let event_loop_handle = tokio::spawn(async move { Self::event_loop( relay_pool, sessions, @@ -277,35 +291,47 @@ impl NostrServerTransport { server_info, extra_common_tags, seen_gift_wrap_ids, + event_loop_token, ) .await; }); - // Spawn session cleanup + // Spawn session cleanup with cancellation support let sessions_cleanup = self.sessions.clone(); let event_routes_cleanup = self.event_routes.clone(); let request_wrap_kinds_cleanup = self.request_wrap_kinds.clone(); let cleanup_interval = self.config.cleanup_interval; let session_timeout = self.config.session_timeout; let request_timeout = self.config.request_timeout; + let cleanup_token = self.cancellation_token.child_token(); - tokio::spawn(async move { + let cleanup_handle = tokio::spawn(async move { let mut interval = tokio::time::interval(cleanup_interval); loop { - interval.tick().await; - let cleaned = Self::cleanup_sessions( - &sessions_cleanup, - &event_routes_cleanup, - &request_wrap_kinds_cleanup, - session_timeout, - ) - .await; - if cleaned > 0 { - tracing::info!( - target: LOG_TARGET, - cleaned_sessions = cleaned, - "Cleaned up inactive sessions" - ); + tokio::select! { + _ = cleanup_token.cancelled() => { + tracing::info!( + target: LOG_TARGET, + "Server cleanup task cancelled" + ); + break; + } + _ = interval.tick() => { + let cleaned = Self::cleanup_sessions( + &sessions_cleanup, + &event_routes_cleanup, + &request_wrap_kinds_cleanup, + session_timeout, + ) + .await; + if cleaned > 0 { + tracing::info!( + target: LOG_TARGET, + cleaned_sessions = cleaned, + "Cleaned up inactive sessions" + ); + } + } } // Sweep stale route entries in active sessions (rmcp handles timeout errors). @@ -328,6 +354,9 @@ impl NostrServerTransport { } }); + self.task_handles.push(event_loop_handle); + self.task_handles.push(cleanup_handle); + tracing::info!( target: LOG_TARGET, relay_count = self.config.relay_urls.len(), @@ -338,8 +367,13 @@ impl NostrServerTransport { Ok(()) } - /// Close the transport. + /// Close the transport — cancels event loop and cleanup tasks, then disconnects. pub async fn close(&mut self) -> Result<()> { + self.cancellation_token.cancel(); + for handle in self.task_handles.drain(..) { + let _ = handle.await; + } + self.message_tx.take(); self.base.disconnect().await?; self.sessions.clear().await; self.event_routes.clear().await; @@ -849,10 +883,26 @@ impl NostrServerTransport { server_info: Option, extra_common_tags: Vec, seen_gift_wrap_ids: Arc>>, + cancel: CancellationToken, ) { let mut notifications = relay_pool.notifications(); - while let Ok(notification) = notifications.recv().await { + loop { + let notification = tokio::select! { + _ = cancel.cancelled() => { + tracing::info!( + target: LOG_TARGET, + "Server event loop cancelled" + ); + break; + } + result = notifications.recv() => { + match result { + Ok(n) => n, + Err(_) => break, + } + } + }; if let RelayPoolNotification::Event { event, .. } = notification { let is_gift_wrap = event.kind == Kind::Custom(GIFT_WRAP_KIND) || event.kind == Kind::Custom(EPHEMERAL_GIFT_WRAP_KIND); diff --git a/tests/transport_integration.rs b/tests/transport_integration.rs index 90ffe59..5bddfad 100644 --- a/tests/transport_integration.rs +++ b/tests/transport_integration.rs @@ -3523,3 +3523,65 @@ async fn session_store_eviction_callback_fires() { assert_eq!(keys.len(), 1, "callback must fire exactly once"); assert_eq!(keys[0], "x", "evicted key must be the oldest session"); } + +// ── Event loop cancellation on close() ────────────────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn client_close_stops_event_loop() { + let (client_pool, server_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool.mock_public_key(); + + let mut client = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig { + server_pubkey: server_pubkey.to_hex(), + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(client_pool), + ) + .await + .expect("create client transport"); + + let mut rx = client.take_message_receiver().expect("message receiver"); + client.start().await.expect("client start"); + let_event_loops_start().await; + + // Close should cancel the event loop, causing the rx channel to close. + client.close().await.expect("client close"); + + // The receiver must resolve to None (closed) within a short timeout. + let result = tokio::time::timeout(Duration::from_millis(200), rx.recv()).await; + assert!( + matches!(result, Ok(None)), + "after close(), message receiver must yield None (channel closed)" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn server_close_stops_event_loop() { + let (_client_pool, server_pool) = MockRelayPool::create_pair(); + + let mut server = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig { + encryption_mode: EncryptionMode::Disabled, + ..Default::default() + }, + as_pool(server_pool), + ) + .await + .expect("create server transport"); + + let mut rx = server.take_message_receiver().expect("message receiver"); + server.start().await.expect("server start"); + let_event_loops_start().await; + + // Close should cancel both event loop and cleanup tasks. + server.close().await.expect("server close"); + + // The receiver must resolve to None (closed) within a short timeout. + let result = tokio::time::timeout(Duration::from_millis(200), rx.recv()).await; + assert!( + matches!(result, Ok(None)), + "after close(), message receiver must yield None (channel closed)" + ); +} From 4d705aaa9aa16a065813a9c30ce0cab1523f42cc Mon Sep 17 00:00:00 2001 From: ContextVM Date: Wed, 6 May 2026 17:46:22 +0200 Subject: [PATCH 63/69] refactor(rmcp): add direct transport adapters for native ContextVM services Introduce direct rmcp adapter entrypoints over the raw Nostr transports so native ContextVM servers and clients can follow the expected transport-first serve flow. --- examples/native_echo_client.rs | 96 ++++++++++++++++++++++++++++ examples/native_echo_server.rs | 107 ++++++++++++++++++++++++++++++++ src/gateway/mod.rs | 8 ++- src/proxy/mod.rs | 8 ++- src/rmcp_transport/mod.rs | 2 + src/rmcp_transport/transport.rs | 55 ++++++++++++++++ src/rmcp_transport/worker.rs | 10 +++ 7 files changed, 280 insertions(+), 6 deletions(-) create mode 100644 examples/native_echo_client.rs create mode 100644 examples/native_echo_server.rs create mode 100644 src/rmcp_transport/transport.rs diff --git a/examples/native_echo_client.rs b/examples/native_echo_client.rs new file mode 100644 index 0000000..7c46f87 --- /dev/null +++ b/examples/native_echo_client.rs @@ -0,0 +1,96 @@ +//! Example: Native rmcp client over ContextVM/Nostr. +//! +//! Usage: +//! cargo run --example native_echo_client -- + +use anyhow::{Context, Result}; +use contextvm_sdk::transport::client::{NostrClientTransport, NostrClientTransportConfig}; +use contextvm_sdk::{signer, EncryptionMode, GiftWrapMode}; +use rmcp::{ + model::{CallToolRequestParams, CallToolResult}, + ClientHandler, ServiceExt, +}; + +const RELAY_URL: &str = "wss://relay.contextvm.org"; + +#[derive(Clone, Default)] +struct EchoClient; + +impl ClientHandler for EchoClient {} + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::from_default_env() + .add_directive("contextvm_sdk=info".parse()?) + .add_directive("rmcp=warn".parse()?), + ) + .init(); + + let server_pubkey = std::env::args() + .nth(1) + .context("Usage: native_echo_client ")?; + + let signer = signer::generate(); + println!("Native ContextVM echo client starting"); + println!("Relay: {RELAY_URL}"); + println!("Client pubkey: {}", signer.public_key().to_hex()); + println!("Target server pubkey: {server_pubkey}"); + + let transport = NostrClientTransport::new( + signer, + NostrClientTransportConfig { + relay_urls: vec![RELAY_URL.to_string()], + server_pubkey, + encryption_mode: EncryptionMode::Optional, + gift_wrap_mode: GiftWrapMode::Optional, + ..Default::default() + }, + ) + .await?; + + let client = EchoClient.serve(transport.into_rmcp_transport()).await?; + + let peer_info = client + .peer_info() + .context("server did not provide peer info after initialize")?; + println!("Connected to: {:?}", peer_info.server_info.name); + + let tools = client.list_all_tools().await?; + println!("Discovered {} tool(s):", tools.len()); + for tool in &tools { + println!("- {}", tool.name); + } + + let result = client + .call_tool(CallToolRequestParams { + name: "echo".into(), + arguments: serde_json::from_value(serde_json::json!({ + "message": "hello from native contextvm client" + })) + .ok(), + meta: None, + task: None, + }) + .await?; + + println!("Echo result: {}", first_text(&result)); + + client.cancel().await?; + Ok(()) +} + +fn first_text(result: &CallToolResult) -> String { + result + .content + .iter() + .find_map(|content| { + if let rmcp::model::RawContent::Text(text) = &content.raw { + Some(text.text.clone()) + } else { + None + } + }) + .unwrap_or_default() +} diff --git a/examples/native_echo_server.rs b/examples/native_echo_server.rs new file mode 100644 index 0000000..dc3b302 --- /dev/null +++ b/examples/native_echo_server.rs @@ -0,0 +1,107 @@ +//! Example: Native rmcp echo server over ContextVM/Nostr. +//! +//! Usage: +//! cargo run --example native_echo_server + +use anyhow::Result; +use contextvm_sdk::transport::server::{NostrServerTransport, NostrServerTransportConfig}; +use contextvm_sdk::{signer, EncryptionMode, GiftWrapMode, ServerInfo}; +use rmcp::{ + handler::server::{router::tool::ToolRouter, wrapper::Parameters}, + model::*, + schemars, tool, tool_handler, tool_router, ServerHandler, ServiceExt, +}; + +const RELAY_URL: &str = "wss://relay.contextvm.org"; + +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +struct EchoParams { + message: String, +} + +#[derive(Clone)] +struct EchoServer { + tool_router: ToolRouter, +} + +impl EchoServer { + fn new() -> Self { + Self { + tool_router: Self::tool_router(), + } + } +} + +#[tool_router] +impl EchoServer { + #[tool(description = "Echo a message back unchanged")] + async fn echo( + &self, + Parameters(EchoParams { message }): Parameters, + ) -> Result { + Ok(CallToolResult::success(vec![Content::text(format!( + "Echo: {message}" + ))])) + } +} + +#[tool_handler] +impl ServerHandler for EchoServer { + fn get_info(&self) -> rmcp::model::ServerInfo { + rmcp::model::ServerInfo { + protocol_version: ProtocolVersion::LATEST, + capabilities: ServerCapabilities::builder().enable_tools().build(), + server_info: Implementation { + name: "contextvm-native-echo".to_string(), + title: Some("ContextVM Native Echo Server".to_string()), + version: "0.1.0".to_string(), + description: Some("Native rmcp echo server over ContextVM/Nostr".to_string()), + icons: None, + website_url: None, + }, + instructions: Some("Call the echo tool with a message string".to_string()), + } + } +} + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::from_default_env() + .add_directive("contextvm_sdk=info".parse()?) + .add_directive("rmcp=warn".parse()?), + ) + .init(); + + let signer = signer::generate(); + let pubkey = signer.public_key().to_hex(); + + println!("Native ContextVM echo server starting"); + println!("Relay: {RELAY_URL}"); + println!("Server pubkey: {pubkey}"); + + let transport = NostrServerTransport::new( + signer, + NostrServerTransportConfig { + relay_urls: vec![RELAY_URL.to_string()], + encryption_mode: EncryptionMode::Optional, + gift_wrap_mode: GiftWrapMode::Optional, + is_announced_server: false, + server_info: Some(ServerInfo { + name: Some("contextvm-native-echo".to_string()), + about: Some("Native rmcp echo server example".to_string()), + ..Default::default() + }), + ..Default::default() + }, + ) + .await?; + + let service = EchoServer::new() + .serve(transport.into_rmcp_transport()) + .await?; + println!("Server ready. Press Ctrl+C to stop."); + service.waiting().await?; + Ok(()) +} diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index b5ae17d..0feb03a 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -94,12 +94,14 @@ impl NostrMCPGateway { T: nostr_sdk::prelude::IntoNostrSigner, H: rmcp::ServerHandler, { - use crate::rmcp_transport::NostrServerWorker; + use crate::NostrServerTransport; use rmcp::ServiceExt; - let worker = NostrServerWorker::new(signer, config.nostr_config).await?; + let transport = NostrServerTransport::new(signer, config.nostr_config) + .await? + .into_rmcp_transport(); handler - .serve(worker) + .serve(transport) .await .map_err(|e| Error::Other(format!("rmcp server initialization failed: {e}"))) } diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 1322fa0..c6c1a34 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -83,12 +83,14 @@ impl NostrMCPProxy { T: nostr_sdk::prelude::IntoNostrSigner, H: rmcp::ClientHandler, { - use crate::rmcp_transport::NostrClientWorker; + use crate::NostrClientTransport; use rmcp::ServiceExt; - let worker = NostrClientWorker::new(signer, config.nostr_config).await?; + let transport = NostrClientTransport::new(signer, config.nostr_config) + .await? + .into_rmcp_transport(); handler - .serve(worker) + .serve(transport) .await .map_err(|e| Error::Other(format!("rmcp client initialization failed: {e}"))) } diff --git a/src/rmcp_transport/mod.rs b/src/rmcp_transport/mod.rs index 57919b5..bf5c99f 100644 --- a/src/rmcp_transport/mod.rs +++ b/src/rmcp_transport/mod.rs @@ -3,6 +3,7 @@ //! This module bridges the existing Nostr transport implementation with rmcp services. pub mod convert; +pub mod transport; pub mod worker; #[cfg(test)] @@ -12,4 +13,5 @@ pub use convert::{ internal_to_rmcp_client_rx, internal_to_rmcp_server_rx, rmcp_client_tx_to_internal, rmcp_server_tx_to_internal, }; +pub use transport::{NostrClientRmcpTransport, NostrServerRmcpTransport}; pub use worker::{NostrClientWorker, NostrServerWorker}; diff --git a/src/rmcp_transport/transport.rs b/src/rmcp_transport/transport.rs new file mode 100644 index 0000000..f10223a --- /dev/null +++ b/src/rmcp_transport/transport.rs @@ -0,0 +1,55 @@ +//! Direct rmcp adapter entrypoints over raw ContextVM Nostr transports. + +use crate::{ + core::error::Error, + rmcp_transport::worker::{NostrClientWorker, NostrServerWorker}, + transport::{client::NostrClientTransport, server::NostrServerTransport}, +}; + +/// Direct rmcp adapter for [`NostrServerTransport`](src/transport/server/mod.rs:87). +pub struct NostrServerRmcpTransport { + worker: NostrServerWorker, +} + +impl NostrServerTransport { + /// Convert this raw transport into an rmcp-compatible transport adapter. + pub fn into_rmcp_transport(self) -> NostrServerRmcpTransport { + NostrServerRmcpTransport { + worker: NostrServerWorker::from_transport(self), + } + } +} + +impl rmcp::transport::IntoTransport + for NostrServerRmcpTransport +{ + fn into_transport( + self, + ) -> impl rmcp::transport::Transport + 'static { + self.worker.into_transport() + } +} + +/// Direct rmcp adapter for [`NostrClientTransport`](src/transport/client/mod.rs:69). +pub struct NostrClientRmcpTransport { + worker: NostrClientWorker, +} + +impl NostrClientTransport { + /// Convert this raw transport into an rmcp-compatible transport adapter. + pub fn into_rmcp_transport(self) -> NostrClientRmcpTransport { + NostrClientRmcpTransport { + worker: NostrClientWorker::from_transport(self), + } + } +} + +impl rmcp::transport::IntoTransport + for NostrClientRmcpTransport +{ + fn into_transport( + self, + ) -> impl rmcp::transport::Transport + 'static { + self.worker.into_transport() + } +} diff --git a/src/rmcp_transport/worker.rs b/src/rmcp_transport/worker.rs index a42d57e..4263c8d 100644 --- a/src/rmcp_transport/worker.rs +++ b/src/rmcp_transport/worker.rs @@ -39,6 +39,11 @@ impl NostrServerWorker { Ok(Self { transport }) } + /// Create a worker from an already-constructed raw transport. + pub fn from_transport(transport: NostrServerTransport) -> Self { + Self { transport } + } + /// Access the wrapped transport. pub fn transport(&self) -> &NostrServerTransport { &self.transport @@ -157,6 +162,11 @@ impl NostrClientWorker { Ok(Self { transport }) } + /// Create a worker from an already-constructed raw transport. + pub fn from_transport(transport: NostrClientTransport) -> Self { + Self { transport } + } + /// Access the wrapped transport. pub fn transport(&self) -> &NostrClientTransport { &self.transport From 6469c920096765f3978229ba06e66bc808539aa6 Mon Sep 17 00:00:00 2001 From: Harsh Date: Wed, 6 May 2026 21:58:18 +0530 Subject: [PATCH 64/69] fix: pre-release blockers - non_exhaustive, Lagged handling, zero-capacity clamp, rmcp feature gate --- Cargo.toml | 4 + README.md | 36 +- examples/gateway.rs | 24 +- examples/proxy.rs | 16 +- examples/rmcp_integration_test.rs | 36 +- src/core/types.rs | 29 + src/gateway/mod.rs | 8 + src/proxy/mod.rs | 8 + src/transport/client/correlation_store.rs | 2 +- src/transport/client/mod.rs | 49 +- src/transport/server/correlation_store.rs | 2 +- src/transport/server/mod.rs | 75 ++- src/transport/server/session_store.rs | 2 +- tests/conformance_stateless_mode.rs | 16 +- tests/transport_integration.rs | 638 ++++++++-------------- 15 files changed, 444 insertions(+), 501 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9ed83c7..fe6c0ee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,10 @@ tokio-util = { version = "0.7", features = ["rt"] } default = ["rmcp"] rmcp = ["dep:rmcp"] +[[example]] +name = "rmcp_integration_test" +required-features = ["rmcp"] + [dev-dependencies] tokio-test = "0.4" anyhow = "1" diff --git a/README.md b/README.md index dfe1d66..ffdb4e5 100644 --- a/README.md +++ b/README.md @@ -77,19 +77,16 @@ use contextvm_sdk::signer; async fn main() -> contextvm_sdk::Result<()> { let keys = signer::generate(); - let config = GatewayConfig { - nostr_config: NostrServerTransportConfig { - relay_urls: vec!["wss://relay.damus.io".into()], - encryption_mode: EncryptionMode::Optional, - server_info: Some(ServerInfo { - name: Some("My MCP Server".into()), - about: Some("Tools via Nostr".into()), - ..Default::default() - }), - is_announced_server: true, - ..Default::default() - }, - }; + let config = GatewayConfig::new( + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Optional) + .with_server_info( + ServerInfo::default() + .with_name("My MCP Server") + .with_about("Tools via Nostr"), + ) + .with_announced_server(true), + ); let mut gateway = NostrMCPGateway::new(keys, config).await?; let mut requests = gateway.start().await?; @@ -116,14 +113,11 @@ use contextvm_sdk::signer; async fn main() -> contextvm_sdk::Result<()> { let keys = signer::generate(); - let config = ProxyConfig { - nostr_config: NostrClientTransportConfig { - relay_urls: vec!["wss://relay.damus.io".into()], - server_pubkey: "abc123...server_hex_pubkey".into(), - encryption_mode: EncryptionMode::Optional, - ..Default::default() - }, - }; + let config = ProxyConfig::new( + NostrClientTransportConfig::default() + .with_server_pubkey("abc123...server_hex_pubkey") + .with_encryption_mode(EncryptionMode::Optional), + ); let mut proxy = NostrMCPProxy::new(keys, config).await?; let mut responses = proxy.start().await?; diff --git a/examples/gateway.rs b/examples/gateway.rs index 41543d1..3a3bef9 100644 --- a/examples/gateway.rs +++ b/examples/gateway.rs @@ -37,19 +37,17 @@ async fn main() -> contextvm_sdk::Result<()> { println!("Server pubkey: {}", keys.public_key().to_hex()); // Configure the gateway - let config = GatewayConfig { - nostr_config: NostrServerTransportConfig { - relay_urls: vec!["wss://relay.damus.io".to_string()], - server_info: Some(ServerInfo { - name: Some("Echo Server".to_string()), - about: Some("A simple echo tool exposed via ContextVM".to_string()), - ..Default::default() - }), - is_announced_server: true, - log_file_path, - ..Default::default() - }, - }; + let mut nostr_config = NostrServerTransportConfig::default() + .with_server_info( + ServerInfo::default() + .with_name("Echo Server") + .with_about("A simple echo tool exposed via ContextVM"), + ) + .with_announced_server(true); + if let Some(path) = log_file_path { + nostr_config = nostr_config.with_log_file_path(path); + } + let config = GatewayConfig::new(nostr_config); let mut gateway = NostrMCPGateway::new(keys, config).await?; let mut rx = gateway.start().await?; diff --git a/examples/proxy.rs b/examples/proxy.rs index da5a63e..9f9e1c0 100644 --- a/examples/proxy.rs +++ b/examples/proxy.rs @@ -41,15 +41,13 @@ async fn main() -> contextvm_sdk::Result<()> { let keys = signer::generate(); println!("Client pubkey: {}", keys.public_key().to_hex()); - let config = ProxyConfig { - nostr_config: NostrClientTransportConfig { - relay_urls: vec!["wss://relay.damus.io".to_string()], - server_pubkey: server_pubkey_hex, - encryption_mode: EncryptionMode::Optional, - log_file_path, - ..Default::default() - }, - }; + let mut nostr_config = NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey_hex) + .with_encryption_mode(EncryptionMode::Optional); + if let Some(path) = log_file_path { + nostr_config = nostr_config.with_log_file_path(path); + } + let config = ProxyConfig::new(nostr_config); let mut proxy = NostrMCPProxy::new(keys, config).await?; let mut rx = proxy.start().await?; diff --git a/examples/rmcp_integration_test.rs b/examples/rmcp_integration_test.rs index c0bedc6..9d82519 100644 --- a/examples/rmcp_integration_test.rs +++ b/examples/rmcp_integration_test.rs @@ -571,30 +571,24 @@ async fn run_relay_rmcp_case(relay_url: &str) -> Result<()> { } fn server_config(relay_url: &str) -> GatewayConfig { - GatewayConfig { - nostr_config: NostrServerTransportConfig { - relay_urls: vec![relay_url.to_string()], - encryption_mode: EncryptionMode::Optional, - server_info: Some(CtxServerInfo { - name: Some("rmcp-matrix-server".to_string()), - about: Some("rmcp matrix coverage server".to_string()), - ..Default::default() - }), - is_announced_server: false, - ..Default::default() - }, - } + let nostr_config = NostrServerTransportConfig::default() + .with_relay_urls(vec![relay_url.to_string()]) + .with_encryption_mode(EncryptionMode::Optional) + .with_server_info( + CtxServerInfo::default() + .with_name("rmcp-matrix-server") + .with_about("rmcp matrix coverage server"), + ) + .with_announced_server(false); + GatewayConfig::new(nostr_config) } fn client_config(relay_url: &str, server_pubkey: String) -> ProxyConfig { - ProxyConfig { - nostr_config: NostrClientTransportConfig { - relay_urls: vec![relay_url.to_string()], - server_pubkey, - encryption_mode: EncryptionMode::Optional, - ..Default::default() - }, - } + let nostr_config = NostrClientTransportConfig::default() + .with_relay_urls(vec![relay_url.to_string()]) + .with_server_pubkey(server_pubkey) + .with_encryption_mode(EncryptionMode::Optional); + ProxyConfig::new(nostr_config) } async fn send_legacy_request_and_wait( diff --git a/src/core/types.rs b/src/core/types.rs index d66311e..d3d0064 100644 --- a/src/core/types.rs +++ b/src/core/types.rs @@ -64,6 +64,7 @@ impl GiftWrapMode { /// Published as the content of a replaceable Nostr event so that clients /// can discover the server's identity and metadata. #[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[non_exhaustive] pub struct ServerInfo { /// Human-readable server name. #[serde(skip_serializing_if = "Option::is_none")] @@ -82,6 +83,34 @@ pub struct ServerInfo { pub about: Option, } +impl ServerInfo { + /// Set the server name. + pub fn with_name(mut self, name: impl Into) -> Self { + self.name = Some(name.into()); + self + } + /// Set the server version. + pub fn with_version(mut self, version: impl Into) -> Self { + self.version = Some(version.into()); + self + } + /// Set the server picture URL. + pub fn with_picture(mut self, picture: impl Into) -> Self { + self.picture = Some(picture.into()); + self + } + /// Set the server website URL. + pub fn with_website(mut self, website: impl Into) -> Self { + self.website = Some(website.into()); + self + } + /// Set the server description. + pub fn with_about(mut self, about: impl Into) -> Self { + self.about = Some(about.into()); + self + } +} + // ── Client session ────────────────────────────────────────────────── /// Client session state tracked by the server transport. diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index b5ae17d..97a6457 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -8,11 +8,19 @@ use crate::core::types::JsonRpcMessage; use crate::transport::server::{IncomingRequest, NostrServerTransport, NostrServerTransportConfig}; /// Configuration for the gateway. +#[non_exhaustive] pub struct GatewayConfig { /// Nostr server transport configuration. pub nostr_config: NostrServerTransportConfig, } +impl GatewayConfig { + /// Create a new gateway configuration. + pub fn new(nostr_config: NostrServerTransportConfig) -> Self { + Self { nostr_config } + } +} + /// Gateway that bridges a local MCP server to Nostr. /// /// The gateway listens for incoming MCP requests via Nostr, forwards them diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 1322fa0..e26f2f9 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -8,11 +8,19 @@ use crate::core::types::JsonRpcMessage; use crate::transport::client::{NostrClientTransport, NostrClientTransportConfig}; /// Configuration for the proxy. +#[non_exhaustive] pub struct ProxyConfig { /// Nostr client transport configuration. pub nostr_config: NostrClientTransportConfig, } +impl ProxyConfig { + /// Create a new proxy configuration. + pub fn new(nostr_config: NostrClientTransportConfig) -> Self { + Self { nostr_config } + } +} + /// Proxy that connects to a remote MCP server via Nostr. pub struct NostrMCPProxy { transport: NostrClientTransport, diff --git a/src/transport/client/correlation_store.rs b/src/transport/client/correlation_store.rs index 537c1ae..0fbcd9f 100644 --- a/src/transport/client/correlation_store.rs +++ b/src/transport/client/correlation_store.rs @@ -45,7 +45,7 @@ impl ClientCorrelationStore { pub fn with_max_pending(max_pending: usize) -> Self { Self { pending_requests: Arc::new(RwLock::new(LruCache::new( - NonZeroUsize::new(max_pending).expect("max_pending must be non-zero"), + NonZeroUsize::new(max_pending).unwrap_or(NonZeroUsize::new(1).unwrap()), ))), } } diff --git a/src/transport/client/mod.rs b/src/transport/client/mod.rs index 53f494b..4d7ff89 100644 --- a/src/transport/client/mod.rs +++ b/src/transport/client/mod.rs @@ -31,6 +31,7 @@ use crate::util::tracing_setup; const LOG_TARGET: &str = "contextvm_sdk::transport::client"; /// Configuration for the client transport. +#[non_exhaustive] pub struct NostrClientTransportConfig { /// Relay URLs to connect to. pub relay_urls: Vec, @@ -66,6 +67,44 @@ impl Default for NostrClientTransportConfig { } } +impl NostrClientTransportConfig { + /// Set the server's public key (hex). + pub fn with_server_pubkey(mut self, pubkey: impl Into) -> Self { + self.server_pubkey = pubkey.into(); + self + } + /// Set the encryption mode. + pub fn with_encryption_mode(mut self, mode: EncryptionMode) -> Self { + self.encryption_mode = mode; + self + } + /// Set the gift-wrap mode (CEP-19). + pub fn with_gift_wrap_mode(mut self, mode: GiftWrapMode) -> Self { + self.gift_wrap_mode = mode; + self + } + /// Enable or disable stateless mode. + pub fn with_stateless(mut self, stateless: bool) -> Self { + self.is_stateless = stateless; + self + } + /// Set the relay URLs to connect to. + pub fn with_relay_urls(mut self, urls: Vec) -> Self { + self.relay_urls = urls; + self + } + /// Set the correlation-retention TTL. + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.timeout = timeout; + self + } + /// Set the log file path. + pub fn with_log_file_path(mut self, path: impl Into) -> Self { + self.log_file_path = Some(path.into()); + self + } +} + /// Client-side Nostr transport for sending MCP requests and receiving responses. pub struct NostrClientTransport { base: BaseTransport, @@ -442,7 +481,15 @@ impl NostrClientTransport { result = notifications.recv() => { let notification = match result { Ok(n) => n, - Err(_) => break, + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + tracing::warn!( + target: LOG_TARGET, + skipped = n, + "Relay broadcast lagged, skipping missed events" + ); + continue; + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => break, }; Self::handle_notification( ¬ification, diff --git a/src/transport/server/correlation_store.rs b/src/transport/server/correlation_store.rs index 80353a8..c25404e 100644 --- a/src/transport/server/correlation_store.rs +++ b/src/transport/server/correlation_store.rs @@ -39,7 +39,7 @@ struct Inner { impl Inner { fn new(max_routes: usize) -> Self { let routes = - LruCache::new(NonZeroUsize::new(max_routes).expect("max_routes must be non-zero")); + LruCache::new(NonZeroUsize::new(max_routes).unwrap_or(NonZeroUsize::new(1).unwrap())); Self { routes, progress_token_to_event: HashMap::new(), diff --git a/src/transport/server/mod.rs b/src/transport/server/mod.rs index a222a3e..0082be1 100644 --- a/src/transport/server/mod.rs +++ b/src/transport/server/mod.rs @@ -34,6 +34,7 @@ use crate::util::tracing_setup; const LOG_TARGET: &str = "contextvm_sdk::transport::server"; /// Configuration for the server transport. +#[non_exhaustive] pub struct NostrServerTransportConfig { /// Relay URLs to connect to. pub relay_urls: Vec, @@ -113,8 +114,72 @@ pub struct NostrServerTransport { task_handles: Vec>, } +impl NostrServerTransportConfig { + /// Set the encryption mode. + pub fn with_encryption_mode(mut self, mode: EncryptionMode) -> Self { + self.encryption_mode = mode; + self + } + /// Set the gift-wrap mode (CEP-19). + pub fn with_gift_wrap_mode(mut self, mode: GiftWrapMode) -> Self { + self.gift_wrap_mode = mode; + self + } + /// Set server information for announcements. + pub fn with_server_info(mut self, info: ServerInfo) -> Self { + self.server_info = Some(info); + self + } + /// Enable or disable public announcement publishing (CEP-6). + pub fn with_announced_server(mut self, announced: bool) -> Self { + self.is_announced_server = announced; + self + } + /// Set the allowed client public keys (hex). Empty = allow all. + pub fn with_allowed_public_keys(mut self, keys: Vec) -> Self { + self.allowed_public_keys = keys; + self + } + /// Set capabilities excluded from pubkey whitelisting. + pub fn with_excluded_capabilities(mut self, caps: Vec) -> Self { + self.excluded_capabilities = caps; + self + } + /// Set the maximum number of concurrent client sessions. + pub fn with_max_sessions(mut self, max: usize) -> Self { + self.max_sessions = max; + self + } + /// Set the relay URLs to connect to. + pub fn with_relay_urls(mut self, urls: Vec) -> Self { + self.relay_urls = urls; + self + } + /// Set the session cleanup interval. + pub fn with_cleanup_interval(mut self, interval: Duration) -> Self { + self.cleanup_interval = interval; + self + } + /// Set the session timeout. + pub fn with_session_timeout(mut self, timeout: Duration) -> Self { + self.session_timeout = timeout; + self + } + /// Set the correlation-retention TTL for event routes. + pub fn with_request_timeout(mut self, timeout: Duration) -> Self { + self.request_timeout = timeout; + self + } + /// Set the log file path. + pub fn with_log_file_path(mut self, path: impl Into) -> Self { + self.log_file_path = Some(path.into()); + self + } +} + /// An incoming MCP request with metadata for routing the response. #[derive(Debug)] +#[non_exhaustive] pub struct IncomingRequest { /// The parsed MCP message. pub message: JsonRpcMessage, @@ -899,7 +964,15 @@ impl NostrServerTransport { result = notifications.recv() => { match result { Ok(n) => n, - Err(_) => break, + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + tracing::warn!( + target: LOG_TARGET, + skipped = n, + "Relay broadcast lagged, skipping missed events" + ); + continue; + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => break, } } }; diff --git a/src/transport/server/session_store.rs b/src/transport/server/session_store.rs index 1415a90..a37d53a 100644 --- a/src/transport/server/session_store.rs +++ b/src/transport/server/session_store.rs @@ -53,7 +53,7 @@ impl SessionStore { pub fn with_capacity(max_sessions: usize) -> Self { Self { sessions: Arc::new(RwLock::new(LruCache::new( - NonZeroUsize::new(max_sessions).expect("max_sessions must be > 0"), + NonZeroUsize::new(max_sessions).unwrap_or(NonZeroUsize::new(1).unwrap()), ))), on_evicted: None, } diff --git a/tests/conformance_stateless_mode.rs b/tests/conformance_stateless_mode.rs index 9999e23..0e30a65 100644 --- a/tests/conformance_stateless_mode.rs +++ b/tests/conformance_stateless_mode.rs @@ -17,15 +17,13 @@ async fn make_stateless_transport() -> ( let server_keys = signer::generate(); let client_keys = signer::generate(); - let config = NostrClientTransportConfig { - relay_urls: Vec::new(), - server_pubkey: server_keys.public_key().to_hex(), - encryption_mode: EncryptionMode::Optional, - gift_wrap_mode: GiftWrapMode::Optional, - is_stateless: true, - timeout: Duration::from_secs(1), - log_file_path: None, - }; + let config = NostrClientTransportConfig::default() + .with_relay_urls(Vec::new()) + .with_server_pubkey(server_keys.public_key().to_hex()) + .with_encryption_mode(EncryptionMode::Optional) + .with_gift_wrap_mode(GiftWrapMode::Optional) + .with_stateless(true) + .with_timeout(Duration::from_secs(1)); let mut transport = NostrClientTransport::new(client_keys, config) .await diff --git a/tests/transport_integration.rs b/tests/transport_integration.rs index 5bddfad..7bcaf1a 100644 --- a/tests/transport_integration.rs +++ b/tests/transport_integration.rs @@ -136,21 +136,16 @@ async fn full_initialization_handshake() { let server_pubkey = server_pool.mock_public_key(); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), as_pool(server_pool), ) .await .expect("create server transport"); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), as_pool(client_pool), ) .await @@ -227,15 +222,10 @@ async fn server_announcement_publishing() { let pool = Arc::new(MockRelayPool::new()); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - is_announced_server: true, - server_info: Some(ServerInfo { - name: Some("Phase3-Test-Server".to_string()), - ..Default::default() - }), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Disabled) + .with_server_info(ServerInfo::default().with_name("Phase3-Test-Server".to_string())) + .with_announced_server(true), Arc::clone(&pool) as Arc, ) .await @@ -273,10 +263,7 @@ async fn encryption_mode_optional_accepts_plaintext() { // Server uses Optional — should accept both encrypted and plaintext. let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - encryption_mode: EncryptionMode::Optional, - ..Default::default() - }, + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Optional), as_pool(server_pool), ) .await @@ -289,11 +276,9 @@ async fn encryption_mode_optional_accepts_plaintext() { // Client uses Disabled — sends plaintext kind 25910. let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), as_pool(client_pool), ) .await @@ -337,11 +322,9 @@ async fn auth_allowlist_blocks_disallowed_pubkey() { // Server allows only `allowed_keys` — client_keys is NOT allowed. let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - allowed_public_keys: vec![allowed_keys.public_key().to_hex()], - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Disabled) + .with_allowed_public_keys(vec![allowed_keys.public_key().to_hex()]), as_pool(server_pool), ) .await @@ -353,11 +336,9 @@ async fn auth_allowlist_blocks_disallowed_pubkey() { server.start().await.expect("server start"); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), as_pool(client_pool), ) .await @@ -392,10 +373,7 @@ async fn encryption_mode_required_drops_plaintext() { // Server requires encryption — plaintext must be dropped. let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - encryption_mode: EncryptionMode::Required, - ..Default::default() - }, + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Required), as_pool(server_pool), ) .await @@ -408,11 +386,9 @@ async fn encryption_mode_required_drops_plaintext() { // Client sends plaintext (Disabled mode). let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), as_pool(client_pool), ) .await @@ -446,21 +422,16 @@ async fn encrypted_gift_wrap_roundtrip() { let server_pool = Arc::new(server_pool); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - encryption_mode: EncryptionMode::Required, - ..Default::default() - }, + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Required), Arc::clone(&server_pool) as Arc, ) .await .expect("create server transport"); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Required, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Required), as_pool(client_pool), ) .await @@ -534,21 +505,16 @@ async fn gift_wrap_dedup_skips_duplicate_delivery() { let server_pool = Arc::new(server_pool); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - encryption_mode: EncryptionMode::Required, - ..Default::default() - }, + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Required), Arc::clone(&server_pool) as Arc, ) .await .expect("create server transport"); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Required, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Required), as_pool(client_pool), ) .await @@ -608,21 +574,16 @@ async fn correlated_notification_has_e_tag() { let server_pool = Arc::new(server_pool); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), Arc::clone(&server_pool) as Arc, ) .await .expect("create server transport"); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), as_pool(client_pool), ) .await @@ -707,21 +668,16 @@ async fn encryption_required_client_optional_server() { let server_pubkey = server_pool.mock_public_key(); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - encryption_mode: EncryptionMode::Optional, - ..Default::default() - }, + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Optional), as_pool(server_pool), ) .await .expect("create server transport"); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Required, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Required), as_pool(client_pool), ) .await @@ -769,21 +725,16 @@ async fn encryption_optional_both_sides_encrypted_path() { let server_pubkey = server_pool.mock_public_key(); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - encryption_mode: EncryptionMode::Optional, - ..Default::default() - }, + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Optional), as_pool(server_pool), ) .await .expect("create server transport"); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Optional, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Optional), as_pool(client_pool), ) .await @@ -824,15 +775,10 @@ async fn announce_includes_encryption_tags() { let pool = Arc::new(MockRelayPool::new()); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - is_announced_server: true, - server_info: Some(ServerInfo { - name: Some("Encrypted-Server".to_string()), - ..Default::default() - }), - encryption_mode: EncryptionMode::Required, - ..Default::default() - }, + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Required) + .with_server_info(ServerInfo::default().with_name("Encrypted-Server".to_string())) + .with_announced_server(true), Arc::clone(&pool) as Arc, ) .await @@ -873,18 +819,16 @@ async fn announce_includes_server_metadata_tags() { let pool = Arc::new(MockRelayPool::new()); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - is_announced_server: true, - server_info: Some(ServerInfo { - name: Some("Meta-Server".to_string()), - about: Some("A test server".to_string()), - website: Some("https://example.com".to_string()), - picture: Some("https://example.com/pic.png".to_string()), - ..Default::default() - }), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Disabled) + .with_server_info( + ServerInfo::default() + .with_name("Meta-Server".to_string()) + .with_about("A test server".to_string()) + .with_website("https://example.com".to_string()) + .with_picture("https://example.com/pic.png".to_string()), + ) + .with_announced_server(true), Arc::clone(&pool) as Arc, ) .await @@ -935,15 +879,10 @@ async fn publish_tools_produces_correct_kind() { let pool = Arc::new(MockRelayPool::new()); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - is_announced_server: true, - server_info: Some(ServerInfo { - name: Some("Tools-Server".to_string()), - ..Default::default() - }), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Disabled) + .with_server_info(ServerInfo::default().with_name("Tools-Server".to_string())) + .with_announced_server(true), Arc::clone(&pool) as Arc, ) .await @@ -984,10 +923,7 @@ async fn broadcast_notification_reaches_initialized_client() { let server_pk = s_pool.mock_public_key(); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), as_pool(s_pool), ) .await @@ -999,11 +935,9 @@ async fn broadcast_notification_reaches_initialized_client() { server.start().await.expect("server start"); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pk.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pk.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), as_pool(c1_pool), ) .await @@ -1106,21 +1040,16 @@ async fn uncorrelated_notification_passes_through() { let server_pubkey = server_pool.mock_public_key(); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), as_pool(server_pool), ) .await .expect("create server transport"); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), as_pool(client_pool), ) .await @@ -1204,21 +1133,16 @@ async fn correlated_notification_unknown_e_tag_is_dropped() { let server_pubkey = server_pool.mock_public_key(); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), as_pool(server_pool), ) .await .expect("create server transport"); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), as_pool(client_pool), ) .await @@ -1299,11 +1223,9 @@ async fn auth_allowed_pubkey_receives_response() { let client_pubkey = client_pool.mock_public_key(); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - allowed_public_keys: vec![client_pubkey.to_hex()], - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Disabled) + .with_allowed_public_keys(vec![client_pubkey.to_hex()]), as_pool(server_pool), ) .await @@ -1315,11 +1237,9 @@ async fn auth_allowed_pubkey_receives_response() { server.start().await.expect("server start"); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), as_pool(client_pool), ) .await @@ -1378,15 +1298,13 @@ async fn excluded_capability_bypasses_auth() { let server_pubkey = server_pool.mock_public_key(); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - allowed_public_keys: vec![allowed_keys.public_key().to_hex()], - excluded_capabilities: vec![CapabilityExclusion { + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Disabled) + .with_allowed_public_keys(vec![allowed_keys.public_key().to_hex()]) + .with_excluded_capabilities(vec![CapabilityExclusion { method: "tools/list".to_string(), name: None, - }], - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + }]), as_pool(server_pool), ) .await @@ -1398,11 +1316,9 @@ async fn excluded_capability_bypasses_auth() { server.start().await.expect("server start"); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), as_pool(client_pool), ) .await @@ -1440,10 +1356,7 @@ async fn publish_resources_produces_correct_kind() { let pool = Arc::new(MockRelayPool::new()); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), Arc::clone(&pool) as Arc, ) .await @@ -1483,10 +1396,7 @@ async fn publish_prompts_produces_correct_kind() { let pool = Arc::new(MockRelayPool::new()); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), Arc::clone(&pool) as Arc, ) .await @@ -1525,10 +1435,7 @@ async fn publish_resource_templates_produces_correct_kind() { let pool = Arc::new(MockRelayPool::new()); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), Arc::clone(&pool) as Arc, ) .await @@ -1568,10 +1475,7 @@ async fn publish_tools_empty_list() { let pool = Arc::new(MockRelayPool::new()); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), Arc::clone(&pool) as Arc, ) .await @@ -1602,15 +1506,10 @@ async fn delete_announcements_k_tags_match_kinds() { let pool = Arc::new(MockRelayPool::new()); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - is_announced_server: true, - server_info: Some(ServerInfo { - name: Some("KTag-Server".to_string()), - ..Default::default() - }), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Disabled) + .with_server_info(ServerInfo::default().with_name("KTag-Server".to_string())) + .with_announced_server(true), Arc::clone(&pool) as Arc, ) .await @@ -1665,10 +1564,7 @@ async fn encryption_disabled_server_rejects_gift_wrap() { // Server has encryption disabled — must reject gift-wrap events. let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), as_pool(server_pool), ) .await @@ -1681,11 +1577,9 @@ async fn encryption_disabled_server_rejects_gift_wrap() { // Client requires encryption — sends gift-wrap (kind 1059). let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Required, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Required), as_pool(client_pool), ) .await @@ -1720,21 +1614,16 @@ async fn response_mirrors_client_encryption_format() { let server_pool = Arc::new(server_pool); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - encryption_mode: EncryptionMode::Optional, - ..Default::default() - }, + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Optional), Arc::clone(&server_pool) as Arc, ) .await .expect("create server transport"); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), as_pool(client_pool), ) .await @@ -1807,21 +1696,16 @@ async fn response_mirrors_client_encryption_format() { let server_pool = Arc::new(server_pool); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - encryption_mode: EncryptionMode::Optional, - ..Default::default() - }, + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Optional), Arc::clone(&server_pool) as Arc, ) .await .expect("create server transport"); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Required, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Required), as_pool(client_pool), ) .await @@ -1908,21 +1792,16 @@ async fn send_response_is_one_shot_under_concurrency() { )); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), delayed_server_pool, ) .await .expect("create server transport"); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), as_pool(client_pool), ) .await @@ -2010,21 +1889,16 @@ async fn send_response_publish_failure_allows_one_successful_retry() { )); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), Arc::clone(&failing_server_pool) as Arc, ) .await .expect("create server transport"); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), as_pool(client_pool), ) .await @@ -2132,12 +2006,10 @@ async fn announced_server_sends_unauthorized_error_response() { // Announced server with an allowlist that does NOT include the client. let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - allowed_public_keys: vec![allowed_keys.public_key().to_hex()], - is_announced_server: true, - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Disabled) + .with_announced_server(true) + .with_allowed_public_keys(vec![allowed_keys.public_key().to_hex()]), as_pool(server_pool), ) .await @@ -2149,11 +2021,9 @@ async fn announced_server_sends_unauthorized_error_response() { server.start().await.expect("server start"); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), as_pool(client_pool), ) .await @@ -2212,11 +2082,9 @@ async fn private_server_silently_drops_unauthorized_request() { // Private server (is_announced_server defaults to false). let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - allowed_public_keys: vec![allowed_keys.public_key().to_hex()], - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Disabled) + .with_allowed_public_keys(vec![allowed_keys.public_key().to_hex()]), as_pool(server_pool), ) .await @@ -2228,11 +2096,9 @@ async fn private_server_silently_drops_unauthorized_request() { server.start().await.expect("server start"); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), as_pool(client_pool), ) .await @@ -2276,12 +2142,10 @@ async fn announced_server_does_not_error_on_unauthorized_notification() { let server_pubkey = server_pool.mock_public_key(); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - allowed_public_keys: vec![allowed_keys.public_key().to_hex()], - is_announced_server: true, - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Disabled) + .with_announced_server(true) + .with_allowed_public_keys(vec![allowed_keys.public_key().to_hex()]), as_pool(server_pool), ) .await @@ -2293,11 +2157,9 @@ async fn announced_server_does_not_error_on_unauthorized_notification() { server.start().await.expect("server start"); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), as_pool(client_pool), ) .await @@ -2341,27 +2203,20 @@ async fn first_response_includes_discovery_tags() { let s_pool = Arc::new(server_pool); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - is_announced_server: true, - server_info: Some(ServerInfo { - name: Some("Disco-Server".to_string()), - ..Default::default() - }), - encryption_mode: EncryptionMode::Optional, - gift_wrap_mode: GiftWrapMode::Optional, - ..Default::default() - }, + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Optional) + .with_gift_wrap_mode(GiftWrapMode::Optional) + .with_server_info(ServerInfo::default().with_name("Disco-Server".to_string())) + .with_announced_server(true), Arc::clone(&s_pool) as Arc, ) .await .expect("create server transport"); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), as_pool(client_pool), ) .await @@ -2472,23 +2327,19 @@ async fn notification_mirror_selection_wrt_cep_19() { let s_pool = Arc::new(server_pool); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - encryption_mode: EncryptionMode::Optional, - gift_wrap_mode: GiftWrapMode::Optional, - ..Default::default() - }, + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Optional) + .with_gift_wrap_mode(GiftWrapMode::Optional), Arc::clone(&s_pool) as Arc, ) .await .expect("create server transport"); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Optional, - gift_wrap_mode: GiftWrapMode::Ephemeral, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Optional) + .with_gift_wrap_mode(GiftWrapMode::Ephemeral), as_pool(client_pool), ) .await @@ -2561,22 +2412,18 @@ async fn server_response_includes_encryption_tags_when_enabled() { let server_pool_arc = Arc::new(server_pool); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - encryption_mode: EncryptionMode::Optional, - gift_wrap_mode: GiftWrapMode::Optional, - ..Default::default() - }, + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Optional) + .with_gift_wrap_mode(GiftWrapMode::Optional), Arc::clone(&server_pool_arc) as Arc, ) .await .unwrap(); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), as_pool(client_pool), ) .await @@ -2639,22 +2486,18 @@ async fn server_response_excludes_ephemeral_tag_when_persistent() { let server_pool_arc = Arc::new(server_pool); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - encryption_mode: EncryptionMode::Optional, - gift_wrap_mode: GiftWrapMode::Persistent, - ..Default::default() - }, + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Optional) + .with_gift_wrap_mode(GiftWrapMode::Persistent), Arc::clone(&server_pool_arc) as Arc, ) .await .unwrap(); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), as_pool(client_pool), ) .await @@ -2716,21 +2559,16 @@ async fn server_learns_capabilities_from_client_request() { let server_pubkey = server_pool.mock_public_key(); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), as_pool(server_pool), ) .await .unwrap(); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), as_pool(client_pool), ) .await @@ -2781,25 +2619,18 @@ async fn server_disabled_encryption_omits_encryption_tags() { let server_pool_arc = Arc::new(server_pool); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - server_info: Some(ServerInfo { - name: Some("NoEncrypt".to_string()), - ..Default::default() - }), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Disabled) + .with_server_info(ServerInfo::default().with_name("NoEncrypt".to_string())), Arc::clone(&server_pool_arc) as Arc, ) .await .unwrap(); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), as_pool(client_pool), ) .await @@ -2866,12 +2697,10 @@ async fn client_disabled_encryption_emits_no_discovery_tags() { let server_keys = Keys::generate(); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_keys.public_key().to_hex(), - encryption_mode: EncryptionMode::Disabled, - gift_wrap_mode: GiftWrapMode::Optional, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_keys.public_key().to_hex()) + .with_encryption_mode(EncryptionMode::Disabled) + .with_gift_wrap_mode(GiftWrapMode::Optional), Arc::clone(&pool) as Arc, ) .await @@ -2915,12 +2744,10 @@ async fn client_second_request_carries_no_discovery_tags() { let server_keys = Keys::generate(); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_keys.public_key().to_hex(), - encryption_mode: EncryptionMode::Disabled, - gift_wrap_mode: GiftWrapMode::Optional, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_keys.public_key().to_hex()) + .with_encryption_mode(EncryptionMode::Disabled) + .with_gift_wrap_mode(GiftWrapMode::Optional), Arc::clone(&pool) as Arc, ) .await @@ -2975,26 +2802,19 @@ async fn client_learns_server_capabilities_from_first_response() { let server_pubkey = server_pool.mock_public_key(); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - server_info: Some(ServerInfo { - name: Some("CapServer".to_string()), - ..Default::default() - }), - encryption_mode: EncryptionMode::Optional, - gift_wrap_mode: GiftWrapMode::Optional, - ..Default::default() - }, + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Optional) + .with_gift_wrap_mode(GiftWrapMode::Optional) + .with_server_info(ServerInfo::default().with_name("CapServer".to_string())), as_pool(server_pool), ) .await .unwrap(); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), as_pool(client_pool), ) .await @@ -3067,26 +2887,19 @@ async fn client_or_assigns_capabilities_across_responses() { let client_pool = Arc::new(client_pool); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - server_info: Some(ServerInfo { - name: Some("PersistentServer".to_string()), - ..Default::default() - }), - encryption_mode: EncryptionMode::Optional, - gift_wrap_mode: GiftWrapMode::Persistent, - ..Default::default() - }, + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Optional) + .with_gift_wrap_mode(GiftWrapMode::Persistent) + .with_server_info(ServerInfo::default().with_name("PersistentServer".to_string())), as_pool(server_pool), ) .await .unwrap(); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), Arc::clone(&client_pool) as Arc, ) .await @@ -3186,26 +2999,19 @@ async fn client_baseline_event_not_replaced_by_later_responses() { let client_pool = Arc::new(client_pool); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - server_info: Some(ServerInfo { - name: Some("BaselineServer".to_string()), - ..Default::default() - }), - encryption_mode: EncryptionMode::Optional, - gift_wrap_mode: GiftWrapMode::Optional, - ..Default::default() - }, + NostrServerTransportConfig::default() + .with_encryption_mode(EncryptionMode::Optional) + .with_gift_wrap_mode(GiftWrapMode::Optional) + .with_server_info(ServerInfo::default().with_name("BaselineServer".to_string())), as_pool(server_pool), ) .await .unwrap(); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), Arc::clone(&client_pool) as Arc, ) .await @@ -3296,12 +3102,10 @@ async fn client_optional_encryption_emits_discovery_tags() { let client_pool = Arc::new(client_pool); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Optional, - gift_wrap_mode: GiftWrapMode::Optional, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Optional) + .with_gift_wrap_mode(GiftWrapMode::Optional), Arc::clone(&client_pool) as Arc, ) .await @@ -3361,32 +3165,25 @@ async fn multi_client_concurrent_requests_both_get_responses() { let server_pubkey = server_pool.mock_public_key(); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), as_pool(server_pool), ) .await .expect("create server transport"); let mut client_a = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), as_pool(client_a_pool), ) .await .expect("create client A"); let mut client_b = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), as_pool(client_b_pool), ) .await @@ -3532,11 +3329,9 @@ async fn client_close_stops_event_loop() { let server_pubkey = server_pool.mock_public_key(); let mut client = NostrClientTransport::with_relay_pool( - NostrClientTransportConfig { - server_pubkey: server_pubkey.to_hex(), - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_server_pubkey(server_pubkey.to_hex()) + .with_encryption_mode(EncryptionMode::Disabled), as_pool(client_pool), ) .await @@ -3562,10 +3357,7 @@ async fn server_close_stops_event_loop() { let (_client_pool, server_pool) = MockRelayPool::create_pair(); let mut server = NostrServerTransport::with_relay_pool( - NostrServerTransportConfig { - encryption_mode: EncryptionMode::Disabled, - ..Default::default() - }, + NostrServerTransportConfig::default().with_encryption_mode(EncryptionMode::Disabled), as_pool(server_pool), ) .await From ad372290c0a4555077efc3673a83a5c5121e6c7a Mon Sep 17 00:00:00 2001 From: ContextVM Date: Wed, 6 May 2026 20:18:18 +0200 Subject: [PATCH 65/69] refactor(rmcp): simplify transport API by removing wrapper adapters Remove NostrServerRmcpTransport and NostrClientRmcpTransport wrapper types, implementing IntoTransport directly on the raw Nostr transports instead. This simplifies the API by eliminating the redundant .into_rmcp_transport() call and reducing indirection in the transport layer. --- examples/native_echo_client.rs | 2 +- examples/native_echo_server.rs | 4 +--- sdk | 2 +- src/gateway/mod.rs | 6 ++--- src/proxy/mod.rs | 6 ++--- src/rmcp_transport/mod.rs | 6 ++--- src/rmcp_transport/transport.rs | 42 +++++++-------------------------- 7 files changed, 19 insertions(+), 49 deletions(-) diff --git a/examples/native_echo_client.rs b/examples/native_echo_client.rs index 7c46f87..734a7e4 100644 --- a/examples/native_echo_client.rs +++ b/examples/native_echo_client.rs @@ -50,7 +50,7 @@ async fn main() -> Result<()> { ) .await?; - let client = EchoClient.serve(transport.into_rmcp_transport()).await?; + let client = EchoClient.serve(transport).await?; let peer_info = client .peer_info() diff --git a/examples/native_echo_server.rs b/examples/native_echo_server.rs index dc3b302..e0548d4 100644 --- a/examples/native_echo_server.rs +++ b/examples/native_echo_server.rs @@ -98,9 +98,7 @@ async fn main() -> Result<()> { ) .await?; - let service = EchoServer::new() - .serve(transport.into_rmcp_transport()) - .await?; + let service = EchoServer::new().serve(transport).await?; println!("Server ready. Press Ctrl+C to stop."); service.waiting().await?; Ok(()) diff --git a/sdk b/sdk index 5f773a2..7a0c5c3 160000 --- a/sdk +++ b/sdk @@ -1 +1 @@ -Subproject commit 5f773a20d9ea4b0a5f06d1b860d3d3da7509699f +Subproject commit 7a0c5c398c8ffcc6b07ddb181c67a29b77dc3cb4 diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 0feb03a..c5cd794 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -84,7 +84,7 @@ impl NostrMCPGateway { /// Start a gateway directly from an rmcp server handler. /// /// This additive API keeps the existing `new/start/send_response` flow intact, - /// while allowing rmcp-first usage through the worker adapter. + /// while also allowing direct `handler.serve(transport)` style usage. pub async fn serve_handler( signer: T, config: GatewayConfig, @@ -97,9 +97,7 @@ impl NostrMCPGateway { use crate::NostrServerTransport; use rmcp::ServiceExt; - let transport = NostrServerTransport::new(signer, config.nostr_config) - .await? - .into_rmcp_transport(); + let transport = NostrServerTransport::new(signer, config.nostr_config).await?; handler .serve(transport) .await diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index c6c1a34..3c9f9e7 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -73,7 +73,7 @@ impl NostrMCPProxy { /// Start a proxy directly from an rmcp client handler. /// /// This additive API keeps the existing `new/start/send` flow intact, - /// while allowing rmcp-first usage through the worker adapter. + /// while also allowing direct `handler.serve(transport)` style usage. pub async fn serve_client_handler( signer: T, config: ProxyConfig, @@ -86,9 +86,7 @@ impl NostrMCPProxy { use crate::NostrClientTransport; use rmcp::ServiceExt; - let transport = NostrClientTransport::new(signer, config.nostr_config) - .await? - .into_rmcp_transport(); + let transport = NostrClientTransport::new(signer, config.nostr_config).await?; handler .serve(transport) .await diff --git a/src/rmcp_transport/mod.rs b/src/rmcp_transport/mod.rs index bf5c99f..436f1dd 100644 --- a/src/rmcp_transport/mod.rs +++ b/src/rmcp_transport/mod.rs @@ -1,6 +1,7 @@ -//! RMCP integration scaffolding. +//! rmcp integration for ContextVM Nostr transports. //! -//! This module bridges the existing Nostr transport implementation with rmcp services. +//! This module contains the conversion helpers and worker bridge that let raw +//! ContextVM transports plug directly into rmcp service APIs. pub mod convert; pub mod transport; @@ -13,5 +14,4 @@ pub use convert::{ internal_to_rmcp_client_rx, internal_to_rmcp_server_rx, rmcp_client_tx_to_internal, rmcp_server_tx_to_internal, }; -pub use transport::{NostrClientRmcpTransport, NostrServerRmcpTransport}; pub use worker::{NostrClientWorker, NostrServerWorker}; diff --git a/src/rmcp_transport/transport.rs b/src/rmcp_transport/transport.rs index f10223a..29e8fb4 100644 --- a/src/rmcp_transport/transport.rs +++ b/src/rmcp_transport/transport.rs @@ -1,4 +1,4 @@ -//! Direct rmcp adapter entrypoints over raw ContextVM Nostr transports. +//! rmcp transport integration for raw ContextVM Nostr transports. use crate::{ core::error::Error, @@ -6,50 +6,26 @@ use crate::{ transport::{client::NostrClientTransport, server::NostrServerTransport}, }; -/// Direct rmcp adapter for [`NostrServerTransport`](src/transport/server/mod.rs:87). -pub struct NostrServerRmcpTransport { - worker: NostrServerWorker, -} - -impl NostrServerTransport { - /// Convert this raw transport into an rmcp-compatible transport adapter. - pub fn into_rmcp_transport(self) -> NostrServerRmcpTransport { - NostrServerRmcpTransport { - worker: NostrServerWorker::from_transport(self), - } - } -} - impl rmcp::transport::IntoTransport - for NostrServerRmcpTransport + for NostrServerTransport { + /// Convert the raw server transport into rmcp's transport model via the + /// worker bridge. fn into_transport( self, ) -> impl rmcp::transport::Transport + 'static { - self.worker.into_transport() - } -} - -/// Direct rmcp adapter for [`NostrClientTransport`](src/transport/client/mod.rs:69). -pub struct NostrClientRmcpTransport { - worker: NostrClientWorker, -} - -impl NostrClientTransport { - /// Convert this raw transport into an rmcp-compatible transport adapter. - pub fn into_rmcp_transport(self) -> NostrClientRmcpTransport { - NostrClientRmcpTransport { - worker: NostrClientWorker::from_transport(self), - } + NostrServerWorker::from_transport(self).into_transport() } } impl rmcp::transport::IntoTransport - for NostrClientRmcpTransport + for NostrClientTransport { + /// Convert the raw client transport into rmcp's transport model via the + /// worker bridge. fn into_transport( self, ) -> impl rmcp::transport::Transport + 'static { - self.worker.into_transport() + NostrClientWorker::from_transport(self).into_transport() } } From 53c2a9e6c3197fe9990d6a601cef1c718a764799 Mon Sep 17 00:00:00 2001 From: ContextVM Date: Thu, 7 May 2026 12:21:03 +0200 Subject: [PATCH 66/69] refactor(examples): use builder pattern for transport configuration Refactored native_echo_client and native_echo_server examples to use the builder pattern (with_*) instead of struct literal syntax for NostrClientTransportConfig and NostrServerTransportConfig initialization. This provides a more fluent and consistent API for configuring transport options. --- examples/native_echo_client.rs | 12 +++++------- examples/native_echo_server.rs | 22 ++++++++++------------ 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/examples/native_echo_client.rs b/examples/native_echo_client.rs index 734a7e4..be5e0dd 100644 --- a/examples/native_echo_client.rs +++ b/examples/native_echo_client.rs @@ -40,13 +40,11 @@ async fn main() -> Result<()> { let transport = NostrClientTransport::new( signer, - NostrClientTransportConfig { - relay_urls: vec![RELAY_URL.to_string()], - server_pubkey, - encryption_mode: EncryptionMode::Optional, - gift_wrap_mode: GiftWrapMode::Optional, - ..Default::default() - }, + NostrClientTransportConfig::default() + .with_relay_urls(vec![RELAY_URL.to_string()]) + .with_server_pubkey(server_pubkey) + .with_encryption_mode(EncryptionMode::Optional) + .with_gift_wrap_mode(GiftWrapMode::Optional), ) .await?; diff --git a/examples/native_echo_server.rs b/examples/native_echo_server.rs index e0548d4..e463e62 100644 --- a/examples/native_echo_server.rs +++ b/examples/native_echo_server.rs @@ -83,18 +83,16 @@ async fn main() -> Result<()> { let transport = NostrServerTransport::new( signer, - NostrServerTransportConfig { - relay_urls: vec![RELAY_URL.to_string()], - encryption_mode: EncryptionMode::Optional, - gift_wrap_mode: GiftWrapMode::Optional, - is_announced_server: false, - server_info: Some(ServerInfo { - name: Some("contextvm-native-echo".to_string()), - about: Some("Native rmcp echo server example".to_string()), - ..Default::default() - }), - ..Default::default() - }, + NostrServerTransportConfig::default() + .with_relay_urls(vec![RELAY_URL.to_string()]) + .with_encryption_mode(EncryptionMode::Optional) + .with_gift_wrap_mode(GiftWrapMode::Optional) + .with_announced_server(false) + .with_server_info( + ServerInfo::default() + .with_name("contextvm-native-echo".to_string()) + .with_about("Native rmcp echo server example".to_string()), + ), ) .await?; From 938a5bec9f762b340b177a55b5cbd50cf158ca07 Mon Sep 17 00:00:00 2001 From: Harsh Date: Thu, 7 May 2026 21:33:24 +0530 Subject: [PATCH 67/69] chore: add crates.io publishing metadata, CHANGELOG, remove tracing-subscriber from library --- CHANGELOG.md | 32 +++++ Cargo.toml | 18 ++- examples/gateway.rs | 27 +--- examples/proxy.rs | 40 +----- src/gateway/mod.rs | 1 - src/lib.rs | 2 - src/proxy/mod.rs | 1 - src/transport/client/mod.rs | 15 --- src/transport/server/mod.rs | 15 --- src/util/mod.rs | 1 - src/util/tracing_setup.rs | 259 ------------------------------------ 11 files changed, 58 insertions(+), 353 deletions(-) create mode 100644 CHANGELOG.md delete mode 100644 src/util/mod.rs delete mode 100644 src/util/tracing_setup.rs diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..694916f --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,32 @@ +# Changelog + +## [0.1.0] - 2026-05-07 + +### Added + +- Core transport layer: `NostrClientTransport` and `NostrServerTransport` over NIP-59 gift wraps +- Gateway and Proxy high-level APIs for bridging MCP over Nostr +- Discovery API: `discover_servers`, `discover_tools`, `discover_resources`, `discover_prompts`, `discover_resource_templates` +- CEP-6: server announcement publishing and querying (kinds 11316–11320) +- CEP-19: ephemeral gift wraps (kind 21059) with `GiftWrapMode` negotiation on both client and server +- CEP-35: stateless session discovery, tag composition, and capability learning +- LRU-bounded session store with configurable capacity (default 1000 sessions) and TTL expiry +- Multi-client support in `NostrServerWorker` (removed single-peer barrier) +- Direct rmcp transport adapters via `into_rmcp_transport()` for native `ContextVM` services +- `CancellationToken`-based graceful shutdown on `close()` +- TTL sweep for client and server correlation stores to prevent pending-request leaks +- `MockRelayPool` for deterministic offline testing +- Builder pattern for all transport and worker configuration structs +- Four examples: gateway, proxy, discovery, and rmcp integration test + +### Fixed + +- Single-peer barrier in RMCP worker rejected concurrent clients (#60) +- Pending-request leak: correlation store entries never expired by TTL (#61) +- Event loop tasks not cancelled on `close()`, causing resource leaks (#63) +- `RecvError::Lagged` killing event loop under high relay throughput (#68) +- Client race condition: responses lost when publish completed before correlation registration (#55) +- Uncorrelated responses (missing `e` tag) forwarded to consumer instead of dropped (#55) +- Non-atomic `send_response` behavior in server transport (#48) +- Unbounded LRU cache initialization with zero capacity (#50) +- Announced servers not sending JSON-RPC `-32000 Unauthorized` error for disallowed clients (#53) diff --git a/Cargo.toml b/Cargo.toml index fe6c0ee..7abf096 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,9 +2,15 @@ name = "contextvm-sdk" version = "0.1.0" edition = "2021" +rust-version = "1.70" description = "Rust SDK for the ContextVM protocol — MCP over Nostr" license = "MIT" -repository = "https://github.com/k0sti/rust-contextvm-sdk" +readme = "README.md" +repository = "https://github.com/ContextVM/rs-sdk" +homepage = "https://contextvm.org" +documentation = "https://docs.rs/contextvm-sdk" +keywords = ["nostr", "mcp", "model-context-protocol", "decentralized", "ai"] +categories = ["network-programming", "api-bindings", "asynchronous"] [dependencies] # Async runtime @@ -23,7 +29,6 @@ nostr-sdk = { version = "0.43", features = ["nip59"] } # Logging tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter"] } # Optional MCP integration (Rust equivalent to TS @modelcontextprotocol/sdk) rmcp = { version = "0.16.0", features = ["server", "client", "macros", "transport-worker"], optional = true } @@ -43,7 +48,16 @@ rmcp = ["dep:rmcp"] name = "rmcp_integration_test" required-features = ["rmcp"] +[[example]] +name = "native_echo_server" +required-features = ["rmcp"] + +[[example]] +name = "native_echo_client" +required-features = ["rmcp"] + [dev-dependencies] tokio-test = "0.4" anyhow = "1" schemars = "0.8" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/gateway.rs b/examples/gateway.rs index 3a3bef9..efe8500 100644 --- a/examples/gateway.rs +++ b/examples/gateway.rs @@ -3,7 +3,7 @@ //! This demonstrates how to create a ContextVM gateway that receives //! MCP requests over Nostr and responds to them. //! -//! Usage: cargo run --example gateway -- [--log-file ] +//! Usage: cargo run --example gateway use contextvm_sdk::core::types::*; use contextvm_sdk::gateway::{GatewayConfig, NostrMCPGateway}; @@ -12,41 +12,20 @@ use contextvm_sdk::transport::server::NostrServerTransportConfig; #[tokio::main] async fn main() -> contextvm_sdk::Result<()> { - let args: Vec = std::env::args().skip(1).collect(); - let mut log_file_path: Option = None; - - let mut index = 0; - while index < args.len() { - match args[index].as_str() { - "--log-file" => { - index += 1; - let Some(path) = args.get(index) else { - panic!("Usage: gateway [--log-file ]"); - }; - log_file_path = Some(path.clone()); - } - other => { - panic!("Unknown argument: {other}. Usage: gateway [--log-file ]"); - } - } - index += 1; - } + tracing_subscriber::fmt::init(); // Generate ephemeral keys for this session let keys = signer::generate(); println!("Server pubkey: {}", keys.public_key().to_hex()); // Configure the gateway - let mut nostr_config = NostrServerTransportConfig::default() + let nostr_config = NostrServerTransportConfig::default() .with_server_info( ServerInfo::default() .with_name("Echo Server") .with_about("A simple echo tool exposed via ContextVM"), ) .with_announced_server(true); - if let Some(path) = log_file_path { - nostr_config = nostr_config.with_log_file_path(path); - } let config = GatewayConfig::new(nostr_config); let mut gateway = NostrMCPGateway::new(keys, config).await?; diff --git a/examples/proxy.rs b/examples/proxy.rs index 9f9e1c0..4aea3e3 100644 --- a/examples/proxy.rs +++ b/examples/proxy.rs @@ -1,52 +1,26 @@ //! Example: Connect to a remote MCP server via Nostr and call tools/list. //! -//! Usage: cargo run --example proxy -- [--log-file ] +//! Usage: cargo run --example proxy -- use contextvm_sdk::core::types::*; use contextvm_sdk::proxy::{NostrMCPProxy, ProxyConfig}; use contextvm_sdk::signer; use contextvm_sdk::transport::client::NostrClientTransportConfig; + #[tokio::main] async fn main() -> contextvm_sdk::Result<()> { - let args: Vec = std::env::args().skip(1).collect(); - let mut server_pubkey_hex: Option = None; - let mut log_file_path: Option = None; - - let mut index = 0; - while index < args.len() { - match args[index].as_str() { - "--log-file" => { - index += 1; - let Some(path) = args.get(index) else { - panic!("Usage: proxy [--log-file ]"); - }; - log_file_path = Some(path.clone()); - } - value => { - if server_pubkey_hex.is_none() { - server_pubkey_hex = Some(value.to_string()); - } else { - panic!( - "Unknown argument: {value}. Usage: proxy [--log-file ]" - ); - } - } - } - index += 1; - } + tracing_subscriber::fmt::init(); - let server_pubkey_hex = - server_pubkey_hex.expect("Usage: proxy [--log-file ]"); + let server_pubkey_hex = std::env::args() + .nth(1) + .expect("Usage: proxy "); let keys = signer::generate(); println!("Client pubkey: {}", keys.public_key().to_hex()); - let mut nostr_config = NostrClientTransportConfig::default() + let nostr_config = NostrClientTransportConfig::default() .with_server_pubkey(server_pubkey_hex) .with_encryption_mode(EncryptionMode::Optional); - if let Some(path) = log_file_path { - nostr_config = nostr_config.with_log_file_path(path); - } let config = ProxyConfig::new(nostr_config); let mut proxy = NostrMCPProxy::new(keys, config).await?; diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 825fa46..e4bba91 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -138,7 +138,6 @@ mod tests { cleanup_interval: Duration::from_secs(120), session_timeout: Duration::from_secs(600), request_timeout: Duration::from_secs(60), - log_file_path: None, }; let config = GatewayConfig { nostr_config }; diff --git a/src/lib.rs b/src/lib.rs index 575e356..7615c7e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -47,8 +47,6 @@ pub mod transport; #[cfg(feature = "rmcp")] pub mod rmcp_transport; -mod util; - // Re-export commonly used types pub use core::error::{Error, Result}; pub use core::types::{ diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 9aa97ae..4833127 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -121,7 +121,6 @@ mod tests { gift_wrap_mode: GiftWrapMode::Optional, is_stateless: true, timeout: Duration::from_secs(60), - log_file_path: None, }; let config = ProxyConfig { nostr_config }; diff --git a/src/transport/client/mod.rs b/src/transport/client/mod.rs index 4d7ff89..31f80c7 100644 --- a/src/transport/client/mod.rs +++ b/src/transport/client/mod.rs @@ -26,8 +26,6 @@ use crate::relay::{RelayPool, RelayPoolTrait}; use crate::transport::base::BaseTransport; use crate::transport::discovery_tags::{parse_discovered_peer_capabilities, PeerCapabilities}; -use crate::util::tracing_setup; - const LOG_TARGET: &str = "contextvm_sdk::transport::client"; /// Configuration for the client transport. @@ -49,8 +47,6 @@ pub struct NostrClientTransportConfig { /// This prevents leaks -- rmcp owns actual request timeout and cancellation. /// Keep this value above your rmcp request timeout to avoid premature cleanup. pub timeout: Duration, - /// Optional log file path. Logs always go to stdout and are also appended here when set. - pub log_file_path: Option, } impl Default for NostrClientTransportConfig { @@ -62,7 +58,6 @@ impl Default for NostrClientTransportConfig { gift_wrap_mode: GiftWrapMode::Optional, is_stateless: false, timeout: Duration::from_secs(30), - log_file_path: None, } } } @@ -98,11 +93,6 @@ impl NostrClientTransportConfig { self.timeout = timeout; self } - /// Set the log file path. - pub fn with_log_file_path(mut self, path: impl Into) -> Self { - self.log_file_path = Some(path.into()); - self - } } /// Client-side Nostr transport for sending MCP requests and receiving responses. @@ -139,8 +129,6 @@ impl NostrClientTransport { where T: IntoNostrSigner, { - tracing_setup::init_tracer(config.log_file_path.as_deref())?; - let server_pubkey = PublicKey::from_hex(&config.server_pubkey).map_err(|error| { tracing::error!( target: LOG_TARGET, @@ -198,8 +186,6 @@ impl NostrClientTransport { config: NostrClientTransportConfig, relay_pool: Arc, ) -> Result { - tracing_setup::init_tracer(config.log_file_path.as_deref())?; - let server_pubkey = PublicKey::from_hex(&config.server_pubkey).map_err(|error| { tracing::error!( target: LOG_TARGET, @@ -870,7 +856,6 @@ mod tests { assert_eq!(config.gift_wrap_mode, GiftWrapMode::Optional); assert!(!config.is_stateless); assert_eq!(config.timeout, Duration::from_secs(30)); - assert!(config.log_file_path.is_none()); } #[test] diff --git a/src/transport/server/mod.rs b/src/transport/server/mod.rs index 0082be1..ba9f6c7 100644 --- a/src/transport/server/mod.rs +++ b/src/transport/server/mod.rs @@ -29,8 +29,6 @@ use crate::relay::{RelayPool, RelayPoolTrait}; use crate::transport::base::BaseTransport; use crate::transport::discovery_tags::learn_peer_capabilities; -use crate::util::tracing_setup; - const LOG_TARGET: &str = "contextvm_sdk::transport::server"; /// Configuration for the server transport. @@ -62,8 +60,6 @@ pub struct NostrServerTransportConfig { /// This prevents leaks -- rmcp owns actual request timeout and cancellation. /// Keep this value above your rmcp request timeout to avoid premature cleanup. pub request_timeout: Duration, - /// Optional log file path. Logs always go to stdout and are also appended here when set. - pub log_file_path: Option, } impl Default for NostrServerTransportConfig { @@ -80,7 +76,6 @@ impl Default for NostrServerTransportConfig { cleanup_interval: Duration::from_secs(60), session_timeout: Duration::from_secs(300), request_timeout: Duration::from_secs(60), - log_file_path: None, } } } @@ -170,11 +165,6 @@ impl NostrServerTransportConfig { self.request_timeout = timeout; self } - /// Set the log file path. - pub fn with_log_file_path(mut self, path: impl Into) -> Self { - self.log_file_path = Some(path.into()); - self - } } /// An incoming MCP request with metadata for routing the response. @@ -197,8 +187,6 @@ impl NostrServerTransport { where T: IntoNostrSigner, { - tracing_setup::init_tracer(config.log_file_path.as_deref())?; - let relay_pool: Arc = Arc::new(RelayPool::new(signer).await.map_err(|error| { tracing::error!( @@ -246,8 +234,6 @@ impl NostrServerTransport { config: NostrServerTransportConfig, relay_pool: Arc, ) -> Result { - tracing_setup::init_tracer(config.log_file_path.as_deref())?; - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); let seen_gift_wrap_ids = Arc::new(Mutex::new(LruCache::new( NonZeroUsize::new(DEFAULT_LRU_SIZE).expect("DEFAULT_LRU_SIZE must be non-zero"), @@ -1672,7 +1658,6 @@ mod tests { assert_eq!(config.session_timeout, Duration::from_secs(300)); assert_eq!(config.request_timeout, Duration::from_secs(60)); assert!(config.server_info.is_none()); - assert!(config.log_file_path.is_none()); } // ── CEP-19 helper logic ────────────────────────────────────── diff --git a/src/util/mod.rs b/src/util/mod.rs deleted file mode 100644 index c5eb5d2..0000000 --- a/src/util/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod tracing_setup; diff --git a/src/util/tracing_setup.rs b/src/util/tracing_setup.rs deleted file mode 100644 index 313bb5d..0000000 --- a/src/util/tracing_setup.rs +++ /dev/null @@ -1,259 +0,0 @@ -//! Internal tracing subscriber setup for ContextVM transports. - -use std::fmt; -use std::fs::{File, OpenOptions}; -use std::io::{self, Write}; -use std::path::{Path, PathBuf}; -use std::sync::{Mutex, OnceLock}; - -use tracing::Event; -use tracing_subscriber::fmt::format::Writer; -use tracing_subscriber::fmt::writer::MakeWriter; -use tracing_subscriber::fmt::{FmtContext, FormatEvent, FormatFields}; -use tracing_subscriber::layer::{Layer, SubscriberExt}; -use tracing_subscriber::registry::LookupSpan; -use tracing_subscriber::{EnvFilter, Registry}; - -use crate::core::error::{Error, Result}; - -static TRACING_SETUP_LOCK: OnceLock> = OnceLock::new(); -static TRACING_INITIALIZED: OnceLock<()> = OnceLock::new(); -static LOG_DESTINATION: OnceLock> = OnceLock::new(); - -fn tracing_setup_lock() -> &'static Mutex<()> { - TRACING_SETUP_LOCK.get_or_init(|| Mutex::new(())) -} - -fn log_destination() -> &'static Mutex { - LOG_DESTINATION.get_or_init(|| Mutex::new(LogDestination::default())) -} - -pub(crate) fn init_tracer(log_file_path: Option<&str>) -> Result<()> { - let _guard = tracing_setup_lock() - .lock() - .map_err(|_| Error::Other("failed to acquire tracing setup lock".to_string()))?; - - configure_file_output(log_file_path)?; - - if TRACING_INITIALIZED.get().is_some() { - return Ok(()); - } - - let subscriber = Registry::default().with( - tracing_subscriber::fmt::layer() - .with_ansi(false) - .with_writer(ContextVmMakeWriter) - .event_format(ContextVmEventFormatter) - .with_filter(build_env_filter()), - ); - - match tracing::subscriber::set_global_default(subscriber) { - Ok(()) => { - let _ = TRACING_INITIALIZED.set(()); - Ok(()) - } - Err(error) => { - let text = error.to_string(); - if text.contains("global default trace dispatcher has already been set") { - let _ = TRACING_INITIALIZED.set(()); - Ok(()) - } else { - Err(Error::Other(format!( - "failed to initialize tracing subscriber: {text}" - ))) - } - } - } -} - -fn configure_file_output(log_file_path: Option<&str>) -> Result<()> { - let Some(path) = normalize_log_file_path(log_file_path) else { - return Ok(()); - }; - - ensure_parent_exists(&path)?; - - let file = OpenOptions::new() - .create(true) - .append(true) - .open(&path) - .map_err(|error| { - Error::Other(format!( - "failed to open log file {}: {error}", - path.display() - )) - })?; - - let mut destination = log_destination() - .lock() - .map_err(|_| Error::Other("failed to acquire log destination lock".to_string()))?; - destination.file = Some(file); - - Ok(()) -} - -fn normalize_log_file_path(log_file_path: Option<&str>) -> Option { - let trimmed = log_file_path?.trim(); - if trimmed.is_empty() { - None - } else { - Some(PathBuf::from(trimmed)) - } -} - -fn ensure_parent_exists(path: &Path) -> Result<()> { - if let Some(parent) = path.parent() { - if !parent.as_os_str().is_empty() { - std::fs::create_dir_all(parent).map_err(|error| { - Error::Other(format!( - "failed to create log directory {}: {error}", - parent.display() - )) - })?; - } - } - - Ok(()) -} - -fn build_env_filter() -> EnvFilter { - EnvFilter::try_from_default_env() - .unwrap_or_else(|_| EnvFilter::new("contextvm_sdk=info,rmcp=warn")) -} - -#[derive(Default)] -struct LogDestination { - file: Option, -} - -#[derive(Clone, Copy)] -struct ContextVmMakeWriter; - -impl<'a> MakeWriter<'a> for ContextVmMakeWriter { - type Writer = ContextVmWriter; - - fn make_writer(&'a self) -> Self::Writer { - ContextVmWriter { - stdout: io::stdout(), - } - } -} - -struct ContextVmWriter { - stdout: io::Stdout, -} - -impl Write for ContextVmWriter { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.stdout.write_all(buf)?; - - if let Ok(mut destination) = log_destination().lock() { - if let Some(file) = destination.file.as_mut() { - let _ = file.write_all(buf); - } - } - - Ok(buf.len()) - } - - fn flush(&mut self) -> io::Result<()> { - self.stdout.flush()?; - - if let Ok(mut destination) = log_destination().lock() { - if let Some(file) = destination.file.as_mut() { - let _ = file.flush(); - } - } - - Ok(()) - } -} - -#[derive(Default)] -struct MessageVisitor { - message: Option, - extra_fields: Vec<(String, String)>, -} - -impl MessageVisitor { - fn record_field(&mut self, name: &str, value: String) { - if name == "message" { - self.message = Some(value); - } else { - self.extra_fields.push((name.to_string(), value)); - } - } -} - -impl tracing::field::Visit for MessageVisitor { - fn record_i64(&mut self, field: &tracing::field::Field, value: i64) { - self.record_field(field.name(), value.to_string()); - } - - fn record_u64(&mut self, field: &tracing::field::Field, value: u64) { - self.record_field(field.name(), value.to_string()); - } - - fn record_bool(&mut self, field: &tracing::field::Field, value: bool) { - self.record_field(field.name(), value.to_string()); - } - - fn record_str(&mut self, field: &tracing::field::Field, value: &str) { - self.record_field(field.name(), value.to_string()); - } - - fn record_error( - &mut self, - field: &tracing::field::Field, - value: &(dyn std::error::Error + 'static), - ) { - self.record_field(field.name(), value.to_string()); - } - - fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn fmt::Debug) { - self.record_field(field.name(), format!("{value:?}")); - } -} - -struct ContextVmEventFormatter; - -impl FormatEvent for ContextVmEventFormatter -where - S: tracing::Subscriber + for<'span> LookupSpan<'span>, - N: for<'writer> FormatFields<'writer> + 'static, -{ - fn format_event( - &self, - _ctx: &FmtContext<'_, S, N>, - mut writer: Writer<'_>, - event: &Event<'_>, - ) -> fmt::Result { - let mut visitor = MessageVisitor::default(); - event.record(&mut visitor); - - let metadata = event.metadata(); - let timestamp = unix_timestamp(); - let level = metadata.level().to_string().to_lowercase(); - let message = visitor.message.unwrap_or_default(); - - write!( - writer, - "{timestamp}:{level}::{}:{message}", - metadata.target() - )?; - - for (key, value) in visitor.extra_fields { - write!(writer, " {key}={value}")?; - } - - writeln!(writer) - } -} - -fn unix_timestamp() -> String { - use std::time::{SystemTime, UNIX_EPOCH}; - - let now = SystemTime::now(); - let duration = now.duration_since(UNIX_EPOCH).unwrap_or_default(); - format!("{}.{:03}", duration.as_secs(), duration.subsec_millis()) -} From f33ef68b8f57f010c6768df203f74007ad76f983 Mon Sep 17 00:00:00 2001 From: ContextVM Date: Thu, 7 May 2026 21:09:52 +0200 Subject: [PATCH 68/69] fix(rmcp): bridge stateless CEP-35 requests into rmcp lifecycle --- src/rmcp_transport/pipeline_tests.rs | 158 ++++++++++++++++++++++++++- src/rmcp_transport/worker.rs | 148 +++++++++++++++++++++++-- src/transport/server/mod.rs | 134 +++++++++++++++++++++++ 3 files changed, 432 insertions(+), 8 deletions(-) diff --git a/src/rmcp_transport/pipeline_tests.rs b/src/rmcp_transport/pipeline_tests.rs index a036799..ff4844e 100644 --- a/src/rmcp_transport/pipeline_tests.rs +++ b/src/rmcp_transport/pipeline_tests.rs @@ -14,18 +14,87 @@ #[cfg(all(test, feature = "rmcp"))] mod tests { + use std::sync::Arc; + use rmcp::model::{ - ClientJsonRpcMessage, ClientResult, RequestId, ServerJsonRpcMessage, ServerResult, + CallToolRequestParams, CallToolResult, ClientJsonRpcMessage, ClientResult, ErrorData, + Implementation, ProtocolVersion, RequestId, ServerCapabilities, ServerInfo, + ServerJsonRpcMessage, ServerResult, + }; + use rmcp::{ + handler::server::{router::tool::ToolRouter, wrapper::Parameters}, + schemars, tool, tool_handler, tool_router, ClientHandler, ServerHandler, ServiceExt, }; use crate::core::serializers; + use crate::core::types::{EncryptionMode, GiftWrapMode}; use crate::core::types::{ JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, }; + use crate::relay::mock::MockRelayPool; + use crate::relay::RelayPoolTrait; use crate::rmcp_transport::convert::{ internal_to_rmcp_client_rx, internal_to_rmcp_server_rx, rmcp_client_tx_to_internal, rmcp_server_tx_to_internal, }; + use crate::transport::{ + client::{NostrClientTransport, NostrClientTransportConfig}, + server::{NostrServerTransport, NostrServerTransportConfig}, + }; + + #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] + struct EchoParams { + message: String, + } + + #[derive(Clone)] + struct StatelessTestServer { + tool_router: ToolRouter, + } + + impl StatelessTestServer { + fn new() -> Self { + Self { + tool_router: Self::tool_router(), + } + } + } + + #[tool_router] + impl StatelessTestServer { + #[tool(description = "Echo a message back unchanged")] + async fn echo( + &self, + Parameters(EchoParams { message }): Parameters, + ) -> Result { + Ok(CallToolResult::success(vec![rmcp::model::Content::text( + format!("Echo: {message}"), + )])) + } + } + + #[tool_handler] + impl ServerHandler for StatelessTestServer { + fn get_info(&self) -> ServerInfo { + ServerInfo { + protocol_version: ProtocolVersion::LATEST, + capabilities: ServerCapabilities::builder().enable_tools().build(), + server_info: Implementation { + name: "stateless-test-server".to_string(), + title: Some("Stateless Test Server".to_string()), + version: "0.1.0".to_string(), + description: Some("Stateless rmcp regression test server".to_string()), + icons: None, + website_url: None, + }, + instructions: Some("Use the echo tool".to_string()), + } + } + } + + #[derive(Clone, Default)] + struct StatelessTestClient; + impl ClientHandler for StatelessTestClient {} // ── Layer 1: Nostr event content → JsonRpcMessage ────────────────────── @@ -324,6 +393,93 @@ mod tests { } } + #[tokio::test] + async fn stateless_rmcp_roundtrip_over_mock_relay_preserves_correlation() { + let (server_pool, client_pool) = MockRelayPool::create_pair(); + let server_pubkey = server_pool + .public_key() + .await + .expect("server mock relay pubkey") + .to_hex(); + + let server_transport = NostrServerTransport::with_relay_pool( + NostrServerTransportConfig::default() + .with_relay_urls(vec!["mock://relay".to_string()]) + .with_encryption_mode(EncryptionMode::Disabled) + .with_gift_wrap_mode(GiftWrapMode::Optional), + Arc::new(server_pool), + ) + .await + .expect("server transport"); + + let server_task = tokio::spawn(async move { + StatelessTestServer::new() + .serve(server_transport) + .await + .expect("server should start") + .waiting() + .await + .expect("server should keep running until aborted"); + }); + + let client_transport = NostrClientTransport::with_relay_pool( + NostrClientTransportConfig::default() + .with_relay_urls(vec!["mock://relay".to_string()]) + .with_server_pubkey(server_pubkey) + .with_encryption_mode(EncryptionMode::Disabled) + .with_gift_wrap_mode(GiftWrapMode::Optional) + .with_stateless(true), + Arc::new(client_pool), + ) + .await + .expect("client transport"); + + let client = StatelessTestClient + .serve(client_transport) + .await + .expect("stateless client should start"); + + let peer_info = client + .peer_info() + .expect("peer info from emulated initialize"); + assert_eq!(peer_info.server_info.name, "Emulated-Stateless-Server"); + + let tools = client + .list_all_tools() + .await + .expect("tools/list should succeed"); + assert!( + tools.iter().any(|tool| tool.name == "echo"), + "expected echo tool from server" + ); + + let result = client + .call_tool(CallToolRequestParams { + name: "echo".into(), + arguments: serde_json::from_value(serde_json::json!({ + "message": "hello from stateless test" + })) + .ok(), + meta: None, + task: None, + }) + .await + .expect("tools/call should succeed"); + + let echoed = result + .content + .iter() + .find_map(|content| match &content.raw { + rmcp::model::RawContent::Text(text) => Some(text.text.clone()), + _ => None, + }) + .expect("echo response text"); + assert_eq!(echoed, "Echo: hello from stateless test"); + + client.cancel().await.expect("client cancel"); + server_task.abort(); + } + // ── Helper ────────────────────────────────────────────────────────────── fn make_request( diff --git a/src/rmcp_transport/worker.rs b/src/rmcp_transport/worker.rs index 4263c8d..1e2c642 100644 --- a/src/rmcp_transport/worker.rs +++ b/src/rmcp_transport/worker.rs @@ -15,6 +15,7 @@ use super::convert::{ }; const LOG_TARGET: &str = "contextvm_sdk::rmcp_transport::worker"; +const STATELESS_SYNTHETIC_EVENT_ID: &str = "contextvm-stateless-init"; /// rmcp server worker wrapper for ContextVM Nostr server transport. /// @@ -96,19 +97,50 @@ impl Worker for NostrServerWorker { .. } = incoming; - // Rewrite the JSON-RPC request ID to the Nostr event_id. - // Event IDs are globally unique (SHA-256), so no collision - // across clients. The transport's event-route store maps - // event_id → (client_pubkey, original_request_id) and - // restores the original ID in `send_response`. - if let JsonRpcMessage::Request(ref mut req) = message { - req.id = serde_json::json!(event_id); + let is_synthetic_initialize = matches!( + &message, + JsonRpcMessage::Request(req) + if req.method == "initialize" + && req.id == serde_json::json!(STATELESS_SYNTHETIC_EVENT_ID) + ); + + // Rewrite real wire requests to the Nostr event_id. + // Synthetic stateless bootstrap messages must retain their + // sentinel ID so their responses can be dropped before they + // ever touch transport correlation. + if !is_synthetic_initialize { + if let JsonRpcMessage::Request(ref mut req) = message { + req.id = serde_json::json!(event_id); + } } if let Some(rmcp_msg) = internal_to_rmcp_server_rx(&message) { if let Err(reason) = context.send_to_handler(rmcp_msg).await { break reason; } + + if is_synthetic_initialize { + let initialized = JsonRpcMessage::Notification( + crate::core::types::JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/initialized".to_string(), + params: None, + }, + ); + + let Some(rmcp_initialized) = internal_to_rmcp_server_rx(&initialized) else { + break WorkerQuitReason::fatal( + Self::Error::Validation( + "failed converting synthetic initialized notification to rmcp format".to_string(), + ), + "converting synthetic initialized notification", + ); + }; + + if let Err(reason) = context.send_to_handler(rmcp_initialized).await { + break reason; + } + } } else { tracing::warn!( target: LOG_TARGET, @@ -272,6 +304,15 @@ impl NostrServerWorker { ) })?; + if event_id == STATELESS_SYNTHETIC_EVENT_ID { + tracing::debug!( + target: LOG_TARGET, + event_id = %event_id, + "Dropping synthetic initialize response before wire transport" + ); + return Ok(()); + } + self.transport .send_response(&event_id, JsonRpcMessage::Response(resp)) .await @@ -283,6 +324,15 @@ impl NostrServerWorker { ) })?; + if event_id == STATELESS_SYNTHETIC_EVENT_ID { + tracing::debug!( + target: LOG_TARGET, + event_id = %event_id, + "Dropping synthetic initialize error before wire transport" + ); + return Ok(()); + } + self.transport .send_response(&event_id, JsonRpcMessage::ErrorResponse(resp)) .await @@ -298,3 +348,87 @@ impl NostrServerWorker { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::types::{JsonRpcRequest, JsonRpcResponse}; + + #[test] + fn test_synthetic_initialize_keeps_sentinel_id() { + let mut message = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(STATELESS_SYNTHETIC_EVENT_ID), + method: "initialize".to_string(), + params: Some(serde_json::json!({ + "protocolVersion": crate::core::constants::mcp_protocol_version(), + })), + }); + + let is_synthetic_initialize = matches!( + &message, + JsonRpcMessage::Request(req) + if req.method == "initialize" + && req.id == serde_json::json!(STATELESS_SYNTHETIC_EVENT_ID) + ); + + if !is_synthetic_initialize { + if let JsonRpcMessage::Request(ref mut req) = message { + req.id = serde_json::json!("real-event-id"); + } + } + + match message { + JsonRpcMessage::Request(req) => { + assert_eq!(req.id, serde_json::json!(STATELESS_SYNTHETIC_EVENT_ID)); + } + other => panic!("expected request, got {other:?}"), + } + } + + #[test] + fn test_real_request_is_rewritten_to_event_id() { + let mut message = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "tools/list".to_string(), + params: Some(serde_json::json!({})), + }); + + let is_synthetic_initialize = matches!( + &message, + JsonRpcMessage::Request(req) + if req.method == "initialize" + && req.id == serde_json::json!(STATELESS_SYNTHETIC_EVENT_ID) + ); + + if !is_synthetic_initialize { + if let JsonRpcMessage::Request(ref mut req) = message { + req.id = serde_json::json!("real-event-id"); + } + } + + match message { + JsonRpcMessage::Request(req) => { + assert_eq!(req.id, serde_json::json!("real-event-id")); + } + other => panic!("expected request, got {other:?}"), + } + } + + #[test] + fn test_synthetic_initialize_response_uses_sentinel_for_drop() { + let message = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(STATELESS_SYNTHETIC_EVENT_ID), + result: serde_json::json!({}), + }); + + match message { + JsonRpcMessage::Response(resp) => { + assert_eq!(resp.id.as_str(), Some(STATELESS_SYNTHETIC_EVENT_ID)); + } + other => panic!("expected response, got {other:?}"), + } + } +} diff --git a/src/transport/server/mod.rs b/src/transport/server/mod.rs index ba9f6c7..f60e6f5 100644 --- a/src/transport/server/mod.rs +++ b/src/transport/server/mod.rs @@ -919,6 +919,54 @@ impl NostrServerTransport { }) } + #[cfg(feature = "rmcp")] + fn synthetic_initialize_message() -> JsonRpcMessage { + JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("contextvm-stateless-init"), + method: "initialize".to_string(), + params: Some(serde_json::json!({ + "protocolVersion": crate::core::constants::mcp_protocol_version(), + "capabilities": {}, + "clientInfo": { + "name": "contextvm-stateless-client", + "version": "0.1.0" + } + })), + }) + } + + #[cfg(not(feature = "rmcp"))] + fn synthetic_initialize_message() -> JsonRpcMessage { + JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!("contextvm-stateless-init"), + method: "initialize".to_string(), + params: Some(serde_json::json!({ + "protocolVersion": crate::core::constants::mcp_protocol_version(), + "capabilities": {}, + "clientInfo": { + "name": "contextvm-stateless-client", + "version": "0.1.0" + } + })), + }) + } + + fn should_inject_synthetic_initialize( + session: &ClientSession, + mcp_msg: &JsonRpcMessage, + ) -> bool { + if session.is_initialized { + return false; + } + + matches!( + mcp_msg, + JsonRpcMessage::Request(req) if req.method != "initialize" + ) + } + #[allow(clippy::too_many_arguments)] async fn event_loop( relay_pool: Arc, @@ -1225,6 +1273,12 @@ impl NostrServerTransport { session.supports_oversized_transfer |= oversized_enabled && discovered.supports_oversized_transfer; + let should_inject_initialize = + Self::should_inject_synthetic_initialize(session, &mcp_msg); + if should_inject_initialize { + session.is_initialized = true; + } + // Track request for correlation if let JsonRpcMessage::Request(ref req) = mcp_msg { let original_id = req.id.clone(); @@ -1284,6 +1338,16 @@ impl NostrServerTransport { } } + // Forward a synthetic initialize first for stateless first-request sessions. + if should_inject_initialize { + let _ = tx.send(IncomingRequest { + message: Self::synthetic_initialize_message(), + client_pubkey: sender_pubkey.clone(), + event_id: event_id.clone(), + is_encrypted, + }); + } + // Forward to consumer let _ = tx.send(IncomingRequest { message: mcp_msg, @@ -1562,6 +1626,76 @@ mod tests { )); } + #[test] + fn test_should_inject_synthetic_initialize_for_first_non_initialize_request() { + let session = ClientSession::new(false); + let message = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "tools/list".to_string(), + params: Some(serde_json::json!({})), + }); + + assert!(NostrServerTransport::should_inject_synthetic_initialize( + &session, &message, + )); + } + + #[test] + fn test_should_not_inject_synthetic_initialize_for_real_initialize_request() { + let session = ClientSession::new(false); + let message = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "initialize".to_string(), + params: Some(serde_json::json!({})), + }); + + assert!(!NostrServerTransport::should_inject_synthetic_initialize( + &session, &message, + )); + } + + #[test] + fn test_should_not_inject_synthetic_initialize_after_session_initialized() { + let mut session = ClientSession::new(false); + session.is_initialized = true; + let message = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), + method: "tools/list".to_string(), + params: Some(serde_json::json!({})), + }); + + assert!(!NostrServerTransport::should_inject_synthetic_initialize( + &session, &message, + )); + } + + #[test] + fn test_synthetic_initialize_message_shape() { + let message = NostrServerTransport::synthetic_initialize_message(); + let JsonRpcMessage::Request(request) = message else { + panic!("expected request"); + }; + + assert_eq!(request.method, "initialize"); + assert_eq!(request.id, serde_json::json!("contextvm-stateless-init")); + + let params = request.params.expect("initialize params"); + assert_eq!( + params.get("protocolVersion").and_then(|v| v.as_str()), + Some(crate::core::constants::mcp_protocol_version()) + ); + assert_eq!( + params + .get("clientInfo") + .and_then(|v| v.get("name")) + .and_then(|v| v.as_str()), + Some("contextvm-stateless-client") + ); + } + #[test] fn test_method_excluded_without_name() { let exclusions = vec![CapabilityExclusion { From efe52eabf791fc130a0e3f3783d0f441e10c43f9 Mon Sep 17 00:00:00 2001 From: ContextVM Date: Thu, 7 May 2026 21:20:21 +0200 Subject: [PATCH 69/69] refactor(rmcp): move stateless client bootstrap logic from transport to worker Moves the synthetic initialize/initialized message injection logic from the NostrServerTransport to the NostrServerWorker. This better separates concerns by having the rmcp layer handle MCP protocol lifecycle while the transport layer focuses on Nostr-specific communication. Also adds a HashSet to track initialized clients per pubkey for proper bootstrap injection. --- src/rmcp_transport/worker.rs | 204 +++++++++++++++++++++++++---------- src/transport/server/mod.rs | 134 ----------------------- 2 files changed, 145 insertions(+), 193 deletions(-) diff --git a/src/rmcp_transport/worker.rs b/src/rmcp_transport/worker.rs index 1e2c642..023354f 100644 --- a/src/rmcp_transport/worker.rs +++ b/src/rmcp_transport/worker.rs @@ -4,10 +4,11 @@ //! transports to rmcp's worker abstraction. use crate::core::error::Result; -use crate::core::types::JsonRpcMessage; +use crate::core::types::{JsonRpcMessage, JsonRpcNotification, JsonRpcRequest}; use crate::transport::client::{NostrClientTransport, NostrClientTransportConfig}; use crate::transport::server::{NostrServerTransport, NostrServerTransportConfig}; use rmcp::transport::worker::{Worker, WorkerContext, WorkerQuitReason}; +use std::collections::HashSet; use super::convert::{ internal_to_rmcp_client_rx, internal_to_rmcp_server_rx, rmcp_client_tx_to_internal, @@ -17,6 +18,51 @@ use super::convert::{ const LOG_TARGET: &str = "contextvm_sdk::rmcp_transport::worker"; const STATELESS_SYNTHETIC_EVENT_ID: &str = "contextvm-stateless-init"; +fn synthetic_initialize_message() -> JsonRpcMessage { + JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(STATELESS_SYNTHETIC_EVENT_ID), + method: "initialize".to_string(), + params: Some(serde_json::json!({ + "protocolVersion": crate::core::constants::mcp_protocol_version(), + "capabilities": {}, + "clientInfo": { + "name": "contextvm-stateless-client", + "version": "0.1.0" + } + })), + }) +} + +fn synthetic_initialized_notification() -> JsonRpcMessage { + JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "notifications/initialized".to_string(), + params: None, + }) +} + +fn should_inject_stateless_bootstrap( + initialized_clients: &HashSet, + client_pubkey: &str, + message: &JsonRpcMessage, +) -> bool { + if initialized_clients.contains(client_pubkey) { + return false; + } + + matches!(message, JsonRpcMessage::Request(req) if req.method != "initialize") +} + +fn is_synthetic_initialize_message(message: &JsonRpcMessage) -> bool { + matches!( + message, + JsonRpcMessage::Request(req) + if req.method == "initialize" + && req.id == serde_json::json!(STATELESS_SYNTHETIC_EVENT_ID) + ) +} + /// rmcp server worker wrapper for ContextVM Nostr server transport. /// /// Multiplexes all connected clients through a single rmcp service instance. @@ -80,6 +126,7 @@ impl Worker for NostrServerWorker { })?; let cancellation_token = context.cancellation_token.clone(); + let mut initialized_clients = HashSet::new(); let quit_reason = loop { tokio::select! { @@ -94,23 +141,61 @@ impl Worker for NostrServerWorker { let crate::transport::server::IncomingRequest { mut message, event_id, + client_pubkey, .. } = incoming; - let is_synthetic_initialize = matches!( + let should_inject_bootstrap = should_inject_stateless_bootstrap( + &initialized_clients, + &client_pubkey, &message, - JsonRpcMessage::Request(req) - if req.method == "initialize" - && req.id == serde_json::json!(STATELESS_SYNTHETIC_EVENT_ID) ); + if should_inject_bootstrap { + let synthetic_init = synthetic_initialize_message(); + let Some(rmcp_init) = internal_to_rmcp_server_rx(&synthetic_init) else { + break WorkerQuitReason::fatal( + Self::Error::Validation( + "failed converting synthetic initialize request to rmcp format".to_string(), + ), + "converting synthetic initialize request", + ); + }; + + if let Err(reason) = context.send_to_handler(rmcp_init).await { + break reason; + } + + let initialized = synthetic_initialized_notification(); + let Some(rmcp_initialized) = internal_to_rmcp_server_rx(&initialized) else { + break WorkerQuitReason::fatal( + Self::Error::Validation( + "failed converting synthetic initialized notification to rmcp format".to_string(), + ), + "converting synthetic initialized notification", + ); + }; + + if let Err(reason) = context.send_to_handler(rmcp_initialized).await { + break reason; + } + + initialized_clients.insert(client_pubkey.clone()); + } + + if matches!(&message, JsonRpcMessage::Request(req) if req.method == "initialize") + || matches!(&message, JsonRpcMessage::Notification(n) if n.method == "notifications/initialized") + { + initialized_clients.insert(client_pubkey.clone()); + } + // Rewrite real wire requests to the Nostr event_id. // Synthetic stateless bootstrap messages must retain their // sentinel ID so their responses can be dropped before they // ever touch transport correlation. - if !is_synthetic_initialize { + if !is_synthetic_initialize_message(&message) { if let JsonRpcMessage::Request(ref mut req) = message { - req.id = serde_json::json!(event_id); + req.id = serde_json::json!(event_id); } } @@ -118,29 +203,6 @@ impl Worker for NostrServerWorker { if let Err(reason) = context.send_to_handler(rmcp_msg).await { break reason; } - - if is_synthetic_initialize { - let initialized = JsonRpcMessage::Notification( - crate::core::types::JsonRpcNotification { - jsonrpc: "2.0".to_string(), - method: "notifications/initialized".to_string(), - params: None, - }, - ); - - let Some(rmcp_initialized) = internal_to_rmcp_server_rx(&initialized) else { - break WorkerQuitReason::fatal( - Self::Error::Validation( - "failed converting synthetic initialized notification to rmcp format".to_string(), - ), - "converting synthetic initialized notification", - ); - }; - - if let Err(reason) = context.send_to_handler(rmcp_initialized).await { - break reason; - } - } } else { tracing::warn!( target: LOG_TARGET, @@ -352,35 +414,50 @@ impl NostrServerWorker { #[cfg(test)] mod tests { use super::*; - use crate::core::types::{JsonRpcRequest, JsonRpcResponse}; + use crate::core::types::JsonRpcResponse; #[test] - fn test_synthetic_initialize_keeps_sentinel_id() { - let mut message = JsonRpcMessage::Request(JsonRpcRequest { + fn test_should_inject_stateless_bootstrap_for_first_non_initialize_request() { + let initialized_clients = HashSet::new(); + let message = JsonRpcMessage::Request(JsonRpcRequest { jsonrpc: "2.0".to_string(), - id: serde_json::json!(STATELESS_SYNTHETIC_EVENT_ID), + id: serde_json::json!(1), + method: "tools/list".to_string(), + params: Some(serde_json::json!({})), + }); + + assert!(should_inject_stateless_bootstrap( + &initialized_clients, + "client-a", + &message, + )); + } + + #[test] + fn test_should_not_inject_stateless_bootstrap_for_real_initialize() { + let initialized_clients = HashSet::new(); + let message = JsonRpcMessage::Request(JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: serde_json::json!(1), method: "initialize".to_string(), - params: Some(serde_json::json!({ - "protocolVersion": crate::core::constants::mcp_protocol_version(), - })), + params: Some(serde_json::json!({})), }); - let is_synthetic_initialize = matches!( + assert!(!should_inject_stateless_bootstrap( + &initialized_clients, + "client-a", &message, - JsonRpcMessage::Request(req) - if req.method == "initialize" - && req.id == serde_json::json!(STATELESS_SYNTHETIC_EVENT_ID) - ); - - if !is_synthetic_initialize { - if let JsonRpcMessage::Request(ref mut req) = message { - req.id = serde_json::json!("real-event-id"); - } - } + )); + } + + #[test] + fn test_synthetic_initialize_keeps_sentinel_id() { + let message = synthetic_initialize_message(); match message { JsonRpcMessage::Request(req) => { assert_eq!(req.id, serde_json::json!(STATELESS_SYNTHETIC_EVENT_ID)); + assert_eq!(req.method, "initialize"); } other => panic!("expected request, got {other:?}"), } @@ -395,17 +472,8 @@ mod tests { params: Some(serde_json::json!({})), }); - let is_synthetic_initialize = matches!( - &message, - JsonRpcMessage::Request(req) - if req.method == "initialize" - && req.id == serde_json::json!(STATELESS_SYNTHETIC_EVENT_ID) - ); - - if !is_synthetic_initialize { - if let JsonRpcMessage::Request(ref mut req) = message { - req.id = serde_json::json!("real-event-id"); - } + if let JsonRpcMessage::Request(ref mut req) = message { + req.id = serde_json::json!("real-event-id"); } match message { @@ -431,4 +499,22 @@ mod tests { other => panic!("expected response, got {other:?}"), } } + + #[test] + fn test_synthetic_initialized_notification_shape() { + let message = synthetic_initialized_notification(); + match message { + JsonRpcMessage::Notification(notification) => { + assert_eq!(notification.method, "notifications/initialized"); + } + other => panic!("expected notification, got {other:?}"), + } + } + + #[test] + fn test_is_synthetic_initialize_message_detects_sentinel() { + assert!(is_synthetic_initialize_message( + &synthetic_initialize_message() + )); + } } diff --git a/src/transport/server/mod.rs b/src/transport/server/mod.rs index f60e6f5..ba9f6c7 100644 --- a/src/transport/server/mod.rs +++ b/src/transport/server/mod.rs @@ -919,54 +919,6 @@ impl NostrServerTransport { }) } - #[cfg(feature = "rmcp")] - fn synthetic_initialize_message() -> JsonRpcMessage { - JsonRpcMessage::Request(JsonRpcRequest { - jsonrpc: "2.0".to_string(), - id: serde_json::json!("contextvm-stateless-init"), - method: "initialize".to_string(), - params: Some(serde_json::json!({ - "protocolVersion": crate::core::constants::mcp_protocol_version(), - "capabilities": {}, - "clientInfo": { - "name": "contextvm-stateless-client", - "version": "0.1.0" - } - })), - }) - } - - #[cfg(not(feature = "rmcp"))] - fn synthetic_initialize_message() -> JsonRpcMessage { - JsonRpcMessage::Request(JsonRpcRequest { - jsonrpc: "2.0".to_string(), - id: serde_json::json!("contextvm-stateless-init"), - method: "initialize".to_string(), - params: Some(serde_json::json!({ - "protocolVersion": crate::core::constants::mcp_protocol_version(), - "capabilities": {}, - "clientInfo": { - "name": "contextvm-stateless-client", - "version": "0.1.0" - } - })), - }) - } - - fn should_inject_synthetic_initialize( - session: &ClientSession, - mcp_msg: &JsonRpcMessage, - ) -> bool { - if session.is_initialized { - return false; - } - - matches!( - mcp_msg, - JsonRpcMessage::Request(req) if req.method != "initialize" - ) - } - #[allow(clippy::too_many_arguments)] async fn event_loop( relay_pool: Arc, @@ -1273,12 +1225,6 @@ impl NostrServerTransport { session.supports_oversized_transfer |= oversized_enabled && discovered.supports_oversized_transfer; - let should_inject_initialize = - Self::should_inject_synthetic_initialize(session, &mcp_msg); - if should_inject_initialize { - session.is_initialized = true; - } - // Track request for correlation if let JsonRpcMessage::Request(ref req) = mcp_msg { let original_id = req.id.clone(); @@ -1338,16 +1284,6 @@ impl NostrServerTransport { } } - // Forward a synthetic initialize first for stateless first-request sessions. - if should_inject_initialize { - let _ = tx.send(IncomingRequest { - message: Self::synthetic_initialize_message(), - client_pubkey: sender_pubkey.clone(), - event_id: event_id.clone(), - is_encrypted, - }); - } - // Forward to consumer let _ = tx.send(IncomingRequest { message: mcp_msg, @@ -1626,76 +1562,6 @@ mod tests { )); } - #[test] - fn test_should_inject_synthetic_initialize_for_first_non_initialize_request() { - let session = ClientSession::new(false); - let message = JsonRpcMessage::Request(JsonRpcRequest { - jsonrpc: "2.0".to_string(), - id: serde_json::json!(1), - method: "tools/list".to_string(), - params: Some(serde_json::json!({})), - }); - - assert!(NostrServerTransport::should_inject_synthetic_initialize( - &session, &message, - )); - } - - #[test] - fn test_should_not_inject_synthetic_initialize_for_real_initialize_request() { - let session = ClientSession::new(false); - let message = JsonRpcMessage::Request(JsonRpcRequest { - jsonrpc: "2.0".to_string(), - id: serde_json::json!(1), - method: "initialize".to_string(), - params: Some(serde_json::json!({})), - }); - - assert!(!NostrServerTransport::should_inject_synthetic_initialize( - &session, &message, - )); - } - - #[test] - fn test_should_not_inject_synthetic_initialize_after_session_initialized() { - let mut session = ClientSession::new(false); - session.is_initialized = true; - let message = JsonRpcMessage::Request(JsonRpcRequest { - jsonrpc: "2.0".to_string(), - id: serde_json::json!(1), - method: "tools/list".to_string(), - params: Some(serde_json::json!({})), - }); - - assert!(!NostrServerTransport::should_inject_synthetic_initialize( - &session, &message, - )); - } - - #[test] - fn test_synthetic_initialize_message_shape() { - let message = NostrServerTransport::synthetic_initialize_message(); - let JsonRpcMessage::Request(request) = message else { - panic!("expected request"); - }; - - assert_eq!(request.method, "initialize"); - assert_eq!(request.id, serde_json::json!("contextvm-stateless-init")); - - let params = request.params.expect("initialize params"); - assert_eq!( - params.get("protocolVersion").and_then(|v| v.as_str()), - Some(crate::core::constants::mcp_protocol_version()) - ); - assert_eq!( - params - .get("clientInfo") - .and_then(|v| v.get("name")) - .and_then(|v| v.as_str()), - Some("contextvm-stateless-client") - ); - } - #[test] fn test_method_excluded_without_name() { let exclusions = vec![CapabilityExclusion {