diff --git a/python/langsmith/sandbox/__init__.py b/python/langsmith/sandbox/__init__.py index 809aa2f57..82a4d7b3e 100644 --- a/python/langsmith/sandbox/__init__.py +++ b/python/langsmith/sandbox/__init__.py @@ -30,6 +30,14 @@ from langsmith.sandbox._async_client import AsyncSandboxClient from langsmith.sandbox._async_sandbox import AsyncSandbox +from langsmith.sandbox._callback_verifier import ( + SANDBOX_CALLBACK_SIGNATURE_HEADER, + SANDBOX_CALLBACK_SUBJECT, + SandboxCallbackClaims, + SandboxCallbackIdentity, + SandboxCallbackVerificationError, + SandboxCallbackVerifier, +) from langsmith.sandbox._client import SandboxClient from langsmith.sandbox._exceptions import ( CommandTimeoutError, @@ -73,6 +81,13 @@ "AsyncSandboxClient", "Sandbox", "AsyncSandbox", + # Callback verification + "SANDBOX_CALLBACK_SIGNATURE_HEADER", + "SANDBOX_CALLBACK_SUBJECT", + "SandboxCallbackClaims", + "SandboxCallbackIdentity", + "SandboxCallbackVerifier", + "SandboxCallbackVerificationError", # Models "ResourceStatus", "ExecutionResult", diff --git a/python/langsmith/sandbox/_callback_verifier.py b/python/langsmith/sandbox/_callback_verifier.py new file mode 100644 index 000000000..f7a25360b --- /dev/null +++ b/python/langsmith/sandbox/_callback_verifier.py @@ -0,0 +1,490 @@ +"""Verify signed LangSmith sandbox callback requests.""" + +from __future__ import annotations + +import base64 +import binascii +import dataclasses +import hashlib +import json +import threading +import time +from collections.abc import Mapping, Sequence +from datetime import datetime, timezone +from typing import Any, Optional, Union +from urllib.parse import urlparse + +import httpx + +from langsmith import utils as ls_utils + +SANDBOX_CALLBACK_SIGNATURE_HEADER = "X-LangSmith-Signature-JWT" +SANDBOX_CALLBACK_SUBJECT = "langsmith-sandbox-callback" + +_JWKS_PATH = "/.well-known/jwks.json" +_HeaderValue = Union[str, bytes] +_Headers = Union[ + Mapping[str, _HeaderValue], Sequence[tuple[_HeaderValue, _HeaderValue]] +] + + +class SandboxCallbackVerificationError(ValueError): + """Raised when a sandbox callback signature cannot be verified.""" + + +@dataclasses.dataclass(frozen=True) +class SandboxCallbackIdentity: + """Verified LangSmith sandbox callback body identity.""" + + tenant_id: str + organization_id: str + sandbox_id: str + ls_user_id: Optional[str] = None + + +@dataclasses.dataclass(frozen=True) +class SandboxCallbackClaims: + """Verified LangSmith sandbox callback JWT claims and body identity.""" + + issuer: str + audience: Union[str, tuple[str, ...]] + subject: str + identity: SandboxCallbackIdentity + body_sha256: str + jti: str + issued_at: datetime + not_before: datetime + expires_at: datetime + + @property + def tenant_id(self) -> str: + return self.identity.tenant_id + + @property + def organization_id(self) -> str: + return self.identity.organization_id + + @property + def sandbox_id(self) -> str: + return self.identity.sandbox_id + + @property + def ls_user_id(self) -> Optional[str]: + return self.identity.ls_user_id + + +class SandboxCallbackVerifier: + """Verifier for LangSmith sandbox callback requests.""" + + def __init__( + self, + *, + api_url: Optional[str] = None, + api_key: Optional[str] = None, + workspace_id: Optional[str] = None, + expected_tenant_id: Optional[str] = None, + expected_organization_id: Optional[str] = None, + leeway_seconds: float = 60, + jwks_cache_ttl_seconds: float = 300, + timeout: float = 5.0, + ) -> None: + """Initialize the verifier using normal LangSmith SDK endpoint defaults.""" + self.api_url = ls_utils.get_api_url(api_url) + self.api_key = ls_utils.get_api_key(api_key) + self.workspace_id = ls_utils.get_workspace_id(workspace_id) + self.issuer_url = _origin_url(self.api_url) + self.expected_organization_id = _clean_optional(expected_organization_id) + self.leeway_seconds = leeway_seconds + self.jwks_cache_ttl_seconds = jwks_cache_ttl_seconds + self.timeout = timeout + + self._expected_tenant_id = _clean_optional(expected_tenant_id) + self._jwks: Optional[Mapping[str, Any]] = None + self._jwks_expires_at = 0.0 + self._lock = threading.Lock() + + @classmethod + def from_client( + cls, + client: Any, + *, + expected_tenant_id: Optional[str] = None, + expected_organization_id: Optional[str] = None, + leeway_seconds: float = 60, + jwks_cache_ttl_seconds: float = 300, + timeout: float = 5.0, + ) -> SandboxCallbackVerifier: + """Create a verifier from an existing LangSmith client.""" + return cls( + api_url=getattr(client, "api_url", None), + api_key=getattr(client, "api_key", None), + workspace_id=getattr(client, "workspace_id", None), + expected_tenant_id=expected_tenant_id, + expected_organization_id=expected_organization_id, + leeway_seconds=leeway_seconds, + jwks_cache_ttl_seconds=jwks_cache_ttl_seconds, + timeout=timeout, + ) + + def verify( + self, + headers: _Headers, + body: Union[bytes, bytearray, memoryview, str], + *, + expected_sandbox_id: Optional[str] = None, + unsafely_allow_any_sandbox_id: bool = False, + ) -> SandboxCallbackClaims: + """Verify a sandbox callback request. + + Args: + headers: Incoming callback request headers. + body: Exact raw callback request body bytes. + expected_sandbox_id: Sandbox ID the callback is expected to be for. + unsafely_allow_any_sandbox_id: Skip sandbox ID matching. This is only + safe if the caller verifies ``claims.identity.sandbox_id`` before + trusting the callback. + """ + expected_sandbox_id = _clean_optional(expected_sandbox_id) + if expected_sandbox_id and unsafely_allow_any_sandbox_id: + raise SandboxCallbackVerificationError( + "Pass expected_sandbox_id or unsafely_allow_any_sandbox_id=True, " + "not both" + ) + if not expected_sandbox_id and not unsafely_allow_any_sandbox_id: + raise SandboxCallbackVerificationError( + "expected_sandbox_id is required unless " + "unsafely_allow_any_sandbox_id=True" + ) + + signature_jwt = _get_header(headers, SANDBOX_CALLBACK_SIGNATURE_HEADER) + if not signature_jwt: + raise SandboxCallbackVerificationError( + f"Missing {SANDBOX_CALLBACK_SIGNATURE_HEADER} header" + ) + + header, claims, signing_input, signature = _decode_jwt(signature_jwt) + if header.get("alg") != "EdDSA": + raise SandboxCallbackVerificationError( + "Sandbox callback JWT must use EdDSA" + ) + kid = header.get("kid") + if not isinstance(kid, str) or not kid: + raise SandboxCallbackVerificationError( + "Sandbox callback JWT is missing kid" + ) + + key = self._find_key(kid) + _verify_eddsa_signature(key, signing_input, signature) + + body_bytes = _body_bytes(body) + expected_body_hash = hashlib.sha256(body_bytes).hexdigest() + body_sha256 = claims.get("body_sha256") + if body_sha256 != expected_body_hash: + raise SandboxCallbackVerificationError( + "Sandbox callback body hash mismatch" + ) + + identity = _identity_from_body(body_bytes) + expected_tenant_id = self._get_expected_tenant_id() + _validate_claims( + claims, + issuer_url=self.issuer_url, + leeway_seconds=self.leeway_seconds, + ) + _validate_identity( + identity, + expected_tenant_id=expected_tenant_id, + expected_organization_id=self.expected_organization_id, + expected_sandbox_id=expected_sandbox_id, + ) + + return _claims_to_result(claims, identity) + + def _find_key(self, kid: str) -> Mapping[str, Any]: + try: + return _find_jwk(self._get_jwks(), kid) + except SandboxCallbackVerificationError: + return _find_jwk(self._get_jwks(force=True), kid) + + def _get_jwks(self, *, force: bool = False) -> Mapping[str, Any]: + now = time.time() + with self._lock: + if not force and self._jwks is not None and self._jwks_expires_at > now: + return self._jwks + url = self.issuer_url + _JWKS_PATH + try: + response = httpx.get(url, timeout=self.timeout, follow_redirects=False) + response.raise_for_status() + jwks = response.json() + except Exception as exc: + msg = f"Failed to fetch LangSmith sandbox callback JWKS from {url!r}" + raise SandboxCallbackVerificationError(msg) from exc + if not isinstance(jwks, Mapping): + raise SandboxCallbackVerificationError( + "JWKS response must be a JSON object" + ) + self._jwks = jwks + self._jwks_expires_at = time.time() + self.jwks_cache_ttl_seconds + return jwks + + def _get_expected_tenant_id(self) -> str: + with self._lock: + if self._expected_tenant_id: + return self._expected_tenant_id + url = self.api_url.rstrip("/") + "/sessions" + try: + response = httpx.get( + url, + params={"limit": 1}, + headers=self._auth_headers(), + timeout=self.timeout, + follow_redirects=False, + ) + response.raise_for_status() + data = response.json() + except Exception as exc: + raise SandboxCallbackVerificationError( + "Failed to fetch LangSmith tenant ID" + ) from exc + if not isinstance(data, list) or not data: + raise SandboxCallbackVerificationError("No LangSmith tenant ID found") + tenant_id = ( + data[0].get("tenant_id") if isinstance(data[0], Mapping) else None + ) + if not isinstance(tenant_id, str) or not tenant_id: + raise SandboxCallbackVerificationError("No LangSmith tenant ID found") + self._expected_tenant_id = tenant_id + return tenant_id + + def _auth_headers(self) -> dict[str, str]: + headers = {"Accept": "application/json"} + if self.api_key: + headers["X-Api-Key"] = self.api_key + if self.workspace_id: + headers["X-Tenant-Id"] = self.workspace_id + return headers + + +def _decode_jwt( + token: str, +) -> tuple[Mapping[str, Any], Mapping[str, Any], bytes, bytes]: + parts = token.split(".") + if len(parts) != 3: + raise SandboxCallbackVerificationError("Sandbox callback JWT is malformed") + try: + header = json.loads(_b64url_decode(parts[0])) + claims = json.loads(_b64url_decode(parts[1])) + signature = _b64url_decode(parts[2]) + except (binascii.Error, UnicodeDecodeError, json.JSONDecodeError) as exc: + raise SandboxCallbackVerificationError( + "Sandbox callback JWT is malformed" + ) from exc + if not isinstance(header, Mapping) or not isinstance(claims, Mapping): + raise SandboxCallbackVerificationError("Sandbox callback JWT is malformed") + return header, claims, f"{parts[0]}.{parts[1]}".encode(), signature + + +def _find_jwk(jwks: Mapping[str, Any], kid: str) -> Mapping[str, Any]: + keys = jwks.get("keys") + if not isinstance(keys, Sequence) or isinstance(keys, (str, bytes)): + raise SandboxCallbackVerificationError("JWKS must contain a keys array") + for key in keys: + if isinstance(key, Mapping) and key.get("kid") == kid: + return key + raise SandboxCallbackVerificationError("No matching JWK for sandbox callback JWT") + + +def _verify_eddsa_signature( + jwk: Mapping[str, Any], + signing_input: bytes, + signature: bytes, +) -> None: + if jwk.get("kty") != "OKP" or jwk.get("crv") != "Ed25519": + raise SandboxCallbackVerificationError("Sandbox callback JWK must be Ed25519") + x = jwk.get("x") + if not isinstance(x, str) or not x: + raise SandboxCallbackVerificationError("Sandbox callback JWK is missing x") + try: + from cryptography.exceptions import InvalidSignature + from cryptography.hazmat.primitives.asymmetric.ed25519 import ( + Ed25519PublicKey, + ) + except ImportError as exc: + raise SandboxCallbackVerificationError( + "Install cryptography to verify sandbox callback signatures" + ) from exc + try: + public_key = Ed25519PublicKey.from_public_bytes(_b64url_decode(x)) + public_key.verify(signature, signing_input) + except InvalidSignature as exc: + raise SandboxCallbackVerificationError( + "Sandbox callback JWT signature is invalid" + ) from exc + except Exception as exc: + raise SandboxCallbackVerificationError( + "Sandbox callback JWK is invalid" + ) from exc + + +def _validate_claims( + claims: Mapping[str, Any], + *, + issuer_url: str, + leeway_seconds: float, +) -> None: + now = time.time() + if claims.get("iss") != issuer_url: + raise SandboxCallbackVerificationError("Sandbox callback issuer mismatch") + if claims.get("sub") != SANDBOX_CALLBACK_SUBJECT: + raise SandboxCallbackVerificationError("Sandbox callback subject mismatch") + if not _non_empty_audience(claims.get("aud")): + raise SandboxCallbackVerificationError("Sandbox callback JWT is missing aud") + exp = _numeric_claim(claims, "exp") + nbf = _numeric_claim(claims, "nbf") + _numeric_claim(claims, "iat") + if exp + leeway_seconds < now: + raise SandboxCallbackVerificationError("Sandbox callback JWT is expired") + if nbf - leeway_seconds > now: + raise SandboxCallbackVerificationError("Sandbox callback JWT is not yet valid") + for name in ("jti", "body_sha256"): + value = claims.get(name) + if not isinstance(value, str) or not value: + raise SandboxCallbackVerificationError( + f"Sandbox callback JWT is missing {name}" + ) + + +def _validate_identity( + identity: SandboxCallbackIdentity, + *, + expected_tenant_id: str, + expected_organization_id: Optional[str], + expected_sandbox_id: Optional[str], +) -> None: + if identity.tenant_id != expected_tenant_id: + raise SandboxCallbackVerificationError("Sandbox callback tenant_id mismatch") + if ( + expected_organization_id is not None + and identity.organization_id != expected_organization_id + ): + raise SandboxCallbackVerificationError( + "Sandbox callback organization_id mismatch" + ) + if expected_sandbox_id is not None and identity.sandbox_id != expected_sandbox_id: + raise SandboxCallbackVerificationError("Sandbox callback sandbox_id mismatch") + + +def _claims_to_result( + claims: Mapping[str, Any], + identity: SandboxCallbackIdentity, +) -> SandboxCallbackClaims: + return SandboxCallbackClaims( + issuer=str(claims["iss"]), + audience=_audience_result(claims["aud"]), + subject=str(claims["sub"]), + identity=identity, + body_sha256=str(claims["body_sha256"]), + jti=str(claims["jti"]), + issued_at=_datetime_from_timestamp(_numeric_claim(claims, "iat")), + not_before=_datetime_from_timestamp(_numeric_claim(claims, "nbf")), + expires_at=_datetime_from_timestamp(_numeric_claim(claims, "exp")), + ) + + +def _identity_from_body(body: bytes) -> SandboxCallbackIdentity: + try: + payload = json.loads(body) + except (UnicodeDecodeError, json.JSONDecodeError) as exc: + raise SandboxCallbackVerificationError( + "Sandbox callback body is malformed" + ) from exc + if not isinstance(payload, Mapping): + raise SandboxCallbackVerificationError("Sandbox callback body is malformed") + identity = payload.get("identity") + if not isinstance(identity, Mapping): + raise SandboxCallbackVerificationError( + "Sandbox callback body is missing identity" + ) + values: dict[str, str] = {} + for name in ("tenant_id", "organization_id", "sandbox_id"): + value = identity.get(name) + if not isinstance(value, str) or not value: + raise SandboxCallbackVerificationError( + f"Sandbox callback body identity is missing {name}" + ) + values[name] = value + raw_ls_user_id = identity.get("ls_user_id") + if raw_ls_user_id is not None and ( + not isinstance(raw_ls_user_id, str) or not raw_ls_user_id + ): + raise SandboxCallbackVerificationError( + "Sandbox callback body identity has invalid ls_user_id" + ) + ls_user_id = raw_ls_user_id if isinstance(raw_ls_user_id, str) else None + return SandboxCallbackIdentity( + tenant_id=values["tenant_id"], + organization_id=values["organization_id"], + sandbox_id=values["sandbox_id"], + ls_user_id=ls_user_id, + ) + + +def _non_empty_audience(audience: Any) -> bool: + if isinstance(audience, str): + return bool(audience) + if isinstance(audience, Sequence) and not isinstance(audience, (bytes, bytearray)): + return any(isinstance(item, str) and item for item in audience) + return False + + +def _audience_result(audience: Any) -> Union[str, tuple[str, ...]]: + if isinstance(audience, str): + return audience + return tuple(item for item in audience if isinstance(item, str)) + + +def _numeric_claim(claims: Mapping[str, Any], name: str) -> float: + value = claims.get(name) + if isinstance(value, bool) or not isinstance(value, (int, float)): + raise SandboxCallbackVerificationError( + f"Sandbox callback JWT is missing numeric {name}" + ) + return float(value) + + +def _datetime_from_timestamp(value: float) -> datetime: + return datetime.fromtimestamp(value, tz=timezone.utc) + + +def _get_header(headers: _Headers, name: str) -> Optional[str]: + needle = name.lower() + items = headers.items() if isinstance(headers, Mapping) else headers + for key, value in items: + key_str = key.decode() if isinstance(key, bytes) else key + if key_str.lower() == needle: + return value.decode() if isinstance(value, bytes) else value + return None + + +def _body_bytes(body: Union[bytes, bytearray, memoryview, str]) -> bytes: + if isinstance(body, str): + return body.encode() + return bytes(body) + + +def _b64url_decode(value: str) -> bytes: + return base64.urlsafe_b64decode(value + "=" * (-len(value) % 4)) + + +def _origin_url(url: str) -> str: + parsed = urlparse(url) + if not parsed.scheme or not parsed.netloc: + raise SandboxCallbackVerificationError("LangSmith API URL must be absolute") + return f"{parsed.scheme}://{parsed.netloc}" + + +def _clean_optional(value: Optional[str]) -> Optional[str]: + if value is None: + return None + value = value.strip().strip('"').strip("'") + return value or None diff --git a/python/tests/unit_tests/sandbox/test_callback_verifier.py b/python/tests/unit_tests/sandbox/test_callback_verifier.py new file mode 100644 index 000000000..6989eb9de --- /dev/null +++ b/python/tests/unit_tests/sandbox/test_callback_verifier.py @@ -0,0 +1,421 @@ +import base64 +import hashlib +import json +import time +from typing import Any, Mapping, Optional + +import httpx +import pytest +from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey +from cryptography.hazmat.primitives.serialization import ( + Encoding, + PublicFormat, +) + +from langsmith import utils as ls_utils +from langsmith.sandbox import ( + SANDBOX_CALLBACK_SIGNATURE_HEADER, + SANDBOX_CALLBACK_SUBJECT, + SandboxCallbackClaims, + SandboxCallbackVerificationError, + SandboxCallbackVerifier, +) + + +def _callback_body( + *, + tenant_id: str = "tenant-1", + organization_id: str = "org-1", + sandbox_id: str = "sandbox-1", + ls_user_id: Optional[str] = "user-1", + host: str = "example.com", + port: int = 80, +) -> bytes: + identity = { + "tenant_id": tenant_id, + "organization_id": organization_id, + "sandbox_id": sandbox_id, + } + if ls_user_id is not None: + identity["ls_user_id"] = ls_user_id + return json.dumps( + {"host": host, "port": port, "identity": identity}, + separators=(",", ":"), + ).encode() + + +def test_verifier_uses_sdk_defaults_and_caches_tenant_and_jwks( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("LANGSMITH_ENDPOINT", "https://smith.example/api") + monkeypatch.setenv("LANGSMITH_API_KEY", "test-key") + ls_utils.get_env_var.cache_clear() + body = _callback_body() + key = Ed25519PrivateKey.generate() + token = _sign_callback_jwt( + key, + kid="callback-kid", + body=body, + issuer_url="https://smith.example", + ) + calls = _install_fake_langsmith_get( + monkeypatch, + key, + tenant_id="tenant-1", + api_url="https://smith.example/api", + ) + verifier = SandboxCallbackVerifier() + + for _ in range(2): + claims = verifier.verify( + {SANDBOX_CALLBACK_SIGNATURE_HEADER.lower(): token}, + body, + expected_sandbox_id="sandbox-1", + ) + assert isinstance(claims, SandboxCallbackClaims) + assert claims.issuer == "https://smith.example" + assert claims.audience == "https://customer.example/callback" + assert claims.subject == SANDBOX_CALLBACK_SUBJECT + assert claims.tenant_id == "tenant-1" + assert claims.organization_id == "org-1" + assert claims.sandbox_id == "sandbox-1" + assert claims.ls_user_id == "user-1" + assert claims.identity.ls_user_id == "user-1" + assert claims.body_sha256 == hashlib.sha256(body).hexdigest() + + assert [call["url"] for call in calls] == [ + "https://smith.example/.well-known/jwks.json", + "https://smith.example/api/sessions", + ] + assert calls[1]["headers"]["X-Api-Key"] == "test-key" + + +def test_verifier_refreshes_jwks_once_for_unknown_kid( + monkeypatch: pytest.MonkeyPatch, +) -> None: + body = _callback_body() + key = Ed25519PrivateKey.generate() + token = _sign_callback_jwt( + key, + kid="new-kid", + body=body, + issuer_url="https://smith.example", + ) + stale_key = Ed25519PrivateKey.generate() + calls = _install_fake_langsmith_get( + monkeypatch, + stale_key, + tenant_id="tenant-1", + api_url="https://smith.example", + jwks_sequence=[ + _jwks(stale_key, kid="old-kid"), + _jwks(key, kid="new-kid"), + ], + ) + verifier = SandboxCallbackVerifier( + api_url="https://smith.example", + expected_tenant_id="tenant-1", + ) + + claims = verifier.verify( + {SANDBOX_CALLBACK_SIGNATURE_HEADER: token}, + body, + expected_sandbox_id="sandbox-1", + ) + + assert claims.sandbox_id == "sandbox-1" + assert [call["url"] for call in calls] == [ + "https://smith.example/.well-known/jwks.json", + "https://smith.example/.well-known/jwks.json", + ] + + +def test_verifier_accepts_raw_asgi_headers_and_default_clock_leeway( + monkeypatch: pytest.MonkeyPatch, +) -> None: + body = _callback_body() + key = Ed25519PrivateKey.generate() + token = _sign_callback_jwt( + key, + kid="callback-kid", + body=body, + issuer_url="https://smith.example", + extra_claims={"exp": int(time.time()) - 30}, + ) + _install_fake_langsmith_get(monkeypatch, key, tenant_id="tenant-1") + verifier = SandboxCallbackVerifier( + api_url="https://smith.example", + expected_tenant_id="tenant-1", + ) + + claims = verifier.verify( + [(SANDBOX_CALLBACK_SIGNATURE_HEADER.lower().encode(), token.encode())], + body, + expected_sandbox_id="sandbox-1", + ) + + assert claims.sandbox_id == "sandbox-1" + + +@pytest.mark.parametrize( + ("mutate_claims", "message"), + [ + ({"iss": "https://other.example"}, "issuer"), + ({"sub": "other-subject"}, "subject"), + ({"aud": ""}, "aud"), + ({"exp": 1}, "expired"), + ({"nbf": time.time() + 3600}, "not yet valid"), + ], +) +def test_verifier_rejects_bad_claims( + monkeypatch: pytest.MonkeyPatch, + mutate_claims: Mapping[str, Any], + message: str, +) -> None: + body = _callback_body() + key = Ed25519PrivateKey.generate() + token = _sign_callback_jwt( + key, + kid="callback-kid", + body=body, + issuer_url="https://smith.example", + extra_claims=mutate_claims, + ) + _install_fake_langsmith_get(monkeypatch, key, tenant_id="tenant-1") + verifier = SandboxCallbackVerifier( + api_url="https://smith.example", + expected_tenant_id="tenant-1", + expected_organization_id="org-1", + ) + + with pytest.raises(SandboxCallbackVerificationError, match=message): + verifier.verify( + {SANDBOX_CALLBACK_SIGNATURE_HEADER: token}, + body, + expected_sandbox_id="sandbox-1", + ) + + +@pytest.mark.parametrize( + ("body", "message"), + [ + (_callback_body(tenant_id="other-tenant"), "tenant_id"), + (_callback_body(organization_id="other-org"), "organization_id"), + (_callback_body(sandbox_id="other-sandbox"), "sandbox_id"), + (b'{"host":"example.com","port":80}', "identity"), + ], +) +def test_verifier_rejects_bad_body_identity( + monkeypatch: pytest.MonkeyPatch, + body: bytes, + message: str, +) -> None: + key = Ed25519PrivateKey.generate() + token = _sign_callback_jwt( + key, + kid="callback-kid", + body=body, + issuer_url="https://smith.example", + ) + _install_fake_langsmith_get(monkeypatch, key, tenant_id="tenant-1") + verifier = SandboxCallbackVerifier( + api_url="https://smith.example", + expected_tenant_id="tenant-1", + expected_organization_id="org-1", + ) + + with pytest.raises(SandboxCallbackVerificationError, match=message): + verifier.verify( + {SANDBOX_CALLBACK_SIGNATURE_HEADER: token}, + body, + expected_sandbox_id="sandbox-1", + ) + + +def test_verifier_rejects_body_hash_mismatch(monkeypatch: pytest.MonkeyPatch) -> None: + signed_body = _callback_body(port=80) + key = Ed25519PrivateKey.generate() + token = _sign_callback_jwt( + key, + kid="callback-kid", + body=signed_body, + issuer_url="https://smith.example", + ) + _install_fake_langsmith_get(monkeypatch, key, tenant_id="tenant-1") + verifier = SandboxCallbackVerifier( + api_url="https://smith.example", + expected_tenant_id="tenant-1", + ) + + with pytest.raises(SandboxCallbackVerificationError, match="body hash"): + verifier.verify( + {SANDBOX_CALLBACK_SIGNATURE_HEADER: token}, + _callback_body(port=443), + expected_sandbox_id="sandbox-1", + ) + + +def test_verifier_rejects_wrong_public_key(monkeypatch: pytest.MonkeyPatch) -> None: + body = _callback_body() + key = Ed25519PrivateKey.generate() + token = _sign_callback_jwt( + key, + kid="callback-kid", + body=body, + issuer_url="https://smith.example", + ) + _install_fake_langsmith_get( + monkeypatch, + Ed25519PrivateKey.generate(), + tenant_id="tenant-1", + jwks_kid="callback-kid", + ) + verifier = SandboxCallbackVerifier( + api_url="https://smith.example", + expected_tenant_id="tenant-1", + ) + + with pytest.raises(SandboxCallbackVerificationError, match="signature"): + verifier.verify( + {SANDBOX_CALLBACK_SIGNATURE_HEADER: token}, + body, + expected_sandbox_id="sandbox-1", + ) + + +def test_verifier_requires_signature_header() -> None: + verifier = SandboxCallbackVerifier( + api_url="https://smith.example", + expected_tenant_id="tenant-1", + ) + + with pytest.raises(SandboxCallbackVerificationError, match="Missing"): + verifier.verify({}, b"{}", expected_sandbox_id="sandbox-1") + + +def test_verifier_requires_sandbox_id_unless_unsafe() -> None: + verifier = SandboxCallbackVerifier( + api_url="https://smith.example", + expected_tenant_id="tenant-1", + ) + + with pytest.raises(SandboxCallbackVerificationError, match="expected_sandbox_id"): + verifier.verify({SANDBOX_CALLBACK_SIGNATURE_HEADER: "token"}, b"{}") + with pytest.raises(SandboxCallbackVerificationError, match="not both"): + verifier.verify( + {SANDBOX_CALLBACK_SIGNATURE_HEADER: "token"}, + b"{}", + expected_sandbox_id="sandbox-1", + unsafely_allow_any_sandbox_id=True, + ) + + +def test_verifier_can_unsafely_allow_any_sandbox_id( + monkeypatch: pytest.MonkeyPatch, +) -> None: + body = _callback_body(sandbox_id="sandbox-from-body") + key = Ed25519PrivateKey.generate() + token = _sign_callback_jwt( + key, + kid="callback-kid", + body=body, + issuer_url="https://smith.example", + ) + _install_fake_langsmith_get(monkeypatch, key, tenant_id="tenant-1") + verifier = SandboxCallbackVerifier( + api_url="https://smith.example", + expected_tenant_id="tenant-1", + ) + + claims = verifier.verify( + {SANDBOX_CALLBACK_SIGNATURE_HEADER: token}, + body, + unsafely_allow_any_sandbox_id=True, + ) + + assert claims.sandbox_id == "sandbox-from-body" + + +def _install_fake_langsmith_get( + monkeypatch: pytest.MonkeyPatch, + key: Ed25519PrivateKey, + *, + tenant_id: str, + api_url: str = "https://smith.example", + jwks_kid: str = "callback-kid", + jwks_sequence: Optional[list[dict[str, list[dict[str, str]]]]] = None, +) -> list[dict[str, Any]]: + calls: list[dict[str, Any]] = [] + jwks_responses = list(jwks_sequence or [_jwks(key, kid=jwks_kid)]) + + def fake_get(url: str, **kwargs: Any) -> httpx.Response: + calls.append({"url": url, **kwargs}) + request = httpx.Request("GET", url) + if url == "https://smith.example/.well-known/jwks.json": + jwks = jwks_responses.pop(0) if jwks_responses else _jwks(key, kid=jwks_kid) + return httpx.Response(200, json=jwks, request=request) + if url == api_url.rstrip("/") + "/sessions": + return httpx.Response(200, json=[{"tenant_id": tenant_id}], request=request) + return httpx.Response(404, json={"error": "not found"}, request=request) + + monkeypatch.setattr(httpx, "get", fake_get) + return calls + + +def _sign_callback_jwt( + key: Ed25519PrivateKey, + *, + kid: str, + body: bytes, + issuer_url: str, + audience: str = "https://customer.example/callback", + extra_claims: Optional[Mapping[str, Any]] = None, +) -> str: + now = int(time.time()) + claims: dict[str, Any] = { + "iss": issuer_url.rstrip("/"), + "sub": SANDBOX_CALLBACK_SUBJECT, + "aud": audience, + "iat": now, + "nbf": now, + "exp": now + 300, + "jti": "jti-1", + "body_sha256": hashlib.sha256(body).hexdigest(), + } + if extra_claims: + claims.update(extra_claims) + header = {"alg": "EdDSA", "typ": "JWT", "kid": kid} + signing_input = b".".join( + [ + _b64url_json(header), + _b64url_json(claims), + ] + ) + return signing_input.decode() + "." + _b64url(key.sign(signing_input)).decode() + + +def _jwks(key: Ed25519PrivateKey, *, kid: str) -> dict[str, list[dict[str, str]]]: + public_bytes = key.public_key().public_bytes( + encoding=Encoding.Raw, + format=PublicFormat.Raw, + ) + return { + "keys": [ + { + "kty": "OKP", + "crv": "Ed25519", + "kid": kid, + "use": "sig", + "alg": "EdDSA", + "x": _b64url(public_bytes).decode(), + } + ] + } + + +def _b64url_json(value: Mapping[str, Any]) -> bytes: + return _b64url(json.dumps(value, separators=(",", ":")).encode()) + + +def _b64url(value: bytes) -> bytes: + return base64.urlsafe_b64encode(value).rstrip(b"=")