diff --git a/docs/executors.md b/docs/executors.md index e8ff816..3b91a23 100644 --- a/docs/executors.md +++ b/docs/executors.md @@ -193,6 +193,58 @@ config = ContainerConfig( - **Resource limits** - CPU, memory, disk quotas - **Clean state** - Each execution in fresh container +### Remote Session Server Mode + +`ContainerExecutor` can also connect to an existing session server instead of starting a +local container itself: + +```python +from py_code_mode import RedisStorage, Session +from py_code_mode.execution import ContainerExecutor + +storage = RedisStorage( + url="redis://localhost:6379", + prefix="production", + workspace_id="workspace-123", +) + +executor = ContainerExecutor(remote_url="http://session-server:8000") + +async with Session(storage=storage, executor=executor) as session: + result = await session.run(agent_code) +``` + +In remote mode: + +- the host storage backend supplies `workspace_id` +- the server issues the execution `session_id` +- workflows, artifacts, and workflow search are scoped to that workspace + +The executor binds the session by calling `POST /sessions` and then sends the returned +session ID on subsequent execution, workflow, artifact, and info requests via +`X-Session-ID`. + +Multiple sessions using the same `workspace_id` share storage state. Different +`workspace_id` values are isolated from each other. + +If `workspace_id` is omitted, the remote server uses the legacy default namespace for +backward compatibility. This is one shared unscoped namespace, not access to all +workspaces. + +### Remote Storage Requirements + +Remote mode only sends workspace identity. The session server must be configured with +server-owned storage roots so it can rebuild workspace-scoped storage internally. + +Relevant server config fields: + +- `storage_base_path`: base directory for file-backed workspace storage +- `storage_prefix`: Redis prefix for Redis-backed workspace storage + +The host storage and the remote server must refer to the same logical backing store. +For true remote deployments, Redis-backed storage is recommended because both sides can +share the same namespace cleanly. + ### Configuration Options ```python diff --git a/docs/production.md b/docs/production.md index bd717de..55b8594 100644 --- a/docs/production.md +++ b/docs/production.md @@ -94,16 +94,21 @@ else: ### 5. Isolate Storage by Tenant -Use separate Redis prefixes for multi-tenant deployments: +Use a stable environment prefix plus `workspace_id` for multi-tenant deployments: ```python def get_storage(tenant_id: str, redis_url: str) -> RedisStorage: return RedisStorage( url=redis_url, - prefix=f"tenant-{tenant_id}" + prefix="production", + workspace_id=tenant_id, ) ``` +If `workspace_id` is omitted, the system uses the legacy default namespace. That is one +shared unscoped namespace, so multi-tenant deployments should set `workspace_id` +explicitly. + --- ## Scalability Patterns @@ -140,6 +145,23 @@ async def handle_request(agent_code: str, tenant_id: str): Load balancer distributes requests across instances. +### Remote Session Servers + +For remote `ContainerExecutor(remote_url=...)` deployments: + +- the client provides `workspace_id` through the storage backend +- the session server creates an execution `session_id` +- workflow/artifact isolation is enforced by the server's workspace-scoped storage bundle + +Configure the session server with server-owned storage roots: + +- `storage_base_path` for file-backed storage +- `storage_prefix` for Redis-backed storage + +The host storage configuration and the remote session server must refer to the same +logical backing store. In practice, Redis-backed storage is the recommended production +topology for remote deployments because both sides can share one namespace directly. + --- ## Container Image Management diff --git a/docs/storage.md b/docs/storage.md index 307a363..c1c6598 100644 --- a/docs/storage.md +++ b/docs/storage.md @@ -81,6 +81,29 @@ RedisStorage( ) ``` +### Workspace Scoping + +Both storage backends accept an optional `workspace_id`: + +```python +from pathlib import Path + +from py_code_mode import FileStorage, RedisStorage + +file_storage = FileStorage(base_path=Path("./data"), workspace_id="client-a") +redis_storage = RedisStorage( + url="redis://localhost:6379", + prefix="production", + workspace_id="client-a", +) +``` + +When `workspace_id` is set, workflows, artifacts, and vector caches are scoped to that +workspace and shared by other sessions using the same ID. + +When `workspace_id` is omitted, storage uses the legacy unscoped namespace. This is one +shared default namespace, **not** access to all workspaces. + --- ## One Agent Learns, All Agents Benefit @@ -125,6 +148,17 @@ tenant_a_storage = RedisStorage(url="redis://localhost:6379", prefix="tenant-a") tenant_b_storage = RedisStorage(url="redis://localhost:6379", prefix="tenant-b") ``` +For multi-tenant systems inside one environment, prefer a stable app-level prefix plus +per-session `workspace_id` values: + +```python +storage = RedisStorage( + url="redis://localhost:6379", + prefix="production", + workspace_id="client-a", +) +``` + --- ## Migrating Between Storage Backends @@ -199,10 +233,24 @@ python -m py_code_mode.store diff \ ### Multi-Tenant -- Use separate prefix per tenant +- Use a stable environment prefix (for example `prod`) plus `workspace_id` per tenant or campaign - Consider separate Redis instances for hard isolation - Monitor Redis memory usage +### Remote Session Servers + +When using `ContainerExecutor(remote_url=...)`, the host storage object and the remote +session server must point at the same logical backing store: + +- file-backed remote mode: host `FileStorage(...)` should correspond to the server's + `storage_base_path` +- Redis-backed remote mode: host `RedisStorage(prefix=..., workspace_id=...)` should + correspond to the server's `storage_prefix` + +For true remote deployments, `RedisStorage` is usually the simplest and safest option +because both the host process and the remote session server can share the same Redis +namespace directly. + ### Workflow Lifecycle ```python diff --git a/src/py_code_mode/bootstrap.py b/src/py_code_mode/bootstrap.py index b122206..7016797 100644 --- a/src/py_code_mode/bootstrap.py +++ b/src/py_code_mode/bootstrap.py @@ -52,8 +52,10 @@ async def bootstrap_namespaces(config: dict[str, Any]) -> NamespaceBundle: Args: config: Dict with "type" key ("file" or "redis") and type-specific fields. - - For "file": {"type": "file", "base_path": str, "tools_path": str|None} + - For "file": {"type": "file", "base_path": str, "workspace_id": str|None, + "tools_path": str|None} - For "redis": {"type": "redis", "url": str, "prefix": str, + "workspace_id": str|None, "tools_path": str|None} - tools_path is optional; if provided, tools load from that directory @@ -128,7 +130,8 @@ async def _bootstrap_file_storage(config: dict[str, Any]) -> NamespaceBundle: from py_code_mode.storage import FileStorage base_path = Path(config["base_path"]) - storage = FileStorage(base_path) + workspace_id = config.get("workspace_id") + storage = FileStorage(base_path, workspace_id=workspace_id) tools_ns = await _load_tools_namespace(config.get("tools_path")) artifact_store = storage.get_artifact_store() @@ -159,15 +162,16 @@ async def _bootstrap_redis_storage(config: dict[str, Any]) -> NamespaceBundle: url = config["url"] prefix = config["prefix"] + workspace_id = config.get("workspace_id") # Connect to Redis - storage = RedisStorage(url=url, prefix=prefix) + storage = RedisStorage(url=url, prefix=prefix, workspace_id=workspace_id) tools_ns = await _load_tools_namespace(config.get("tools_path")) artifact_store = storage.get_artifact_store() # Create deps namespace - deps_store = RedisDepsStore(storage.client, prefix=f"{prefix}:deps") + deps_store = RedisDepsStore(storage.client, prefix=prefix) installer = PackageInstaller() deps_ns = DepsNamespace(deps_store, installer) diff --git a/src/py_code_mode/execution/container/client.py b/src/py_code_mode/execution/container/client.py index 2092e5a..c1a662d 100644 --- a/src/py_code_mode/execution/container/client.py +++ b/src/py_code_mode/execution/container/client.py @@ -3,7 +3,7 @@ This client connects to a running session server and provides a Python API for code execution. Each client maintains its own isolated session with separate Python namespace and artifacts. -The server allocates the session on the first execute call. +The client lazily binds a server session on first session-scoped use. Usage: async with SessionClient("http://localhost:8080") as client: @@ -14,6 +14,7 @@ from __future__ import annotations +import asyncio from dataclasses import dataclass from typing import Any @@ -75,8 +76,8 @@ class SessionClient: - Separate artifact directory Use the same client instance across requests to maintain state. - The server issues the session ID on first execution and the client - reuses it for later session-scoped requests. + The client lazily creates a bound remote session via POST /sessions and + reuses that session ID for later session-scoped requests. """ def __init__( @@ -102,6 +103,9 @@ def __init__( self.session_id: str | None = None self.auth_token = auth_token self._client: httpx.AsyncClient | None = None + self._workspace_id: str | None = None + self._auto_init_session = True + self._session_lock = asyncio.Lock() async def _get_client(self) -> httpx.AsyncClient: """Get or create HTTP client.""" @@ -109,15 +113,101 @@ async def _get_client(self) -> httpx.AsyncClient: self._client = httpx.AsyncClient(timeout=self.timeout) return self._client + def _auth_headers(self) -> dict[str, str]: + """Get headers with optional auth token only.""" + headers: dict[str, str] = {} + if self.auth_token: + headers["Authorization"] = f"Bearer {self.auth_token}" + return headers + def _headers(self) -> dict[str, str]: """Get headers with session ID and optional auth token.""" - headers: dict[str, str] = {} + headers = self._auth_headers() if self.session_id is not None: headers["X-Session-ID"] = self.session_id - if self.auth_token: - headers["Authorization"] = f"Bearer {self.auth_token}" return headers + async def _ensure_session_bound(self) -> None: + """Ensure a previously initialized session is rebound after reset/expiry.""" + if self.session_id is None and self._auto_init_session: + await self._create_session(self._workspace_id) + + def _is_invalid_session_response(self, response: httpx.Response) -> bool: + """Check whether a response indicates a stale or missing bound session.""" + if response.status_code != 400: + return False + + try: + data = response.json() + except ValueError: + return False + return data.get("detail") == "Invalid session ID" + + async def _create_session(self, workspace_id: str | None = None) -> str: + """Create and bind a remote session without taking the session lock.""" + client = await self._get_client() + payload: dict[str, str] = {} + if workspace_id is not None: + payload["workspace_id"] = workspace_id + + response = await client.post( + f"{self.base_url}/sessions", + json=payload, + headers=self._auth_headers(), + ) + response.raise_for_status() + data = response.json() + + self.session_id = data["session_id"] + return self.session_id + + async def _rebind_if_current(self, failed_session_id: str) -> bool: + """Rebind if the current client session still matches the failed session.""" + if not self._auto_init_session: + return False + if self.session_id is not None and self.session_id != failed_session_id: + return False + + self.session_id = None + await self._create_session(self._workspace_id) + return True + + async def _send_session_request( + self, + method: str, + path: str, + **kwargs: Any, + ) -> httpx.Response: + """Send a session-scoped request, rebinding once if the server expired the session.""" + async with self._session_lock: + await self._ensure_session_bound() + client = await self._get_client() + request_method = getattr(client, method.lower()) + request_session_id = self.session_id + response = await request_method( + f"{self.base_url}{path}", + headers=self._headers(), + **kwargs, + ) + + if request_session_id is not None and self._is_invalid_session_response(response): + rebound = await self._rebind_if_current(request_session_id) + if rebound: + response = await request_method( + f"{self.base_url}{path}", + headers=self._headers(), + **kwargs, + ) + + return response + + async def init_session(self, workspace_id: str | None = None) -> str: + """Create and bind an explicit remote session.""" + async with self._session_lock: + self._workspace_id = workspace_id + self._auto_init_session = True + return await self._create_session(workspace_id) + async def execute( self, code: str, @@ -132,28 +222,24 @@ async def execute( Returns: ExecuteResult with value, stdout, error. """ - client = await self._get_client() payload = {"code": code} if timeout is not None: payload["timeout"] = timeout # type: ignore - response = await client.post( - f"{self.base_url}/execute", + response = await self._send_session_request( + "POST", + "/execute", json=payload, - headers=self._headers(), ) response.raise_for_status() data = response.json() - session_id = data["session_id"] - self.session_id = session_id - return ExecuteResult( value=data["value"], stdout=data["stdout"], error=data["error"], execution_time_ms=data["execution_time_ms"], - session_id=session_id, + session_id=data["session_id"], ) async def health(self) -> HealthResult: @@ -178,8 +264,7 @@ async def info(self) -> InfoResult: Returns: InfoResult with available tools and workflows. """ - client = await self._get_client() - response = await client.get(f"{self.base_url}/info", headers=self._headers()) + response = await self._send_session_request("GET", "/info") response.raise_for_status() data = response.json() @@ -197,22 +282,23 @@ async def reset(self) -> ResetResult: Returns: ResetResult confirming reset. """ - if self.session_id is None: - return ResetResult(status="reset", session_id=None) - - client = await self._get_client() - response = await client.post( - f"{self.base_url}/reset", - headers=self._headers(), - ) - response.raise_for_status() - data = response.json() - self.session_id = None + async with self._session_lock: + if self.session_id is None: + return ResetResult(status="reset", session_id=None) + + client = await self._get_client() + response = await client.post( + f"{self.base_url}/reset", + headers=self._headers(), + ) + response.raise_for_status() + data = response.json() + self.session_id = None - return ResetResult( - status=data["status"], - session_id=data.get("session_id"), - ) + return ResetResult( + status=data["status"], + session_id=data.get("session_id"), + ) async def install_deps(self, packages: list[str]) -> dict[str, Any]: """Install packages in the container. @@ -301,42 +387,32 @@ async def search_tools(self, query: str, limit: int = 10) -> list[dict[str, Any] async def list_workflows(self) -> list[dict[str, Any]]: """List all workflows.""" - client = await self._get_client() - response = await client.get( - f"{self.base_url}/api/workflows", - headers=self._headers(), - ) + response = await self._send_session_request("GET", "/api/workflows") response.raise_for_status() return response.json() async def search_workflows(self, query: str, limit: int = 5) -> list[dict[str, Any]]: """Search workflows.""" - client = await self._get_client() - response = await client.get( - f"{self.base_url}/api/workflows/search", + response = await self._send_session_request( + "GET", + "/api/workflows/search", params={"query": query, "limit": limit}, - headers=self._headers(), ) response.raise_for_status() return response.json() async def get_workflow(self, name: str) -> dict[str, Any] | None: """Get workflow by name with full source.""" - client = await self._get_client() - response = await client.get( - f"{self.base_url}/api/workflows/{name}", - headers=self._headers(), - ) + response = await self._send_session_request("GET", f"/api/workflows/{name}") response.raise_for_status() return response.json() async def create_workflow(self, name: str, source: str, description: str) -> dict[str, Any]: """Create a new workflow.""" - client = await self._get_client() - response = await client.post( - f"{self.base_url}/api/workflows", + response = await self._send_session_request( + "POST", + "/api/workflows", json={"name": name, "source": source, "description": description}, - headers=self._headers(), ) if response.status_code != 200: data = response.json() @@ -345,11 +421,7 @@ async def create_workflow(self, name: str, source: str, description: str) -> dic async def delete_workflow(self, name: str) -> bool: """Delete a workflow.""" - client = await self._get_client() - response = await client.delete( - f"{self.base_url}/api/workflows/{name}", - headers=self._headers(), - ) + response = await self._send_session_request("DELETE", f"/api/workflows/{name}") response.raise_for_status() return response.json() @@ -363,11 +435,7 @@ async def list_artifacts(self) -> list[dict[str, Any]]: Returns: List of artifact metadata dicts. """ - client = await self._get_client() - response = await client.get( - f"{self.base_url}/api/artifacts", - headers=self._headers(), - ) + response = await self._send_session_request("GET", "/api/artifacts") response.raise_for_status() return response.json() @@ -383,11 +451,7 @@ async def load_artifact(self, name: str) -> Any: Raises: RuntimeError: If artifact not found. """ - client = await self._get_client() - response = await client.get( - f"{self.base_url}/api/artifacts/{name}", - headers=self._headers(), - ) + response = await self._send_session_request("GET", f"/api/artifacts/{name}") if response.status_code == 404: raise RuntimeError(f"Artifact '{name}' not found") response.raise_for_status() @@ -411,16 +475,15 @@ async def save_artifact( Returns: Artifact metadata dict. """ - client = await self._get_client() - response = await client.post( - f"{self.base_url}/api/artifacts", + response = await self._send_session_request( + "POST", + "/api/artifacts", json={ "name": name, "data": data, "description": description, "metadata": metadata, }, - headers=self._headers(), ) response.raise_for_status() return response.json() @@ -434,11 +497,7 @@ async def delete_artifact(self, name: str) -> None: Raises: RuntimeError: If artifact not found. """ - client = await self._get_client() - response = await client.delete( - f"{self.base_url}/api/artifacts/{name}", - headers=self._headers(), - ) + response = await self._send_session_request("DELETE", f"/api/artifacts/{name}") if response.status_code == 404: raise RuntimeError(f"Artifact '{name}' not found") response.raise_for_status() diff --git a/src/py_code_mode/execution/container/config.py b/src/py_code_mode/execution/container/config.py index 2e73fbf..529ba96 100644 --- a/src/py_code_mode/execution/container/config.py +++ b/src/py_code_mode/execution/container/config.py @@ -44,6 +44,8 @@ class SessionConfig: artifacts_path: Path = field(default_factory=lambda: Path("/workspace/artifacts")) artifact_backend: str = "file" # "file" or "redis" redis_url: str | None = None + storage_base_path: Path | None = None + storage_prefix: str | None = None # Execution default_timeout: float = 30.0 @@ -83,12 +85,16 @@ def from_env(cls) -> SessionConfig: config.workflows_path = Path(workflows_path) if artifacts_path := os.environ.get("ARTIFACTS_PATH"): config.artifacts_path = Path(artifacts_path) + if storage_base_path := os.environ.get("STORAGE_BASE_PATH"): + config.storage_base_path = Path(storage_base_path) # Artifact backend if backend := os.environ.get("ARTIFACT_BACKEND"): config.artifact_backend = backend if redis_url := os.environ.get("REDIS_URL"): config.redis_url = redis_url + if storage_prefix := os.environ.get("STORAGE_PREFIX"): + config.storage_prefix = storage_prefix # Timeouts if timeout := os.environ.get("DEFAULT_TIMEOUT"): @@ -147,6 +153,12 @@ def _from_dict(cls, data: dict[str, Any]) -> SessionConfig: artifacts_path=Path(data.get("artifacts_path", "/workspace/artifacts")), artifact_backend=data.get("artifact_backend", "file"), redis_url=data.get("redis_url"), + storage_base_path=( + Path(data["storage_base_path"]) + if data.get("storage_base_path") is not None + else None + ), + storage_prefix=data.get("storage_prefix"), default_timeout=data.get("default_timeout", 30.0), max_execution_time=data.get("max_execution_time", 300.0), host=data.get("host", "0.0.0.0"), diff --git a/src/py_code_mode/execution/container/executor.py b/src/py_code_mode/execution/container/executor.py index ec16903..a24dc87 100644 --- a/src/py_code_mode/execution/container/executor.py +++ b/src/py_code_mode/execution/container/executor.py @@ -375,6 +375,8 @@ async def start( timeout=self.config.timeout, auth_token=self.config.auth_token, ) + workspace_id = storage.workspace_id if storage is not None else None + await self._client.init_session(workspace_id=workspace_id) return # Initialize Docker client with fallback socket detection @@ -404,8 +406,9 @@ async def start( workflows_path = access.workflows_path artifacts_path = access.artifacts_path deps_path = None - if artifacts_path is not None: - deps_path = artifacts_path.parent / "deps" + deps_root = access.root_path or artifacts_path.parent + if deps_root is not None: + deps_path = deps_root / "deps" # Create directories on host before mounting # Workflows need to exist for volume mount if workflows_path: @@ -425,7 +428,7 @@ async def start( tools_prefix = None # Tools owned by executor workflows_prefix = access.workflows_prefix artifacts_prefix = access.artifacts_prefix - deps_prefix = None # Deps owned by executor + deps_prefix = access.root_prefix or access.workflows_prefix.rsplit(":", 1)[0] else: raise TypeError( f"Unexpected storage access type: {type(access).__name__}. " diff --git a/src/py_code_mode/execution/container/server.py b/src/py_code_mode/execution/container/server.py index 435b122..0217a02 100644 --- a/src/py_code_mode/execution/container/server.py +++ b/src/py_code_mode/execution/container/server.py @@ -64,6 +64,7 @@ from py_code_mode.execution.in_process import ( # noqa: E402 InProcessExecutor as CodeExecutor, ) +from py_code_mode.storage.backends import _validate_workspace_id # noqa: E402 from py_code_mode.tools import ToolRegistry # noqa: E402 from py_code_mode.workflows import ( # noqa: E402 FileWorkflowStore, @@ -133,6 +134,16 @@ class ResetResponseModel(BaseModel): # type: ignore status: str session_id: str + class CreateSessionRequestModel(BaseModel): # type: ignore + """Request to create a bound remote session.""" + + workspace_id: str | None = None + + class SessionResponseModel(BaseModel): # type: ignore + """Response containing a created session ID.""" + + session_id: str + # NOTE: SessionInfoModel removed - /sessions endpoint was removed for security # (session enumeration attack vector) @@ -212,13 +223,26 @@ class DepsSyncResult(BaseModel): # type: ignore failed: list[str] = [] +@dataclass +class WorkspaceBundle: + """Shared workflow/artifact state for one workspace scope.""" + + workspace_id: str | None + workflow_library: WorkflowLibrary | None + artifact_store: ArtifactStoreProtocol + artifacts_path: str + + @dataclass class Session: - """Individual session state.""" + """Individual session state with isolated Python state and bound storage.""" session_id: str executor: CodeExecutor + workflow_library: WorkflowLibrary | None artifact_store: ArtifactStoreProtocol + workspace_id: str | None = None + artifacts_path: str = "" created_at: float = field(default_factory=time.time) last_used: float = field(default_factory=time.time) execution_count: int = 0 @@ -237,6 +261,11 @@ class ServerState: sessions: dict[str, Session] = field(default_factory=dict) start_time: float = 0.0 redis_mode: bool = False + default_bundle: WorkspaceBundle | None = None + workspace_bundles: dict[str, WorkspaceBundle] = field(default_factory=dict) + redis_client: Any | None = None + redis_workflows_prefix: str | None = None + redis_artifacts_prefix: str | None = None # Global state @@ -277,22 +306,136 @@ def build_workflow_library(config: SessionConfig) -> WorkflowLibrary | None: return create_workflow_library(store=store) -def create_session(session_id: str) -> Session: - """Create a new isolated session.""" - if _state.config is None: +def build_workflow_library_for_path(workflows_path: Path) -> WorkflowLibrary | None: + """Build a file-backed workflow library for the given path.""" + try: + workflows_path.mkdir(parents=True, exist_ok=True) + except OSError as e: + logger.warning("Cannot create workflows directory at %s: %s", workflows_path, e) + return None + + store = FileWorkflowStore(workflows_path) + return create_workflow_library(store=store) + + +def build_workflow_library_for_redis(redis_client: Any, prefix: str) -> WorkflowLibrary: + """Build a Redis-backed workflow library for the given prefix.""" + from py_code_mode.workflows import RedisWorkflowStore + + store = RedisWorkflowStore(redis_client, prefix=prefix) + return create_workflow_library(store=store) + + +def get_default_bundle() -> WorkspaceBundle: + """Return the legacy unscoped bundle.""" + if _state.default_bundle is None: + raise RuntimeError("Default workspace bundle not initialized") + return _state.default_bundle + + +def _derive_redis_root_prefix( + workflows_prefix: str | None, + artifacts_prefix: str | None, +) -> str | None: + """Derive a shared Redis root prefix from legacy workflow/artifact prefixes.""" + if workflows_prefix is None or artifacts_prefix is None: + return None + if not workflows_prefix.endswith(":workflows") or not artifacts_prefix.endswith(":artifacts"): + return None + + workflows_root = workflows_prefix[: -len(":workflows")] + artifacts_root = artifacts_prefix[: -len(":artifacts")] + if workflows_root != artifacts_root: + return None + if ":ws:" in workflows_root: + return None + return workflows_root + + +def build_workspace_bundle(workspace_id: str) -> WorkspaceBundle: + """Build a cached workflow/artifact bundle for one workspace.""" + workspace_id = _validate_workspace_id(workspace_id) + config = _state.config + if config is None: raise RuntimeError("Server not initialized") - # Use shared artifact store (already initialized at startup for both modes) - if _state.artifact_store is None: - raise RuntimeError("Artifact store not initialized") - artifact_store = _state.artifact_store + if _state.redis_mode: + if _state.redis_client is None: + raise RuntimeError("Redis client not initialized") + + storage_prefix = config.storage_prefix or _derive_redis_root_prefix( + _state.redis_workflows_prefix, + _state.redis_artifacts_prefix, + ) + if storage_prefix is None: + raise RuntimeError("Workspace-scoped Redis storage requires an explicit storage_prefix") + + workflow_prefix = f"{storage_prefix}:ws:{workspace_id}:workflows" + artifact_prefix = f"{storage_prefix}:ws:{workspace_id}:artifacts" + workflow_library = build_workflow_library_for_redis(_state.redis_client, workflow_prefix) + + from py_code_mode.artifacts import RedisArtifactStore + + artifact_store = RedisArtifactStore(_state.redis_client, prefix=artifact_prefix) + return WorkspaceBundle( + workspace_id=workspace_id, + workflow_library=workflow_library, + artifact_store=artifact_store, + artifacts_path=artifact_prefix, + ) + + storage_base_path = config.storage_base_path + if storage_base_path is not None: + workspace_root = storage_base_path / "workspaces" / workspace_id + workflows_path = workspace_root / "workflows" + artifacts_path = workspace_root / "artifacts" + else: + workflows_path = ( + config.workflows_path.parent / "workspaces" / workspace_id / config.workflows_path.name + ) + artifacts_path = ( + config.artifacts_path.parent / "workspaces" / workspace_id / config.artifacts_path.name + ) + + workflow_library = build_workflow_library_for_path(workflows_path) + artifacts_path.mkdir(parents=True, exist_ok=True) + artifact_store = FileArtifactStore(artifacts_path) + return WorkspaceBundle( + workspace_id=workspace_id, + workflow_library=workflow_library, + artifact_store=artifact_store, + artifacts_path=str(artifacts_path), + ) + + +def get_workspace_bundle(workspace_id: str | None) -> WorkspaceBundle: + """Return the default or scoped bundle for a request/session.""" + if workspace_id is None: + return get_default_bundle() + + bundle = _state.workspace_bundles.get(workspace_id) + if bundle is not None: + return bundle + + bundle = build_workspace_bundle(workspace_id) + _state.workspace_bundles[workspace_id] = bundle + return bundle + + +def create_session(session_id: str, workspace_id: str | None = None) -> Session: + """Create a new isolated Python session bound to a storage bundle.""" + config = _state.config + if config is None: + raise RuntimeError("Server not initialized") + + bundle = get_workspace_bundle(workspace_id) # Create deps namespace if deps_store is available deps_namespace = None if _state.deps_store is not None and _state.deps_installer is not None: base_deps = DepsNamespace(_state.deps_store, _state.deps_installer) # Wrap if runtime deps disabled - if not _state.config.allow_runtime_deps: + if not config.allow_runtime_deps: deps_namespace = ControlledDepsNamespace(base_deps, allow_runtime=False) else: deps_namespace = base_deps @@ -300,23 +443,26 @@ def create_session(session_id: str) -> Session: # Create executor with shared registries but isolated namespace/artifacts executor = CodeExecutor( registry=_state.registry, - workflow_library=_state.workflow_library, - artifact_store=artifact_store, + workflow_library=bundle.workflow_library, + artifact_store=bundle.artifact_store, deps_namespace=deps_namespace, - default_timeout=_state.config.default_timeout, + default_timeout=config.default_timeout, ) return Session( session_id=session_id, executor=executor, - artifact_store=artifact_store, + workflow_library=bundle.workflow_library, + artifact_store=bundle.artifact_store, + workspace_id=workspace_id, + artifacts_path=bundle.artifacts_path, ) -def create_new_session() -> Session: +def create_new_session(workspace_id: str | None = None) -> Session: """Create a new isolated session with a server-issued ID.""" session_id = str(uuid.uuid4()) - session = create_session(session_id) + session = create_session(session_id, workspace_id=workspace_id) _state.sessions[session_id] = session return session @@ -340,6 +486,29 @@ def cleanup_expired_sessions() -> int: return len(expired) +def get_bound_session(session_id: str | None) -> Session: + """Resolve a request's bound session or raise 400.""" + cleanup_expired_sessions() + if session_id is None: + raise HTTPException(status_code=400, detail="Invalid session ID") + + session = get_existing_session(session_id) + if session is None: + raise HTTPException(status_code=400, detail="Invalid session ID") + return session + + +def get_request_bundle(session_id: str | None) -> WorkspaceBundle: + """Resolve the effective workflow/artifact bundle for a request.""" + session = get_bound_session(session_id) + return WorkspaceBundle( + workspace_id=session.workspace_id, + workflow_library=session.workflow_library, + artifact_store=session.artifact_store, + artifacts_path=session.artifacts_path, + ) + + def install_python_deps(deps: list[str]) -> None: """Install Python dependencies if not already installed. @@ -375,7 +544,7 @@ async def initialize_server(config: SessionConfig) -> None: if config.python_deps: install_python_deps(config.python_deps) - redis_url = os.environ.get("REDIS_URL") + redis_url = config.redis_url or os.environ.get("REDIS_URL") if redis_url: # Redis mode: load everything from Redis with semantic search @@ -390,8 +559,14 @@ async def initialize_server(config: SessionConfig) -> None: # Get prefixes from environment (set by ContainerExecutor), with defaults tools_prefix = os.environ.get("REDIS_TOOLS_PREFIX", "tools") - workflows_prefix = os.environ.get("REDIS_WORKFLOWS_PREFIX", "workflows") - artifacts_prefix = os.environ.get("REDIS_ARTIFACTS_PREFIX", "artifacts") + workflows_prefix = os.environ.get( + "REDIS_WORKFLOWS_PREFIX", + f"{config.storage_prefix}:workflows" if config.storage_prefix else "workflows", + ) + artifacts_prefix = os.environ.get( + "REDIS_ARTIFACTS_PREFIX", + f"{config.storage_prefix}:artifacts" if config.storage_prefix else "artifacts", + ) # Tools from Redis tool_store = RedisToolStore(r, prefix=tools_prefix) @@ -408,9 +583,8 @@ async def initialize_server(config: SessionConfig) -> None: artifact_store = RedisArtifactStore(r, prefix=artifacts_prefix) # Deps from Redis - # Derive deps prefix from tools prefix namespace (e.g., "myapp:tools" -> "myapp:deps") - # If tools_prefix has no namespace separator, uses tools_prefix directly as base - deps_prefix = os.environ.get("REDIS_DEPS_PREFIX", f"{tools_prefix.rsplit(':', 1)[0]}:deps") + # Deps use the root Redis prefix. RedisDepsStore appends ":deps" internally. + deps_prefix = os.environ.get("REDIS_DEPS_PREFIX", tools_prefix.rsplit(":", 1)[0]) deps_store = RedisDepsStore(r, prefix=deps_prefix) deps_installer = PackageInstaller() logger.info(" Deps in Redis (%s): initialized", deps_prefix) @@ -424,6 +598,13 @@ async def initialize_server(config: SessionConfig) -> None: deps_store.add(dep) logger.info(" Pre-configured deps: %s", container_deps) + default_bundle = WorkspaceBundle( + workspace_id=None, + workflow_library=workflow_library, + artifact_store=artifact_store, + artifacts_path=artifacts_prefix, + ) + _state = ServerState( config=config, registry=registry, @@ -434,6 +615,11 @@ async def initialize_server(config: SessionConfig) -> None: sessions={}, start_time=time.time(), redis_mode=True, + default_bundle=default_bundle, + workspace_bundles={}, + redis_client=r, + redis_workflows_prefix=workflows_prefix, + redis_artifacts_prefix=artifacts_prefix, ) else: # File mode: load from config paths @@ -484,6 +670,13 @@ async def initialize_server(config: SessionConfig) -> None: deps_store.add(dep) logger.info(" Pre-configured deps: %s", container_deps) + default_bundle = WorkspaceBundle( + workspace_id=None, + workflow_library=workflow_library, + artifact_store=artifact_store, + artifacts_path=str(config.artifacts_path), + ) + _state = ServerState( config=config, registry=registry, @@ -494,6 +687,8 @@ async def initialize_server(config: SessionConfig) -> None: sessions={}, start_time=time.time(), redis_mode=False, + default_bundle=default_bundle, + workspace_bundles={}, ) # Log authentication status (important for security awareness) @@ -592,21 +787,12 @@ async def execute( ) -> ExecuteResponseModel: """Execute code in an isolated session. - Pass X-Session-ID header to use a specific session. - Omit to create a new session (ID returned in response). + Pass X-Session-ID header to use a specific bound session. """ if _state.config is None: raise HTTPException(status_code=503, detail="Server not initialized") - # Cleanup expired sessions periodically - cleanup_expired_sessions() - - if x_session_id is None: - session = create_new_session() - else: - session = get_existing_session(x_session_id) - if session is None: - raise HTTPException(status_code=400, detail="Invalid session ID") + session = get_bound_session(x_session_id) start = time.time() timeout = body.timeout or _state.config.default_timeout @@ -628,6 +814,25 @@ async def execute( session_id=session.session_id, ) + @app.post( + "/sessions", + response_model=SessionResponseModel, + dependencies=[Depends(require_auth)], + ) + async def create_bound_session(body: CreateSessionRequestModel) -> SessionResponseModel: + """Create a new session optionally bound to a workspace-scoped bundle.""" + try: + workspace_id = ( + _validate_workspace_id(body.workspace_id) if body.workspace_id is not None else None + ) + session = create_new_session(workspace_id=workspace_id) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except RuntimeError as e: + raise HTTPException(status_code=503, detail=str(e)) + + return SessionResponseModel(session_id=session.session_id) + @app.get("/health", response_model=HealthResponseModel) async def health() -> HealthResponseModel: """Health check endpoint. @@ -643,25 +848,26 @@ async def health() -> HealthResponseModel: ) @app.get("/info", response_model=InfoResponseModel, dependencies=[Depends(require_auth)]) - async def info() -> InfoResponseModel: + async def info( + x_session_id: str | None = Header(None, alias="X-Session-ID"), + ) -> InfoResponseModel: """Get information about available tools and workflows.""" + bundle = get_request_bundle(x_session_id) tools = [] if _state.registry: for tool in _state.registry.list_tools(): tools.append({"name": tool.name, "description": tool.description}) workflows = [] - if _state.workflow_library is not None: - _state.workflow_library.refresh() - for workflow in _state.workflow_library.list(): + if bundle.workflow_library is not None: + bundle.workflow_library.refresh() + for workflow in bundle.workflow_library.list(): workflows.append({"name": workflow.name, "description": workflow.description}) - artifacts_path = str(_state.config.artifacts_path) if _state.config else "" - return InfoResponseModel( tools=tools, workflows=workflows, - artifacts_path=artifacts_path, + artifacts_path=bundle.artifacts_path, ) @app.post("/reset", response_model=ResetResponseModel, dependencies=[Depends(require_auth)]) @@ -679,7 +885,7 @@ async def reset( session_id=x_session_id, ) - # NOTE: /sessions endpoint removed - session enumeration is an information disclosure risk + # Only POST /sessions is exposed. Enumeration/listing endpoints remain unavailable. @app.post( "/install_deps", @@ -809,13 +1015,16 @@ async def api_search_tools(query: str, limit: int = 10) -> list[dict[str, Any]]: # ========================================================================== @app.get("/api/workflows", dependencies=[Depends(require_auth)]) - async def api_list_workflows() -> list[dict[str, Any]]: + async def api_list_workflows( + x_session_id: str | None = Header(None, alias="X-Session-ID"), + ) -> list[dict[str, Any]]: """Return all workflows.""" - if _state.workflow_library is None: + bundle = get_request_bundle(x_session_id) + if bundle.workflow_library is None: raise HTTPException(status_code=503, detail="Workflow library not initialized") - _state.workflow_library.refresh() - workflows = _state.workflow_library.list() + bundle.workflow_library.refresh() + workflows = bundle.workflow_library.list() return [ { "name": workflow.name, @@ -826,13 +1035,18 @@ async def api_list_workflows() -> list[dict[str, Any]]: ] @app.get("/api/workflows/search", dependencies=[Depends(require_auth)]) - async def api_search_workflows(query: str, limit: int = 5) -> list[dict[str, Any]]: + async def api_search_workflows( + query: str, + limit: int = 5, + x_session_id: str | None = Header(None, alias="X-Session-ID"), + ) -> list[dict[str, Any]]: """Search workflows.""" - if _state.workflow_library is None: + bundle = get_request_bundle(x_session_id) + if bundle.workflow_library is None: raise HTTPException(status_code=503, detail="Workflow library not initialized") - _state.workflow_library.refresh() - workflows = _state.workflow_library.search(query, limit=limit) + bundle.workflow_library.refresh() + workflows = bundle.workflow_library.search(query, limit=limit) return [ { "name": workflow.name, @@ -843,13 +1057,17 @@ async def api_search_workflows(query: str, limit: int = 5) -> list[dict[str, Any ] @app.get("/api/workflows/{name}", dependencies=[Depends(require_auth)]) - async def api_get_workflow(name: str) -> dict[str, Any] | None: + async def api_get_workflow( + name: str, + x_session_id: str | None = Header(None, alias="X-Session-ID"), + ) -> dict[str, Any] | None: """Get workflow by name with full source.""" - if _state.workflow_library is None: + bundle = get_request_bundle(x_session_id) + if bundle.workflow_library is None: raise HTTPException(status_code=503, detail="Workflow library not initialized") - _state.workflow_library.refresh() - workflow = _state.workflow_library.get(name) + bundle.workflow_library.refresh() + workflow = bundle.workflow_library.get(name) if workflow is None: return None @@ -861,9 +1079,13 @@ async def api_get_workflow(name: str) -> dict[str, Any] | None: } @app.post("/api/workflows", dependencies=[Depends(require_auth)]) - async def api_create_workflow(body: CreateWorkflowRequest) -> dict[str, Any]: + async def api_create_workflow( + body: CreateWorkflowRequest, + x_session_id: str | None = Header(None, alias="X-Session-ID"), + ) -> dict[str, Any]: """Create a new workflow.""" - if _state.workflow_library is None: + bundle = get_request_bundle(x_session_id) + if bundle.workflow_library is None: raise HTTPException(status_code=503, detail="Workflow library not initialized") from py_code_mode.workflows import PythonWorkflow @@ -874,7 +1096,7 @@ async def api_create_workflow(body: CreateWorkflowRequest) -> dict[str, Any]: source=body.source, description=body.description, ) - _state.workflow_library.add(workflow) + bundle.workflow_library.add(workflow) return { "name": workflow.name, "description": workflow.description, @@ -885,24 +1107,28 @@ async def api_create_workflow(body: CreateWorkflowRequest) -> dict[str, Any]: raise HTTPException(status_code=400, detail=str(e)) @app.delete("/api/workflows/{name}", dependencies=[Depends(require_auth)]) - async def api_delete_workflow(name: str) -> bool: + async def api_delete_workflow( + name: str, + x_session_id: str | None = Header(None, alias="X-Session-ID"), + ) -> bool: """Delete a workflow.""" - if _state.workflow_library is None: + bundle = get_request_bundle(x_session_id) + if bundle.workflow_library is None: return False - return _state.workflow_library.remove(name) + return bundle.workflow_library.remove(name) # ========================================================================== # Artifacts API Endpoints # ========================================================================== @app.get("/api/artifacts", dependencies=[Depends(require_auth)]) - async def api_list_artifacts() -> list[dict[str, Any]]: + async def api_list_artifacts( + x_session_id: str | None = Header(None, alias="X-Session-ID"), + ) -> list[dict[str, Any]]: """List all artifacts with metadata.""" - if _state.artifact_store is None: - return [] - - artifacts = _state.artifact_store.list() + bundle = get_request_bundle(x_session_id) + artifacts = bundle.artifact_store.list() return [ { "name": artifact.name, @@ -915,23 +1141,27 @@ async def api_list_artifacts() -> list[dict[str, Any]]: ] @app.get("/api/artifacts/{name}", dependencies=[Depends(require_auth)]) - async def api_load_artifact(name: str) -> Any: + async def api_load_artifact( + name: str, + x_session_id: str | None = Header(None, alias="X-Session-ID"), + ) -> Any: """Load artifact data.""" - if _state.artifact_store is None: - raise HTTPException(status_code=503, detail="Artifact store not initialized") + bundle = get_request_bundle(x_session_id) try: - return _state.artifact_store.load(name) + return bundle.artifact_store.load(name) except ArtifactNotFoundError: raise HTTPException(status_code=404, detail=f"Artifact '{name}' not found") @app.post("/api/artifacts", dependencies=[Depends(require_auth)]) - async def api_save_artifact(body: SaveArtifactRequest) -> dict[str, Any]: + async def api_save_artifact( + body: SaveArtifactRequest, + x_session_id: str | None = Header(None, alias="X-Session-ID"), + ) -> dict[str, Any]: """Save artifact.""" - if _state.artifact_store is None: - raise HTTPException(status_code=503, detail="Artifact store not initialized") + bundle = get_request_bundle(x_session_id) - artifact = _state.artifact_store.save( + artifact = bundle.artifact_store.save( name=body.name, data=body.data, description=body.description, @@ -946,15 +1176,17 @@ async def api_save_artifact(body: SaveArtifactRequest) -> dict[str, Any]: } @app.delete("/api/artifacts/{name}", dependencies=[Depends(require_auth)]) - async def api_delete_artifact(name: str) -> None: + async def api_delete_artifact( + name: str, + x_session_id: str | None = Header(None, alias="X-Session-ID"), + ) -> None: """Delete artifact.""" - if _state.artifact_store is None: - raise HTTPException(status_code=503, detail="Artifact store not initialized") + bundle = get_request_bundle(x_session_id) - if not _state.artifact_store.exists(name): + if not bundle.artifact_store.exists(name): raise HTTPException(status_code=404, detail=f"Artifact '{name}' not found") - _state.artifact_store.delete(name) + bundle.artifact_store.delete(name) # ========================================================================== # Deps API Endpoints diff --git a/src/py_code_mode/execution/protocol.py b/src/py_code_mode/execution/protocol.py index 570987e..3b9e8db 100644 --- a/src/py_code_mode/execution/protocol.py +++ b/src/py_code_mode/execution/protocol.py @@ -33,6 +33,7 @@ class FileStorageAccess: workflows_path: Path | None artifacts_path: Path vectors_path: Path | None = None + root_path: Path | None = None @dataclass(frozen=True) @@ -48,6 +49,7 @@ class RedisStorageAccess: workflows_prefix: str artifacts_prefix: str vectors_prefix: str | None = None + root_prefix: str | None = None StorageAccess = FileStorageAccess | RedisStorageAccess diff --git a/src/py_code_mode/execution/subprocess/namespace.py b/src/py_code_mode/execution/subprocess/namespace.py index 80cddd7..f00a3d4 100644 --- a/src/py_code_mode/execution/subprocess/namespace.py +++ b/src/py_code_mode/execution/subprocess/namespace.py @@ -102,8 +102,8 @@ def _build_file_storage_setup_code( repr(str(storage_access.workflows_path)) if storage_access.workflows_path else "None" ) artifacts_path_str = repr(str(storage_access.artifacts_path)) - # Base path is parent of artifacts for deps store - base_path_str = repr(str(storage_access.artifacts_path.parent)) + base_path = storage_access.root_path or storage_access.artifacts_path.parent + base_path_str = repr(str(base_path)) allow_deps_str = "True" if allow_runtime_deps else "False" vectors_path_str = ( repr(str(storage_access.vectors_path)) if storage_access.vectors_path else "None" @@ -416,7 +416,8 @@ def _build_redis_storage_setup_code( vectors_prefix_str = ( repr(storage_access.vectors_prefix) if storage_access.vectors_prefix else "None" ) - deps_prefix_str = repr(f"{storage_access.workflows_prefix.rsplit(':', 1)[0]}:deps") + deps_prefix = storage_access.root_prefix or storage_access.workflows_prefix.rsplit(":", 1)[0] + deps_prefix_str = repr(deps_prefix) allow_deps_str = "True" if allow_runtime_deps else "False" return f'''# Auto-generated namespace setup for SubprocessExecutor (Redis) diff --git a/src/py_code_mode/storage/backends.py b/src/py_code_mode/storage/backends.py index 4377b4d..48be83e 100644 --- a/src/py_code_mode/storage/backends.py +++ b/src/py_code_mode/storage/backends.py @@ -9,6 +9,7 @@ from __future__ import annotations import logging +import re from pathlib import Path from typing import TYPE_CHECKING, ClassVar, Protocol, runtime_checkable from urllib.parse import quote @@ -45,10 +46,22 @@ logger = logging.getLogger(__name__) +_WORKSPACE_ID_PATTERN = re.compile(r"^[A-Za-z0-9_-]{1,128}$") + if TYPE_CHECKING: from redis import Redis +def _validate_workspace_id(workspace_id: str) -> str: + """Validate a workspace identifier used in paths and Redis prefixes.""" + if not _WORKSPACE_ID_PATTERN.fullmatch(workspace_id): + raise ValueError( + "workspace_id must be 1-128 characters using only ASCII letters, digits, " + "underscores, or hyphens" + ) + return workspace_id + + @runtime_checkable class StorageBackend(Protocol): """Protocol for unified storage backend. @@ -79,6 +92,11 @@ def get_artifact_store(self) -> ArtifactStoreProtocol: """ ... + @property + def workspace_id(self) -> str | None: + """Return the configured workspace scope, if any.""" + ... + class FileStorage: """File-based storage using directories for workflows and artifacts. @@ -88,14 +106,18 @@ class FileStorage: _UNINITIALIZED: ClassVar[object] = object() - def __init__(self, base_path: Path | str) -> None: + def __init__(self, base_path: Path | str, workspace_id: str | None = None) -> None: """Initialize file storage. Args: base_path: Base directory for storage. Will create workflows/, artifacts/ subdirs. + workspace_id: Optional workspace scope for workflows, artifacts, and vectors. """ self._base_path = Path(base_path) if isinstance(base_path, str) else base_path self._base_path.mkdir(parents=True, exist_ok=True) + self._workspace_id = ( + _validate_workspace_id(workspace_id) if workspace_id is not None else None + ) # Lazy-initialized stores (workflows and artifacts only) self._workflow_library: WorkflowLibrary | None = None @@ -107,21 +129,32 @@ def root(self) -> Path: """Get the root storage path.""" return self._base_path + @property + def workspace_id(self) -> str | None: + """Get the configured workspace scope.""" + return self._workspace_id + + def _get_storage_root(self) -> Path: + """Get the scoped root for workflows, artifacts, and vectors.""" + if self._workspace_id is None: + return self._base_path + return self._base_path / "workspaces" / self._workspace_id + def _get_workflows_path(self) -> Path: """Get the workflows directory path.""" - workflows_path = self._base_path / "workflows" + workflows_path = self._get_storage_root() / "workflows" workflows_path.mkdir(parents=True, exist_ok=True) return workflows_path def _get_artifacts_path(self) -> Path: """Get the artifacts directory path.""" - artifacts_path = self._base_path / "artifacts" + artifacts_path = self._get_storage_root() / "artifacts" artifacts_path.mkdir(parents=True, exist_ok=True) return artifacts_path def _get_vectors_path(self) -> Path: """Get the vectors directory path.""" - vectors_path = self._base_path / "vectors" + vectors_path = self._get_storage_root() / "vectors" vectors_path.mkdir(parents=True, exist_ok=True) return vectors_path @@ -153,13 +186,14 @@ def get_vector_store(self) -> VectorStore | None: def get_serializable_access(self) -> FileStorageAccess: """Return FileStorageAccess for cross-process communication.""" - base_path = self._base_path - vectors_path = base_path / "vectors" + storage_root = self._get_storage_root() + vectors_path = storage_root / "vectors" return FileStorageAccess( - workflows_path=base_path / "workflows", - artifacts_path=base_path / "artifacts", + workflows_path=storage_root / "workflows", + artifacts_path=storage_root / "artifacts", vectors_path=vectors_path if vectors_path.exists() else None, + root_path=self._base_path, ) def get_workflow_library(self) -> WorkflowLibrary: @@ -206,10 +240,13 @@ def to_bootstrap_config(self) -> dict[str, str]: This config can be passed to bootstrap_namespaces() to reconstruct the storage in a subprocess. """ - return { + config = { "type": "file", "base_path": str(self._base_path), } + if self._workspace_id is not None: + config["workspace_id"] = self._workspace_id + return config class RedisStorage: @@ -225,6 +262,7 @@ def __init__( url: str | None = None, redis: Redis | None = None, prefix: str = "py_code_mode", + workspace_id: str | None = None, ) -> None: """Initialize Redis storage. @@ -234,6 +272,7 @@ def __init__( redis: Redis client instance. Use for advanced configurations (custom connection pools, etc.). Mutually exclusive with url. prefix: Key prefix for all storage. Default: "py_code_mode" + workspace_id: Optional workspace scope for workflows, artifacts, and vectors. Raises: ValueError: If neither url nor redis is provided, or if both are. @@ -260,6 +299,12 @@ def __init__( raise ValueError("Redis client is required") self._prefix = prefix + self._workspace_id = ( + _validate_workspace_id(workspace_id) if workspace_id is not None else None + ) + self._storage_prefix = ( + f"{prefix}:ws:{self._workspace_id}" if self._workspace_id is not None else prefix + ) # Lazy-initialized stores (workflows and artifacts only) self._workflow_library: WorkflowLibrary | None = None @@ -271,6 +316,11 @@ def prefix(self) -> str: """Get the configured prefix.""" return self._prefix + @property + def workspace_id(self) -> str | None: + """Get the configured workspace scope.""" + return self._workspace_id + @property def client(self) -> Redis: """Get the Redis client.""" @@ -322,7 +372,7 @@ def get_vector_store(self) -> VectorStore | None: self._vector_store = RedisVectorStore( redis=self._redis, embedder=embedder, - prefix=f"{self._prefix}:vectors", + prefix=f"{self._storage_prefix}:vectors", ) except ImportError: self._vector_store = None @@ -343,7 +393,7 @@ def get_serializable_access(self) -> RedisStorageAccess: else: redis_url = self._reconstruct_redis_url() - prefix = self._prefix + prefix = self._storage_prefix # vectors_prefix is set when RedisVectorStore dependencies are available # (redis-py with RediSearch). We check module availability, not actual # vector store creation, to avoid side effects during serialization. @@ -355,12 +405,13 @@ def get_serializable_access(self) -> RedisStorageAccess: workflows_prefix=f"{prefix}:workflows", artifacts_prefix=f"{prefix}:artifacts", vectors_prefix=vectors_prefix, + root_prefix=self._prefix, ) def get_workflow_library(self) -> WorkflowLibrary: """Return WorkflowLibrary for in-process execution.""" if self._workflow_library is None: - raw_store = RedisWorkflowStore(self._redis, prefix=f"{self._prefix}:workflows") + raw_store = RedisWorkflowStore(self._redis, prefix=f"{self._storage_prefix}:workflows") vector_store = self.get_vector_store() try: self._workflow_library = create_workflow_library( @@ -385,13 +436,13 @@ def get_artifact_store(self) -> ArtifactStoreProtocol: """Return artifact store for in-process execution.""" if self._artifact_store is None: self._artifact_store = RedisArtifactStore( - self._redis, prefix=f"{self._prefix}:artifacts" + self._redis, prefix=f"{self._storage_prefix}:artifacts" ) return self._artifact_store def get_workflow_store(self) -> WorkflowStore: """Return the underlying WorkflowStore for direct access.""" - return RedisWorkflowStore(self._redis, prefix=f"{self._prefix}:workflows") + return RedisWorkflowStore(self._redis, prefix=f"{self._storage_prefix}:workflows") def to_bootstrap_config(self) -> dict[str, str]: """Serialize storage configuration for subprocess bootstrap. @@ -401,8 +452,11 @@ def to_bootstrap_config(self) -> dict[str, str]: This config can be passed to bootstrap_namespaces() to reconstruct the storage in a subprocess. """ - return { + config = { "type": "redis", "url": self._reconstruct_redis_url(), "prefix": self._prefix, } + if self._workspace_id is not None: + config["workspace_id"] = self._workspace_id + return config diff --git a/tests/container/test_client.py b/tests/container/test_client.py index 7dd7826..d7f105b 100644 --- a/tests/container/test_client.py +++ b/tests/container/test_client.py @@ -1,5 +1,6 @@ """Tests for session server HTTP client.""" +import asyncio from unittest.mock import AsyncMock, MagicMock import httpx @@ -17,6 +18,11 @@ def make_mock_response(json_data: dict, status_code: int = 200) -> MagicMock: return mock_response +def make_session_response(session_id: str = "server-session-1") -> MagicMock: + """Create a mock session creation response.""" + return make_mock_response({"session_id": session_id}) + + class TestSessionClient: """Tests for SessionClient.""" @@ -45,7 +51,8 @@ async def test_execute_simple_code(self) -> None: """Execute returns result from server.""" client = SessionClient() - mock_response = make_mock_response( + session_response = make_session_response() + execute_response = make_mock_response( { "value": 42, "stdout": "", @@ -57,16 +64,21 @@ async def test_execute_simple_code(self) -> None: # Mock the internal client's post method mock_http_client = AsyncMock() - mock_http_client.post = AsyncMock(return_value=mock_response) + mock_http_client.post = AsyncMock(side_effect=[session_response, execute_response]) client._client = mock_http_client result = await client.execute("21 * 2") - mock_http_client.post.assert_called_once() - call_args = mock_http_client.post.call_args - assert call_args[0][0] == "http://localhost:8080/execute" - assert call_args[1]["json"]["code"] == "21 * 2" - assert call_args[1]["headers"] == {} + assert mock_http_client.post.call_count == 2 + session_call = mock_http_client.post.call_args_list[0] + assert session_call[0][0] == "http://localhost:8080/sessions" + assert session_call[1]["json"] == {} + assert session_call[1]["headers"] == {} + + execute_call = mock_http_client.post.call_args_list[1] + assert execute_call[0][0] == "http://localhost:8080/execute" + assert execute_call[1]["json"]["code"] == "21 * 2" + assert execute_call[1]["headers"]["X-Session-ID"] == "server-session-1" assert result.value == 42 assert result.error is None @@ -79,7 +91,8 @@ async def test_execute_with_timeout(self) -> None: """Execute passes timeout to server.""" client = SessionClient() - mock_response = make_mock_response( + session_response = make_session_response() + execute_response = make_mock_response( { "value": None, "stdout": "", @@ -90,12 +103,12 @@ async def test_execute_with_timeout(self) -> None: ) mock_http_client = AsyncMock() - mock_http_client.post = AsyncMock(return_value=mock_response) + mock_http_client.post = AsyncMock(side_effect=[session_response, execute_response]) client._client = mock_http_client await client.execute("import time; time.sleep(1)", timeout=60.0) - call_args = mock_http_client.post.call_args + call_args = mock_http_client.post.call_args_list[1] assert call_args[1]["json"]["timeout"] == 60.0 @pytest.mark.asyncio @@ -103,7 +116,8 @@ async def test_execute_with_error(self) -> None: """Execute returns error from server.""" client = SessionClient() - mock_response = make_mock_response( + session_response = make_session_response() + execute_response = make_mock_response( { "value": None, "stdout": "", @@ -114,7 +128,7 @@ async def test_execute_with_error(self) -> None: ) mock_http_client = AsyncMock() - mock_http_client.post = AsyncMock(return_value=mock_response) + mock_http_client.post = AsyncMock(side_effect=[session_response, execute_response]) client._client = mock_http_client result = await client.execute("1/0") @@ -128,6 +142,7 @@ async def test_execute_reuses_server_assigned_session_id(self) -> None: """Second execute sends the server-issued session ID.""" client = SessionClient() + session_response = make_session_response() first_response = make_mock_response( { "value": 42, @@ -148,15 +163,200 @@ async def test_execute_reuses_server_assigned_session_id(self) -> None: ) mock_http_client = AsyncMock() - mock_http_client.post = AsyncMock(side_effect=[first_response, second_response]) + mock_http_client.post = AsyncMock( + side_effect=[session_response, first_response, second_response] + ) client._client = mock_http_client await client.execute("x = 42") await client.execute("x * 2") - second_call = mock_http_client.post.call_args_list[1] + second_call = mock_http_client.post.call_args_list[2] assert second_call[1]["headers"]["X-Session-ID"] == "server-session-1" + @pytest.mark.asyncio + async def test_concurrent_first_use_binds_only_one_remote_session(self) -> None: + """Concurrent first use should not create multiple remote sessions.""" + client = SessionClient() + + session_response = make_session_response() + execute_responses = [ + make_mock_response( + { + "value": 1, + "stdout": "", + "error": None, + "execution_time_ms": 5.0, + "session_id": "server-session-1", + } + ), + make_mock_response( + { + "value": 2, + "stdout": "", + "error": None, + "execution_time_ms": 5.0, + "session_id": "server-session-1", + } + ), + ] + + async def post_side_effect(url: str, **kwargs): + if url.endswith("/sessions"): + await asyncio.sleep(0) + return session_response + if url.endswith("/execute"): + await asyncio.sleep(0) + return execute_responses.pop(0) + raise AssertionError(f"Unexpected URL {url}") + + mock_http_client = AsyncMock() + mock_http_client.post = AsyncMock(side_effect=post_side_effect) + client._client = mock_http_client + + results = await asyncio.gather(client.execute("1"), client.execute("2")) + + session_calls = [ + call + for call in mock_http_client.post.call_args_list + if call[0][0].endswith("/sessions") + ] + assert len(session_calls) == 1 + + execute_calls = [ + call for call in mock_http_client.post.call_args_list if call[0][0].endswith("/execute") + ] + assert len(execute_calls) == 2 + assert all( + call[1]["headers"]["X-Session-ID"] == "server-session-1" for call in execute_calls + ) + assert [result.value for result in results] == [1, 2] + + @pytest.mark.asyncio + async def test_concurrent_stale_session_rebinds_only_once(self) -> None: + """Concurrent stale-session recovery should create only one replacement session.""" + client = SessionClient() + client.session_id = "stale-session" + client._workspace_id = "workspace-a" + + invalid_response = make_mock_response({"detail": "Invalid session ID"}, status_code=400) + rebound_session = make_session_response("rebound-session") + rebound_values = [10, 20] + + async def post_side_effect(url: str, **kwargs): + if url.endswith("/sessions"): + await asyncio.sleep(0) + return rebound_session + if url.endswith("/execute"): + await asyncio.sleep(0) + session_id = kwargs["headers"]["X-Session-ID"] + if session_id == "stale-session": + return invalid_response + if session_id == "rebound-session": + value = rebound_values.pop(0) + return make_mock_response( + { + "value": value, + "stdout": "", + "error": None, + "execution_time_ms": 5.0, + "session_id": "rebound-session", + } + ) + raise AssertionError(f"Unexpected call {url} {kwargs}") + + mock_http_client = AsyncMock() + mock_http_client.post = AsyncMock(side_effect=post_side_effect) + client._client = mock_http_client + + results = await asyncio.gather(client.execute("10"), client.execute("20")) + + session_calls = [ + call + for call in mock_http_client.post.call_args_list + if call[0][0].endswith("/sessions") + ] + assert len(session_calls) == 1 + + execute_calls = [ + call for call in mock_http_client.post.call_args_list if call[0][0].endswith("/execute") + ] + stale_calls = [ + call for call in execute_calls if call[1]["headers"]["X-Session-ID"] == "stale-session" + ] + rebound_calls = [ + call + for call in execute_calls + if call[1]["headers"]["X-Session-ID"] == "rebound-session" + ] + assert len(stale_calls) == 1 + assert len(rebound_calls) == 2 + assert [result.value for result in results] == [10, 20] + assert client.session_id == "rebound-session" + + @pytest.mark.asyncio + async def test_init_session_waits_for_stale_rebind_before_switching_workspace(self) -> None: + """Explicit rebinding should not redirect an in-flight retry into a new workspace.""" + client = SessionClient() + client.session_id = "stale-session-a" + client._workspace_id = "workspace-a" + + stale_seen = asyncio.Event() + rebound_execute_seen = asyncio.Event() + session_payloads: list[dict[str, str]] = [] + + async def post_side_effect(url: str, **kwargs): + if url.endswith("/execute"): + session_id = kwargs["headers"]["X-Session-ID"] + if session_id == "stale-session-a": + stale_seen.set() + return make_mock_response( + {"detail": "Invalid session ID"}, + status_code=400, + ) + if session_id == "rebound-session-a": + rebound_execute_seen.set() + return make_mock_response( + { + "value": 1, + "stdout": "", + "error": None, + "execution_time_ms": 5.0, + "session_id": "rebound-session-a", + } + ) + if url.endswith("/sessions"): + session_payloads.append(kwargs["json"]) + workspace_id = kwargs["json"].get("workspace_id") + if workspace_id == "workspace-a": + await asyncio.sleep(0) + return make_mock_response({"session_id": "rebound-session-a"}) + if workspace_id == "workspace-b": + await asyncio.sleep(0) + return make_mock_response({"session_id": "session-b"}) + raise AssertionError(f"Unexpected call {url} {kwargs}") + + mock_http_client = AsyncMock() + mock_http_client.post = AsyncMock(side_effect=post_side_effect) + client._client = mock_http_client + + execute_task = asyncio.create_task(client.execute("1")) + await stale_seen.wait() + init_task = asyncio.create_task(client.init_session("workspace-b")) + + execute_result = await execute_task + new_session_id = await init_task + + assert execute_result.value == 1 + assert execute_result.session_id == "rebound-session-a" + assert rebound_execute_seen.is_set() + assert new_session_id == "session-b" + assert client.session_id == "session-b" + assert session_payloads == [ + {"workspace_id": "workspace-a"}, + {"workspace_id": "workspace-b"}, + ] + class TestSessionClientHealth: """Tests for health check method.""" @@ -191,6 +391,7 @@ async def test_info_returns_tools_and_workflows(self) -> None: """Info returns available tools and workflows.""" client = SessionClient() + session_response = make_session_response() mock_response = make_mock_response( { "tools": [{"name": "cli.nmap", "description": "Network scanner"}], @@ -200,6 +401,7 @@ async def test_info_returns_tools_and_workflows(self) -> None: ) mock_http_client = AsyncMock() + mock_http_client.post = AsyncMock(return_value=session_response) mock_http_client.get = AsyncMock(return_value=mock_response) client._client = mock_http_client @@ -210,7 +412,7 @@ async def test_info_returns_tools_and_workflows(self) -> None: assert len(info.workflows) == 1 assert info.workflows[0]["name"] == "scan" call_args = mock_http_client.get.call_args - assert call_args[1]["headers"] == {} + assert call_args[1]["headers"]["X-Session-ID"] == "server-session-1" class TestSessionClientReset: diff --git a/tests/container/test_container_api.py b/tests/container/test_container_api.py index af36d81..646cfcd 100644 --- a/tests/container/test_container_api.py +++ b/tests/container/test_container_api.py @@ -16,6 +16,22 @@ import pytest + +def auth_headers(token: str, session_id: str | None = None) -> dict[str, str]: + """Build auth headers for container API requests.""" + headers = {"Authorization": f"Bearer {token}"} + if session_id is not None: + headers["X-Session-ID"] = session_id + return headers + + +def create_bound_session(client, token: str) -> str: + """Create a bound session for session-aware API tests.""" + response = client.post("/sessions", json={}, headers=auth_headers(token)) + assert response.status_code == 200 + return response.json()["session_id"] + + # ============================================================================= # SECTION 1: TOOLS API # ============================================================================= @@ -110,9 +126,10 @@ def auth_client(self, tmp_path): def test_list_workflows_returns_empty_when_no_workflows(self, auth_client) -> None: """GET /api/workflows returns empty list when no workflows registered.""" client, token = auth_client + session_id = create_bound_session(client, token) response = client.get( "/api/workflows", - headers={"Authorization": f"Bearer {token}"}, + headers=auth_headers(token, session_id), ) assert response.status_code == 200 data = response.json() @@ -127,10 +144,11 @@ def test_list_workflows_requires_auth(self, auth_client) -> None: def test_search_workflows_returns_empty_when_no_workflows(self, auth_client) -> None: """GET /api/workflows/search returns empty list when no workflows registered.""" client, token = auth_client + session_id = create_bound_session(client, token) response = client.get( "/api/workflows/search", params={"query": "fetch"}, - headers={"Authorization": f"Bearer {token}"}, + headers=auth_headers(token, session_id), ) assert response.status_code == 200 data = response.json() @@ -164,7 +182,7 @@ def test_workflows_endpoints_return_503_if_library_not_initialized(self, tmp_pat app = create_app(config) with TestClient(app) as client: - headers = {"Authorization": "Bearer test-token"} + headers = auth_headers("test-token", create_bound_session(client, "test-token")) resp = client.get("/api/workflows", headers=headers) assert resp.status_code == 503 @@ -175,9 +193,10 @@ def test_workflows_endpoints_return_503_if_library_not_initialized(self, tmp_pat def test_get_workflow_returns_none_when_not_found(self, auth_client) -> None: """GET /api/workflows/{name} returns null when workflow not found.""" client, token = auth_client + session_id = create_bound_session(client, token) response = client.get( "/api/workflows/nonexistent", - headers={"Authorization": f"Bearer {token}"}, + headers=auth_headers(token, session_id), ) assert response.status_code == 200 data = response.json() @@ -192,6 +211,7 @@ def test_get_workflow_requires_auth(self, auth_client) -> None: def test_create_workflow_success(self, auth_client) -> None: """POST /api/workflows creates a new workflow.""" client, token = auth_client + session_id = create_bound_session(client, token) response = client.post( "/api/workflows", json={ @@ -199,7 +219,7 @@ def test_create_workflow_success(self, auth_client) -> None: "source": "async def run(x: int) -> int:\n return x * 2", "description": "Doubles a number", }, - headers={"Authorization": f"Bearer {token}"}, + headers=auth_headers(token, session_id), ) assert response.status_code == 200 data = response.json() @@ -212,6 +232,7 @@ def test_list_workflows_refreshes_after_external_create(self, auth_client, tmp_p from py_code_mode.workflows import FileWorkflowStore, PythonWorkflow client, token = auth_client + session_id = create_bound_session(client, token) store = FileWorkflowStore(tmp_path / "workflows") store.save( PythonWorkflow.from_source( @@ -223,7 +244,7 @@ def test_list_workflows_refreshes_after_external_create(self, auth_client, tmp_p response = client.get( "/api/workflows", - headers={"Authorization": f"Bearer {token}"}, + headers=auth_headers(token, session_id), ) assert response.status_code == 200 @@ -234,6 +255,7 @@ def test_get_workflow_refreshes_after_external_edit(self, auth_client, tmp_path) from py_code_mode.workflows import FileWorkflowStore, PythonWorkflow client, token = auth_client + session_id = create_bound_session(client, token) store = FileWorkflowStore(tmp_path / "workflows") store.save( PythonWorkflow.from_source( @@ -253,7 +275,7 @@ def test_get_workflow_refreshes_after_external_edit(self, auth_client, tmp_path) response = client.get( "/api/workflows/editable", - headers={"Authorization": f"Bearer {token}"}, + headers=auth_headers(token, session_id), ) assert response.status_code == 200 @@ -264,6 +286,7 @@ def test_info_refreshes_after_external_workflow_create(self, auth_client, tmp_pa from py_code_mode.workflows import FileWorkflowStore, PythonWorkflow client, token = auth_client + session_id = create_bound_session(client, token) store = FileWorkflowStore(tmp_path / "workflows") store.save( PythonWorkflow.from_source( @@ -275,7 +298,7 @@ def test_info_refreshes_after_external_workflow_create(self, auth_client, tmp_pa response = client.get( "/info", - headers={"Authorization": f"Bearer {token}"}, + headers=auth_headers(token, session_id), ) assert response.status_code == 200 @@ -299,6 +322,7 @@ def test_create_workflow_requires_auth(self, auth_client) -> None: def test_create_workflow_invalid_source_returns_400(self, auth_client) -> None: """POST /api/workflows returns 400 for invalid source code.""" client, token = auth_client + session_id = create_bound_session(client, token) response = client.post( "/api/workflows", json={ @@ -306,13 +330,14 @@ def test_create_workflow_invalid_source_returns_400(self, auth_client) -> None: "source": "not valid python +++", "description": "Invalid", }, - headers={"Authorization": f"Bearer {token}"}, + headers=auth_headers(token, session_id), ) assert response.status_code == 400 def test_create_workflow_no_run_returns_400(self, auth_client) -> None: """POST /api/workflows returns 400 when source has no run() function.""" client, token = auth_client + session_id = create_bound_session(client, token) response = client.post( "/api/workflows", json={ @@ -320,7 +345,7 @@ def test_create_workflow_no_run_returns_400(self, auth_client) -> None: "source": "def other_func(): pass", "description": "No run", }, - headers={"Authorization": f"Bearer {token}"}, + headers=auth_headers(token, session_id), ) assert response.status_code == 400 @@ -333,9 +358,10 @@ def test_delete_workflow_requires_auth(self, auth_client) -> None: def test_delete_workflow_returns_false_when_not_found(self, auth_client) -> None: """DELETE /api/workflows/{name} returns false when workflow not found.""" client, token = auth_client + session_id = create_bound_session(client, token) response = client.delete( "/api/workflows/nonexistent", - headers={"Authorization": f"Bearer {token}"}, + headers=auth_headers(token, session_id), ) assert response.status_code == 200 assert response.json() is False @@ -343,7 +369,7 @@ def test_delete_workflow_returns_false_when_not_found(self, auth_client) -> None def test_workflow_lifecycle_create_get_delete(self, auth_client) -> None: """Full workflow lifecycle: create, get, delete.""" client, token = auth_client - headers = {"Authorization": f"Bearer {token}"} + headers = auth_headers(token, create_bound_session(client, token)) # Create workflow_source = ( @@ -409,9 +435,10 @@ def auth_client(self, tmp_path): def test_list_artifacts_returns_empty_when_no_artifacts(self, auth_client) -> None: """GET /api/artifacts returns empty list when no artifacts saved.""" client, token = auth_client + session_id = create_bound_session(client, token) response = client.get( "/api/artifacts", - headers={"Authorization": f"Bearer {token}"}, + headers=auth_headers(token, session_id), ) assert response.status_code == 200 data = response.json() @@ -427,9 +454,10 @@ def test_list_artifacts_requires_auth(self, auth_client) -> None: def test_load_artifact_returns_404_when_not_found(self, auth_client) -> None: """GET /api/artifacts/{name} returns 404 when artifact not found.""" client, token = auth_client + session_id = create_bound_session(client, token) response = client.get( "/api/artifacts/nonexistent", - headers={"Authorization": f"Bearer {token}"}, + headers=auth_headers(token, session_id), ) assert response.status_code == 404 @@ -442,7 +470,7 @@ def test_load_artifact_requires_auth(self, auth_client) -> None: def test_list_artifacts_omits_externally_deleted_artifact(self, auth_client, tmp_path) -> None: """GET /api/artifacts prunes stale metadata after external file deletion.""" client, token = auth_client - headers = {"Authorization": f"Bearer {token}"} + headers = auth_headers(token, create_bound_session(client, token)) response = client.post( "/api/artifacts", @@ -466,7 +494,7 @@ def test_load_artifact_returns_404_after_external_file_delete( ) -> None: """GET /api/artifacts/{name} returns 404 when a tracked file is deleted externally.""" client, token = auth_client - headers = {"Authorization": f"Bearer {token}"} + headers = auth_headers(token, create_bound_session(client, token)) response = client.post( "/api/artifacts", @@ -487,6 +515,7 @@ def test_load_artifact_returns_404_after_external_file_delete( def test_save_artifact_success(self, auth_client) -> None: """POST /api/artifacts saves an artifact.""" client, token = auth_client + session_id = create_bound_session(client, token) response = client.post( "/api/artifacts", json={ @@ -494,7 +523,7 @@ def test_save_artifact_success(self, auth_client) -> None: "data": {"key": "value", "number": 42}, "description": "Test artifact", }, - headers={"Authorization": f"Bearer {token}"}, + headers=auth_headers(token, session_id), ) assert response.status_code == 200 data = response.json() @@ -519,16 +548,17 @@ def test_delete_artifact_requires_auth(self, auth_client) -> None: def test_delete_artifact_returns_404_when_not_found(self, auth_client) -> None: """DELETE /api/artifacts/{name} returns 404 when artifact not found.""" client, token = auth_client + session_id = create_bound_session(client, token) response = client.delete( "/api/artifacts/nonexistent", - headers={"Authorization": f"Bearer {token}"}, + headers=auth_headers(token, session_id), ) assert response.status_code == 404 def test_artifact_lifecycle_save_load_delete(self, auth_client) -> None: """Full artifact lifecycle: save, load, delete.""" client, token = auth_client - headers = {"Authorization": f"Bearer {token}"} + headers = auth_headers(token, create_bound_session(client, token)) # Save response = client.post( diff --git a/tests/container/test_container_auth.py b/tests/container/test_container_auth.py index 09269a8..5e4c3ea 100644 --- a/tests/container/test_container_auth.py +++ b/tests/container/test_container_auth.py @@ -23,6 +23,14 @@ from py_code_mode.execution.container.config import ContainerConfig, SessionConfig + +def create_bound_session(client, headers: dict[str, str] | None = None) -> str: + """Create a bound session for session-aware endpoint tests.""" + response = client.post("/sessions", json={}, headers=headers or {}) + assert response.status_code == 200 + return response.json()["session_id"] + + # ============================================================================= # SECTION 1: AUTH REJECTION (Critical Security) # ============================================================================= @@ -250,14 +258,8 @@ def test_protected_endpoint_with_valid_token_succeeds( """Protected endpoints succeed with valid token.""" headers = {"Authorization": f"Bearer {auth_token}"} - if endpoint == "/reset": - create_response = auth_enabled_client.post( - "/execute", - json={"code": "x = 42"}, - headers=headers, - ) - assert create_response.status_code == 200 - session_id = create_response.json()["session_id"] + if endpoint in {"/execute", "/reset", "/info"}: + session_id = create_bound_session(auth_enabled_client, headers=headers) headers["X-Session-ID"] = session_id if method == "post": @@ -285,6 +287,8 @@ def test_token_with_urlsafe_special_chars_succeeds(self, tmp_path) -> None: app = create_app(config) with TestClient(app) as client: headers = {"Authorization": f"Bearer {special_token}"} + session_id = create_bound_session(client, headers=headers) + headers["X-Session-ID"] = session_id response = client.post("/execute", json={"code": "1 + 1"}, headers=headers) assert response.status_code == 200 @@ -320,13 +324,20 @@ def auth_disabled_client(self, tmp_path, monkeypatch): def test_requests_without_token_succeed_when_auth_disabled(self, auth_disabled_client) -> None: """Requests without token succeed when auth is explicitly disabled.""" - response = auth_disabled_client.post("/execute", json={"code": "1 + 1"}) + session_id = create_bound_session(auth_disabled_client) + response = auth_disabled_client.post( + "/execute", + json={"code": "1 + 1"}, + headers={"X-Session-ID": session_id}, + ) assert response.status_code == 200 def test_token_sent_to_disabled_server_is_ignored(self, auth_disabled_client) -> None: """Token sent to auth-disabled server is ignored (not validated).""" # Even an invalid token should be accepted (ignored) headers = {"Authorization": "Bearer some-random-token"} + session_id = create_bound_session(auth_disabled_client, headers=headers) + headers["X-Session-ID"] = session_id response = auth_disabled_client.post("/execute", json={"code": "1 + 1"}, headers=headers) assert response.status_code == 200 @@ -515,7 +526,7 @@ def test_health_endpoint_with_valid_token_also_works(self, auth_enabled_client) class TestSessionsEndpointRemoved: - """Tests verifying sessions endpoint is removed (information leakage).""" + """Tests verifying session enumeration remains unavailable.""" @pytest.fixture def client(self, tmp_path): @@ -534,16 +545,16 @@ def client(self, tmp_path): with TestClient(app) as client: yield client - def test_sessions_endpoint_returns_404(self, client) -> None: - """GET /sessions returns 404 (endpoint removed).""" + def test_sessions_endpoint_get_returns_405(self, client) -> None: + """GET /sessions is not allowed; only POST session creation is supported.""" response = client.get("/sessions") - assert response.status_code == 404 + assert response.status_code == 405 - def test_sessions_endpoint_with_valid_auth_still_returns_404(self, client) -> None: - """GET /sessions with valid auth still returns 404 (endpoint removed).""" + def test_sessions_endpoint_get_with_valid_auth_still_returns_405(self, client) -> None: + """GET /sessions stays unavailable even with valid auth.""" headers = {"Authorization": "Bearer secret-token"} response = client.get("/sessions", headers=headers) - assert response.status_code == 404 + assert response.status_code == 405 # ============================================================================= @@ -637,29 +648,37 @@ async def test_session_client_sends_authorization_header(self) -> None: auth_token="client-auth-token", ) - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { + session_response = MagicMock() + session_response.status_code = 200 + session_response.json.return_value = {"session_id": "test"} + session_response.raise_for_status = MagicMock() + + execute_response = MagicMock() + execute_response.status_code = 200 + execute_response.json.return_value = { "value": 42, "stdout": "", "error": None, "execution_time_ms": 1.0, "session_id": "test", } - mock_response.raise_for_status = MagicMock() + execute_response.raise_for_status = MagicMock() mock_http_client = AsyncMock() - mock_http_client.post = AsyncMock(return_value=mock_response) + mock_http_client.post = AsyncMock(side_effect=[session_response, execute_response]) client._client = mock_http_client await client.execute("1 + 1") - # Verify Authorization header was sent - call_args = mock_http_client.post.call_args - headers = call_args[1].get("headers", {}) - assert "Authorization" in headers - assert headers["Authorization"] == "Bearer client-auth-token" - assert "X-Session-ID" not in headers + session_call = mock_http_client.post.call_args_list[0] + session_headers = session_call[1].get("headers", {}) + assert session_headers["Authorization"] == "Bearer client-auth-token" + assert "X-Session-ID" not in session_headers + + execute_call = mock_http_client.post.call_args_list[1] + execute_headers = execute_call[1].get("headers", {}) + assert execute_headers["Authorization"] == "Bearer client-auth-token" + assert execute_headers["X-Session-ID"] == "test" await client.close() diff --git a/tests/container/test_remote_workspace.py b/tests/container/test_remote_workspace.py new file mode 100644 index 0000000..c8d25d5 --- /dev/null +++ b/tests/container/test_remote_workspace.py @@ -0,0 +1,880 @@ +"""Real-behavior tests for remote workspace scoping in the session server.""" + +import asyncio +import os + +import pytest +import redis +import uvicorn + +import docker +from py_code_mode.execution.container import ContainerConfig, ContainerExecutor +from py_code_mode.execution.container.config import SessionConfig +from py_code_mode.session import Session +from py_code_mode.storage import FileStorage, RedisStorage +from tests.docker_diagnostics import did_test_fail, emit_testcontainer_logs + +REMOTE_REDIS_STORAGE_PREFIX = "remote_workspace_journey" + + +def auth_headers(token: str, session_id: str | None = None) -> dict[str, str]: + """Build auth headers for container API requests.""" + headers = {"Authorization": f"Bearer {token}"} + if session_id is not None: + headers["X-Session-ID"] = session_id + return headers + + +def create_scoped_session(client, token: str, workspace_id: str | None = None) -> str: + """Create a server session optionally bound to a workspace.""" + payload = {} if workspace_id is None else {"workspace_id": workspace_id} + response = client.post("/sessions", json=payload, headers=auth_headers(token)) + assert response.status_code == 200 + data = response.json() + assert data["session_id"] + return data["session_id"] + + +def _docker_daemon_is_available() -> bool: + """Check whether Docker is available for real Redis container tests.""" + try: + client = docker.from_env() + client.ping() + except Exception: + return False + return True + + +class TestRemoteWorkspaceSessions: + """Tests for workspace-aware remote session behavior.""" + + @pytest.fixture + def auth_client(self, tmp_path): + """Create authenticated client with real workflow and artifact storage.""" + try: + from fastapi.testclient import TestClient + except ImportError: + pytest.skip("FastAPI not installed") + + from py_code_mode.execution.container.server import create_app + + config = SessionConfig( + artifacts_path=tmp_path / "artifacts", + workflows_path=tmp_path / "workflows", + ) + config.auth_token = "test-token" + + app = create_app(config) + with TestClient(app) as client: + yield client, "test-token" + + def test_create_session_accepts_workspace_id(self, auth_client) -> None: + """POST /sessions should allow explicit workspace binding.""" + client, token = auth_client + + response = client.post( + "/sessions", + json={"workspace_id": "client_a"}, + headers=auth_headers(token), + ) + + assert response.status_code == 200 + assert response.json()["session_id"] + + def test_create_session_rejects_invalid_workspace_id(self, auth_client) -> None: + """POST /sessions should reject invalid workspace IDs.""" + client, token = auth_client + + response = client.post( + "/sessions", + json={"workspace_id": "../escape"}, + headers=auth_headers(token), + ) + + assert response.status_code == 400 + + def test_same_workspace_sessions_share_workflows(self, auth_client) -> None: + """Sessions in the same workspace should see the same workflows.""" + client, token = auth_client + session_a = create_scoped_session(client, token, "client_a") + session_b = create_scoped_session(client, token, "client_a") + + response = client.post( + "/api/workflows", + json={ + "name": "shared_workflow", + "source": 'async def run() -> str:\n return "ok"\n', + "description": "Shared workflow", + }, + headers=auth_headers(token, session_a), + ) + assert response.status_code == 200 + + response = client.get( + "/api/workflows", + headers=auth_headers(token, session_b), + ) + assert response.status_code == 200 + assert any(workflow["name"] == "shared_workflow" for workflow in response.json()) + + def test_different_workspace_sessions_isolate_workflows(self, auth_client) -> None: + """Sessions in different workspaces should not share workflows.""" + client, token = auth_client + session_a = create_scoped_session(client, token, "client_a") + session_b = create_scoped_session(client, token, "client_b") + + response = client.post( + "/api/workflows", + json={ + "name": "isolated_workflow", + "source": 'async def run() -> str:\n return "ok"\n', + "description": "Workspace A only", + }, + headers=auth_headers(token, session_a), + ) + assert response.status_code == 200 + + response = client.get( + "/api/workflows", + headers=auth_headers(token, session_b), + ) + assert response.status_code == 200 + assert all(workflow["name"] != "isolated_workflow" for workflow in response.json()) + + def test_different_workspace_sessions_isolate_workflow_search(self, auth_client) -> None: + """Workflow search should respect workspace boundaries.""" + client, token = auth_client + session_a = create_scoped_session(client, token, "client_a") + session_b = create_scoped_session(client, token, "client_b") + + response = client.post( + "/api/workflows", + json={ + "name": "summarize_notes", + "source": 'async def run() -> str:\n return "ok"\n', + "description": "Summarize notes", + }, + headers=auth_headers(token, session_a), + ) + assert response.status_code == 200 + + response = client.get( + "/api/workflows/search", + params={"query": "summarize"}, + headers=auth_headers(token, session_b), + ) + assert response.status_code == 200 + assert response.json() == [] + + def test_different_workspace_sessions_isolate_artifacts(self, auth_client) -> None: + """Sessions in different workspaces should not share artifacts.""" + client, token = auth_client + session_a = create_scoped_session(client, token, "client_a") + session_b = create_scoped_session(client, token, "client_b") + + response = client.post( + "/api/artifacts", + json={ + "name": "notes.json", + "data": {"owner": "workspace-a"}, + "description": "Workspace artifact", + }, + headers=auth_headers(token, session_a), + ) + assert response.status_code == 200 + + response = client.get( + "/api/artifacts", + headers=auth_headers(token, session_b), + ) + assert response.status_code == 200 + assert response.json() == [] + + def test_info_uses_bound_session_workspace(self, auth_client) -> None: + """GET /info should return workflows from the bound session's workspace.""" + client, token = auth_client + session_a = create_scoped_session(client, token, "client_a") + session_b = create_scoped_session(client, token, "client_b") + + response = client.post( + "/api/workflows", + json={ + "name": "workspace_a_only", + "source": 'async def run() -> str:\n return "ok"\n', + "description": "Visible only to workspace A", + }, + headers=auth_headers(token, session_a), + ) + assert response.status_code == 200 + + response = client.get("/info", headers=auth_headers(token, session_b)) + assert response.status_code == 200 + assert all( + workflow["name"] != "workspace_a_only" for workflow in response.json()["workflows"] + ) + + def test_unknown_session_is_rejected_by_session_aware_api(self, auth_client) -> None: + """Session-aware workflow APIs should reject unknown session IDs.""" + client, token = auth_client + + response = client.get( + "/api/workflows", + headers=auth_headers(token, "missing-session"), + ) + + assert response.status_code == 400 + assert response.json()["detail"] == "Invalid session ID" + + def test_requests_without_session_header_are_rejected(self, auth_client) -> None: + """Session-aware remote storage APIs should fail closed without X-Session-ID.""" + client, token = auth_client + + response = client.get("/api/workflows", headers=auth_headers(token)) + assert response.status_code == 400 + assert response.json()["detail"] == "Invalid session ID" + + +@pytest.fixture +async def live_session_server_url(tmp_path, unused_tcp_port: int) -> str: + """Start a real session server over HTTP and return its base URL.""" + from py_code_mode.execution.container.server import create_app + + config = SessionConfig( + artifacts_path=tmp_path / "server-artifacts", + workflows_path=tmp_path / "server-workflows", + auth_disabled=True, + ) + app = create_app(config) + server_config = uvicorn.Config(app, host="127.0.0.1", port=unused_tcp_port, log_level="warning") + server = uvicorn.Server(server_config) + + task = asyncio.create_task(server.serve()) + while not server.started: + await asyncio.sleep(0.1) + + try: + yield f"http://127.0.0.1:{unused_tcp_port}" + finally: + server.should_exit = True + await task + + +@pytest.fixture +async def live_expiring_session_server_url( + tmp_path, unused_tcp_port: int, monkeypatch: pytest.MonkeyPatch +) -> str: + """Start a real session server whose bound sessions expire quickly.""" + from py_code_mode.execution.container import server as server_module + + monkeypatch.setattr(server_module, "SESSION_EXPIRY", 0.1) + config = SessionConfig( + artifacts_path=tmp_path / "server-artifacts", + workflows_path=tmp_path / "server-workflows", + auth_disabled=True, + ) + app = server_module.create_app(config) + server_config = uvicorn.Config(app, host="127.0.0.1", port=unused_tcp_port, log_level="warning") + server = uvicorn.Server(server_config) + + task = asyncio.create_task(server.serve()) + while not server.started: + await asyncio.sleep(0.1) + + try: + yield f"http://127.0.0.1:{unused_tcp_port}" + finally: + server.should_exit = True + await task + + +@pytest.fixture +async def live_redis_session_server_url( + remote_workspace_redis_url: str, unused_tcp_port: int +) -> str: + """Start a real Redis-backed session server over HTTP and return its base URL.""" + from py_code_mode.execution.container.server import create_app + + config = SessionConfig( + redis_url=remote_workspace_redis_url, + storage_prefix=REMOTE_REDIS_STORAGE_PREFIX, + auth_disabled=True, + ) + app = create_app(config) + server_config = uvicorn.Config(app, host="127.0.0.1", port=unused_tcp_port, log_level="warning") + server = uvicorn.Server(server_config) + + task = asyncio.create_task(server.serve()) + while not server.started: + await asyncio.sleep(0.1) + + try: + yield f"http://127.0.0.1:{unused_tcp_port}" + finally: + server.should_exit = True + await task + + +@pytest.fixture +def remote_workspace_redis_container(request: pytest.FixtureRequest): + """Start a dedicated Redis container without triggering container image rebuild hooks.""" + pytest.importorskip("testcontainers.redis") + from testcontainers.redis import RedisContainer + + if not _docker_daemon_is_available(): + if os.environ.get("CI"): + pytest.fail("Docker daemon not available for remote Redis workspace journey tests") + pytest.skip("Docker daemon not available for remote Redis workspace journey tests") + + container = RedisContainer(image="redis:7-alpine") + container.start() + try: + yield container + finally: + if did_test_fail(request.node): + emit_testcontainer_logs( + container, + source=f"testcontainers.RedisContainer ({request.node.nodeid})", + ) + container.stop() + + +@pytest.fixture +def remote_workspace_redis_url(remote_workspace_redis_container) -> str: + """Return a Redis URL for the dedicated workspace journey container.""" + host = remote_workspace_redis_container.get_container_host_ip() + port = remote_workspace_redis_container.get_exposed_port(6379) + return f"redis://{host}:{port}" + + +def create_remote_redis_session( + redis_url: str, + remote_url: str, + workspace_id: str | None = None, +) -> Session: + """Create a developer-style Session using RedisStorage and a remote container executor.""" + storage = RedisStorage( + redis=redis.from_url(redis_url), + prefix=REMOTE_REDIS_STORAGE_PREFIX, + workspace_id=workspace_id, + ) + executor = ContainerExecutor( + ContainerConfig( + remote_url=remote_url, + timeout=30.0, + auth_disabled=True, + ) + ) + return Session(storage=storage, executor=executor) + + +class TestRemoteWorkspaceSessionE2E: + """End-to-end tests through Session() against a live session server.""" + + @pytest.mark.asyncio + async def test_sessions_in_same_workspace_share_remote_state( + self, tmp_path, live_session_server_url: str + ) -> None: + """Two Session() objects in the same workspace should share remote workflows/artifacts.""" + session_a = Session( + storage=FileStorage(tmp_path / "client-storage-a", workspace_id="client_a"), + executor=ContainerExecutor( + ContainerConfig( + remote_url=live_session_server_url, + timeout=30.0, + auth_disabled=True, + ) + ), + ) + session_b = Session( + storage=FileStorage(tmp_path / "client-storage-b", workspace_id="client_a"), + executor=ContainerExecutor( + ContainerConfig( + remote_url=live_session_server_url, + timeout=30.0, + auth_disabled=True, + ) + ), + ) + + async with session_a, session_b: + result = await session_a.run( + "\n".join( + [ + "workflows.create(", + " 'shared_remote',", + " 'async def run() -> str:\\n return \"ok\"\\n',", + " 'Shared remote workflow',", + ")", + "artifacts.save('shared_remote.json', {'value': 1}, description='')", + ] + ) + ) + assert result.error is None + + result = await session_b.run( + "[workflows.get('shared_remote') is not None, " + "artifacts.load('shared_remote.json')['value']]" + ) + assert result.error is None + assert result.value == [True, 1] + + @pytest.mark.asyncio + async def test_sessions_in_different_workspaces_are_isolated( + self, tmp_path, live_session_server_url: str + ) -> None: + """Two Session() objects in different workspaces should not share remote state.""" + session_a = Session( + storage=FileStorage(tmp_path / "client-storage-a", workspace_id="client_a"), + executor=ContainerExecutor( + ContainerConfig( + remote_url=live_session_server_url, + timeout=30.0, + auth_disabled=True, + ) + ), + ) + session_b = Session( + storage=FileStorage(tmp_path / "client-storage-b", workspace_id="client_b"), + executor=ContainerExecutor( + ContainerConfig( + remote_url=live_session_server_url, + timeout=30.0, + auth_disabled=True, + ) + ), + ) + + async with session_a, session_b: + result = await session_a.run( + "\n".join( + [ + "workflows.create(", + " 'isolated_remote',", + " 'async def run() -> str:\\n return \"ok\"\\n',", + " 'Workspace A workflow',", + ")", + "artifacts.save('isolated_remote.json', {'value': 1}, description='')", + ] + ) + ) + assert result.error is None + + workflow_result = await session_b.run("workflows.get('isolated_remote')") + assert workflow_result.error is None + assert workflow_result.value is None + + artifact_result = await session_b.run("artifacts.exists('isolated_remote.json')") + assert artifact_result.error is None + assert artifact_result.value is False + + @pytest.mark.asyncio + async def test_sessions_without_workspace_id_use_shared_legacy_namespace( + self, tmp_path, live_session_server_url: str + ) -> None: + """Omitting workspace_id should keep the legacy shared namespace behavior.""" + session_a = Session( + storage=FileStorage(tmp_path / "client-storage-a"), + executor=ContainerExecutor( + ContainerConfig( + remote_url=live_session_server_url, + timeout=30.0, + auth_disabled=True, + ) + ), + ) + session_b = Session( + storage=FileStorage(tmp_path / "client-storage-b"), + executor=ContainerExecutor( + ContainerConfig( + remote_url=live_session_server_url, + timeout=30.0, + auth_disabled=True, + ) + ), + ) + + async with session_a, session_b: + result = await session_a.run( + "\n".join( + [ + "workflows.create(", + " 'legacy_remote',", + " 'async def run() -> str:\\n return \"ok\"\\n',", + " 'Legacy workflow',", + ")", + ] + ) + ) + assert result.error is None + + result = await session_b.run("workflows.get('legacy_remote') is not None") + assert result.error is None + assert result.value is True + + @pytest.mark.asyncio + async def test_workflow_search_is_isolated_across_session_workspaces( + self, tmp_path, live_session_server_url: str + ) -> None: + """Agent-facing workflow search should respect workspace boundaries end-to-end.""" + session_a = Session( + storage=FileStorage(tmp_path / "client-storage-a", workspace_id="client_a"), + executor=ContainerExecutor( + ContainerConfig( + remote_url=live_session_server_url, + timeout=30.0, + auth_disabled=True, + ) + ), + ) + session_b = Session( + storage=FileStorage(tmp_path / "client-storage-b", workspace_id="client_b"), + executor=ContainerExecutor( + ContainerConfig( + remote_url=live_session_server_url, + timeout=30.0, + auth_disabled=True, + ) + ), + ) + + async with session_a, session_b: + result = await session_a.run( + "\n".join( + [ + "workflows.create(", + " 'summarize_remote_notes',", + " 'async def run() -> str:\\n return \"ok\"\\n',", + " 'Summarize remote notes',", + ")", + ] + ) + ) + assert result.error is None + + result = await session_b.run("workflows.search('summarize')") + assert result.error is None + assert result.value == [] + + @pytest.mark.asyncio + async def test_expired_remote_session_rebinds_to_same_workspace( + self, tmp_path, live_expiring_session_server_url: str + ) -> None: + """A remote Session() should transparently recover from server-side expiry.""" + session = Session( + storage=FileStorage(tmp_path / "client-storage-a", workspace_id="client_a"), + executor=ContainerExecutor( + ContainerConfig( + remote_url=live_expiring_session_server_url, + timeout=30.0, + auth_disabled=True, + ) + ), + ) + + async with session: + result = await session.run( + "\n".join( + [ + "workflows.create(", + " 'survives_rebind',", + " 'async def run() -> str:\\n return \"ok\"\\n',", + " 'Persisted in workspace storage',", + ")", + ] + ) + ) + assert result.error is None + original_session_id = session._executor._client.session_id + assert original_session_id is not None + + await asyncio.sleep(0.2) + + result = await session.run("workflows.get('survives_rebind') is not None") + assert result.error is None + assert result.value is True + assert session._executor._client.session_id is not None + assert session._executor._client.session_id != original_session_id + + +@pytest.mark.xdist_group("remote-redis") +class TestRemoteWorkspaceRedisConfig: + """Tests for remote Redis workspace configuration edge cases.""" + + def test_workspace_session_creation_rejects_scoped_redis_fallback( + self, + remote_workspace_redis_url: str, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Workspace-scoped sessions should fail closed without an unscoped Redis root.""" + try: + from fastapi.testclient import TestClient + except ImportError: + pytest.skip("FastAPI not installed") + + from py_code_mode.execution.container.server import create_app + + monkeypatch.setenv("REDIS_WORKFLOWS_PREFIX", "app:ws:client_a:workflows") + monkeypatch.setenv("REDIS_ARTIFACTS_PREFIX", "app:ws:client_a:artifacts") + monkeypatch.delenv("REDIS_TOOLS_PREFIX", raising=False) + monkeypatch.delenv("REDIS_DEPS_PREFIX", raising=False) + monkeypatch.delenv("STORAGE_PREFIX", raising=False) + + app = create_app( + SessionConfig( + redis_url=remote_workspace_redis_url, + auth_disabled=True, + ) + ) + + with TestClient(app) as client: + response = client.post("/sessions", json={"workspace_id": "client_b"}) + + assert response.status_code == 503 + assert "storage_prefix" in response.json()["detail"] + + +@pytest.mark.xdist_group("remote-redis") +class TestRemoteWorkspaceRedisUserJourney: + """Developer-facing remote Redis journeys for workspace scoping.""" + + @pytest.mark.asyncio + async def test_same_workspace_sessions_share_remote_redis_state( + self, remote_workspace_redis_url: str, live_redis_session_server_url: str + ) -> None: + """Two sessions in one workspace should share workflows and artifacts through Redis.""" + session_a = create_remote_redis_session( + remote_workspace_redis_url, + live_redis_session_server_url, + workspace_id="client_a", + ) + session_b = create_remote_redis_session( + remote_workspace_redis_url, + live_redis_session_server_url, + workspace_id="client_a", + ) + + async with session_a, session_b: + result = await session_a.run( + "\n".join( + [ + "workflows.create(", + " 'shared_remote_redis',", + " 'async def run() -> str:\\n return \"shared through redis\"\\n',", + " 'Shared remote Redis workflow',", + ")", + "artifacts.save(", + " 'shared_remote_redis.json',", + " {'workspace': 'client_a'},", + " description='shared',", + ")", + ] + ) + ) + assert result.error is None + + workflow = await session_b.get_workflow("shared_remote_redis") + assert workflow is not None + + artifact = await session_b.load_artifact("shared_remote_redis.json") + assert artifact == {"workspace": "client_a"} + + result = await session_b.run( + "[workflows.shared_remote_redis(), " + "artifacts.load('shared_remote_redis.json')['workspace']]" + ) + assert result.error is None + assert result.value == ["shared through redis", "client_a"] + + @pytest.mark.asyncio + async def test_different_workspaces_are_isolated_in_remote_redis( + self, remote_workspace_redis_url: str, live_redis_session_server_url: str + ) -> None: + """Different workspace IDs should isolate Redis-backed remote state and search.""" + session_a = create_remote_redis_session( + remote_workspace_redis_url, + live_redis_session_server_url, + workspace_id="client_a", + ) + session_b = create_remote_redis_session( + remote_workspace_redis_url, + live_redis_session_server_url, + workspace_id="client_b", + ) + + async with session_a, session_b: + result = await session_a.run( + "\n".join( + [ + "workflows.create(", + " 'campaign_metrics',", + " 'async def run() -> str:\\n return \"workspace a\"\\n',", + " 'Analyze campaign metrics and summarize ad performance',", + ")", + "artifacts.save(", + " 'campaign_metrics.json',", + " {'workspace': 'client_a'},", + " description='isolated',", + ")", + ] + ) + ) + assert result.error is None + + assert await session_b.get_workflow("campaign_metrics") is None + + result = await session_b.run( + "[workflows.get('campaign_metrics'), " + "artifacts.exists('campaign_metrics.json'), " + "workflows.search('campaign metrics ad performance')]" + ) + assert result.error is None + assert result.value == [None, False, []] + + @pytest.mark.asyncio + async def test_fresh_session_rejoins_same_workspace_in_remote_redis( + self, remote_workspace_redis_url: str, live_redis_session_server_url: str + ) -> None: + """A newly created session with the same workspace should rejoin shared Redis state.""" + async with create_remote_redis_session( + remote_workspace_redis_url, + live_redis_session_server_url, + workspace_id="client_a", + ) as writer: + result = await writer.run( + "\n".join( + [ + "workflows.create(", + " 'rejoin_remote_redis',", + " 'async def run() -> str:\\n return \"rejoined\"\\n',", + " 'Workflow for rejoin test',", + ")", + "artifacts.save(", + " 'rejoin_remote_redis.json',", + " {'status': 'persisted'},", + " description='persisted',", + ")", + ] + ) + ) + assert result.error is None + + async with create_remote_redis_session( + remote_workspace_redis_url, + live_redis_session_server_url, + workspace_id="client_a", + ) as reader: + workflow = await reader.get_workflow("rejoin_remote_redis") + assert workflow is not None + + artifact = await reader.load_artifact("rejoin_remote_redis.json") + assert artifact == {"status": "persisted"} + + result = await reader.run("workflows.rejoin_remote_redis()") + assert result.error is None + assert result.value == "rejoined" + + @pytest.mark.asyncio + async def test_scoped_and_legacy_remote_redis_namespaces_are_isolated( + self, remote_workspace_redis_url: str, live_redis_session_server_url: str + ) -> None: + """Scoped Redis workspaces should be isolated from the legacy default namespace.""" + legacy_session = create_remote_redis_session( + remote_workspace_redis_url, + live_redis_session_server_url, + ) + scoped_session = create_remote_redis_session( + remote_workspace_redis_url, + live_redis_session_server_url, + workspace_id="client_a", + ) + + async with legacy_session, scoped_session: + result = await legacy_session.run( + "\n".join( + [ + "workflows.create(", + " 'legacy_remote_redis',", + " 'async def run() -> str:\\n return \"legacy\"\\n',", + " 'Legacy default workflow',", + ")", + "artifacts.save(", + " 'legacy_remote_redis.json',", + " {'namespace': 'legacy'},", + " description='legacy',", + ")", + ] + ) + ) + assert result.error is None + + result = await scoped_session.run( + "[workflows.get('legacy_remote_redis'), " + "artifacts.exists('legacy_remote_redis.json')]" + ) + assert result.error is None + assert result.value == [None, False] + + result = await scoped_session.run( + "\n".join( + [ + "workflows.create(", + " 'scoped_remote_redis',", + " 'async def run() -> str:\\n return \"scoped\"\\n',", + " 'Scoped workflow',", + ")", + "artifacts.save(", + " 'scoped_remote_redis.json',", + " {'namespace': 'scoped'},", + " description='scoped',", + ")", + ] + ) + ) + assert result.error is None + + result = await legacy_session.run( + "[workflows.get('scoped_remote_redis'), " + "artifacts.exists('scoped_remote_redis.json')]" + ) + assert result.error is None + assert result.value == [None, False] + + @pytest.mark.asyncio + async def test_sessions_without_workspace_id_share_legacy_remote_redis_namespace( + self, remote_workspace_redis_url: str, live_redis_session_server_url: str + ) -> None: + """Two unscoped sessions should share the legacy default Redis namespace.""" + session_a = create_remote_redis_session( + remote_workspace_redis_url, + live_redis_session_server_url, + ) + session_b = create_remote_redis_session( + remote_workspace_redis_url, + live_redis_session_server_url, + ) + + async with session_a, session_b: + result = await session_a.run( + "\n".join( + [ + "workflows.create(", + " 'legacy_shared_remote_redis',", + " 'async def run() -> str:\\n return \"legacy shared\"\\n',", + " 'Legacy shared workflow',", + ")", + "artifacts.save(", + " 'legacy_shared_remote_redis.json',", + " {'namespace': 'legacy'},", + " description='legacy',", + ")", + ] + ) + ) + assert result.error is None + + result = await session_b.run( + "[workflows.legacy_shared_remote_redis(), " + "artifacts.load('legacy_shared_remote_redis.json')['namespace']]" + ) + assert result.error is None + assert result.value == ["legacy shared", "legacy"] diff --git a/tests/container/test_server.py b/tests/container/test_server.py index 1c7c929..58bab1d 100644 --- a/tests/container/test_server.py +++ b/tests/container/test_server.py @@ -1,10 +1,26 @@ """Tests for session server.""" +import asyncio +from unittest.mock import MagicMock + import pytest from py_code_mode.execution.container.config import SessionConfig +def create_bound_session(client, workspace_id: str | None = None) -> str: + """Create a session and return its bound session ID.""" + payload = {} if workspace_id is None else {"workspace_id": workspace_id} + response = client.post("/sessions", json=payload) + assert response.status_code == 200 + return response.json()["session_id"] + + +def session_headers(session_id: str) -> dict[str, str]: + """Build headers for a bound session request.""" + return {"X-Session-ID": session_id} + + class TestSessionConfig: """Tests for SessionConfig loading.""" @@ -81,7 +97,8 @@ def test_health_endpoint(self, client) -> None: def test_info_endpoint(self, client) -> None: """Info endpoint returns tools and workflows.""" - response = client.get("/info") + session_id = create_bound_session(client) + response = client.get("/info", headers=session_headers(session_id)) assert response.status_code == 200 data = response.json() @@ -91,7 +108,12 @@ def test_info_endpoint(self, client) -> None: def test_execute_simple_expression(self, client) -> None: """Can execute simple expression.""" - response = client.post("/execute", json={"code": "1 + 1"}) + session_id = create_bound_session(client) + response = client.post( + "/execute", + json={"code": "1 + 1"}, + headers=session_headers(session_id), + ) assert response.status_code == 200 data = response.json() @@ -101,7 +123,12 @@ def test_execute_simple_expression(self, client) -> None: def test_execute_with_stdout(self, client) -> None: """Captures stdout from print statements.""" - response = client.post("/execute", json={"code": "print('hello')"}) + session_id = create_bound_session(client) + response = client.post( + "/execute", + json={"code": "print('hello')"}, + headers=session_headers(session_id), + ) assert response.status_code == 200 data = response.json() @@ -110,7 +137,12 @@ def test_execute_with_stdout(self, client) -> None: def test_execute_with_error(self, client) -> None: """Returns error for invalid code.""" - response = client.post("/execute", json={"code": "1/0"}) + session_id = create_bound_session(client) + response = client.post( + "/execute", + json={"code": "1/0"}, + headers=session_headers(session_id), + ) assert response.status_code == 200 data = response.json() @@ -119,9 +151,11 @@ def test_execute_with_error(self, client) -> None: def test_execute_state_persists(self, client) -> None: """Variables persist across executions within same session.""" - create_response = client.post("/execute", json={"code": "x = 42"}) - session_id = create_response.json()["session_id"] - headers = {"X-Session-ID": session_id} + session_id = create_bound_session(client) + headers = session_headers(session_id) + + create_response = client.post("/execute", json={"code": "x = 42"}, headers=headers) + assert create_response.status_code == 200 # Access variable (same session) response = client.post("/execute", json={"code": "x * 2"}, headers=headers) @@ -133,9 +167,11 @@ def test_execute_state_persists(self, client) -> None: def test_reset_clears_state(self, client) -> None: """Reset clears session state.""" - create_response = client.post("/execute", json={"code": "x = 42"}) - session_id = create_response.json()["session_id"] - headers = {"X-Session-ID": session_id} + session_id = create_bound_session(client) + headers = session_headers(session_id) + + create_response = client.post("/execute", json={"code": "x = 42"}, headers=headers) + assert create_response.status_code == 200 # Reset this session response = client.post("/reset", headers=headers) @@ -156,6 +192,24 @@ def test_execute_with_unknown_session_id_returns_400(self, client) -> None: assert response.status_code == 400 assert response.json()["detail"] == "Invalid session ID" + def test_session_aware_endpoints_require_bound_session_id(self, client) -> None: + """Remote session-aware endpoints should fail closed without X-Session-ID.""" + execute_response = client.post("/execute", json={"code": "1 + 1"}) + assert execute_response.status_code == 400 + assert execute_response.json()["detail"] == "Invalid session ID" + + info_response = client.get("/info") + assert info_response.status_code == 400 + assert info_response.json()["detail"] == "Invalid session ID" + + workflows_response = client.get("/api/workflows") + assert workflows_response.status_code == 400 + assert workflows_response.json()["detail"] == "Invalid session ID" + + artifacts_response = client.get("/api/artifacts") + assert artifacts_response.status_code == 400 + assert artifacts_response.json()["detail"] == "Invalid session ID" + def test_reset_requires_known_session_id(self, client) -> None: """Reset rejects missing or unknown session IDs.""" missing_response = client.post("/reset") @@ -168,7 +222,12 @@ def test_reset_requires_known_session_id(self, client) -> None: def test_execute_returns_execution_time(self, client) -> None: """Execute response includes execution time.""" - response = client.post("/execute", json={"code": "1 + 1"}) + session_id = create_bound_session(client) + response = client.post( + "/execute", + json={"code": "1 + 1"}, + headers=session_headers(session_id), + ) data = response.json() assert "execution_time_ms" in data @@ -226,7 +285,8 @@ def client_with_tools(self, tmp_path, monkeypatch): def test_info_endpoint_includes_tools(self, client_with_tools) -> None: """Info endpoint lists tools loaded from TOOLS_PATH.""" - response = client_with_tools.get("/info") + session_id = create_bound_session(client_with_tools) + response = client_with_tools.get("/info", headers=session_headers(session_id)) assert response.status_code == 200 data = response.json() @@ -252,3 +312,83 @@ def test_session_creates_artifact_store(self, config_with_artifacts) -> None: # Session has a file artifact store pointing to artifacts_path assert isinstance(session.artifact_store, FileArtifactStore) assert "artifacts" in str(session.artifact_store._path) + + +class TestRedisDepsFallback: + """Tests for Redis deps initialization in the container server.""" + + def test_redis_deps_fallback_stays_unscoped_when_workflows_are_scoped( + self, monkeypatch, mock_redis + ) -> None: + """Deps fallback should use the root prefix, not the workspace-scoped workflows prefix.""" + from py_code_mode.execution.container import server as server_module + from py_code_mode.tools import ToolRegistry + + async def fake_registry_from_redis(_tool_store): + return ToolRegistry() + + monkeypatch.setenv("REDIS_URL", "redis://localhost:6379/0") + monkeypatch.setenv("REDIS_TOOLS_PREFIX", "app:tools") + monkeypatch.setenv("REDIS_WORKFLOWS_PREFIX", "app:ws:client_a:workflows") + monkeypatch.setenv("REDIS_ARTIFACTS_PREFIX", "app:ws:client_a:artifacts") + monkeypatch.delenv("REDIS_DEPS_PREFIX", raising=False) + + config = SessionConfig(auth_disabled=True) + + monkeypatch.setattr("redis.from_url", lambda _url: mock_redis) + monkeypatch.setattr( + "py_code_mode.storage.registry_from_redis", + fake_registry_from_redis, + ) + monkeypatch.setattr( + server_module, + "create_workflow_library", + lambda *, store: MagicMock( + refresh=lambda: None, list=lambda: [], search=lambda *_a, **_k: [] + ), + ) + + asyncio.run(server_module.initialize_server(config)) + + assert server_module._state.deps_store is not None + server_module._state.deps_store.add("requests") + + assert "requests" in mock_redis.smembers("app:deps") + assert mock_redis.smembers("app:ws:client_a:deps") == set() + + def test_explicit_redis_deps_prefix_takes_precedence(self, monkeypatch, mock_redis) -> None: + """Explicit REDIS_DEPS_PREFIX should override any fallback derivation.""" + from py_code_mode.execution.container import server as server_module + from py_code_mode.tools import ToolRegistry + + async def fake_registry_from_redis(_tool_store): + return ToolRegistry() + + monkeypatch.setenv("REDIS_URL", "redis://localhost:6379/0") + monkeypatch.setenv("REDIS_TOOLS_PREFIX", "app:tools") + monkeypatch.setenv("REDIS_WORKFLOWS_PREFIX", "app:ws:client_a:workflows") + monkeypatch.setenv("REDIS_ARTIFACTS_PREFIX", "app:ws:client_a:artifacts") + monkeypatch.setenv("REDIS_DEPS_PREFIX", "custom-root") + + config = SessionConfig(auth_disabled=True) + + monkeypatch.setattr("redis.from_url", lambda _url: mock_redis) + monkeypatch.setattr( + "py_code_mode.storage.registry_from_redis", + fake_registry_from_redis, + ) + monkeypatch.setattr( + server_module, + "create_workflow_library", + lambda *, store: MagicMock( + refresh=lambda: None, list=lambda: [], search=lambda *_a, **_k: [] + ), + ) + + asyncio.run(server_module.initialize_server(config)) + + assert server_module._state.deps_store is not None + server_module._state.deps_store.add("requests") + + assert "requests" in mock_redis.smembers("custom-root:deps") + assert mock_redis.smembers("app:deps") == set() diff --git a/tests/test_bootstrap.py b/tests/test_bootstrap.py index 4440f19..fb8ae40 100644 --- a/tests/test_bootstrap.py +++ b/tests/test_bootstrap.py @@ -4,9 +4,6 @@ 1. `to_bootstrap_config()` method on storage classes - serializes to dict 2. `bootstrap_namespaces(config)` function - reconstructs storage from config 3. Lazy connections - storage only connects when actually used - -TDD RED phase: These tests are written before implementation. -They will fail until the bootstrap module is implemented. """ from __future__ import annotations @@ -713,6 +710,67 @@ async def test_config_roundtrip(self, tmp_path: Path) -> None: assert workflow is not None assert workflow.name == "greet" + @pytest.mark.asyncio + async def test_config_roundtrip_preserves_workspace_scoping(self, tmp_path: Path) -> None: + """Scoped file storage roundtrip preserves scoped visibility.""" + from py_code_mode.bootstrap import bootstrap_namespaces + from py_code_mode.errors import ArtifactNotFoundError + from py_code_mode.storage import FileStorage + from py_code_mode.workflows import PythonWorkflow + + scoped_storage = FileStorage(tmp_path, workspace_id="client_a") + legacy_storage = FileStorage(tmp_path) + + scoped_storage.get_artifact_store().save( + "scoped.json", + {"scope": "client_a"}, + description="Scoped artifact", + ) + legacy_storage.get_artifact_store().save( + "legacy.json", + {"scope": "legacy"}, + description="Legacy artifact", + ) + + scoped_storage.get_workflow_store().save( + PythonWorkflow.from_source( + name="scoped_workflow", + source='async def run() -> str:\n return "scoped"', + description="Workflow visible only in client_a workspace", + ) + ) + legacy_storage.get_workflow_store().save( + PythonWorkflow.from_source( + name="legacy_workflow", + source='async def run() -> str:\n return "legacy"', + description="Workflow visible only in unscoped storage", + ) + ) + + config = scoped_storage.to_bootstrap_config() + bundle = await bootstrap_namespaces(config) + + assert bundle.artifacts.load("scoped.json") == {"scope": "client_a"} + with pytest.raises(ArtifactNotFoundError): + bundle.artifacts.load("legacy.json") + + assert bundle.workflows.library.get("scoped_workflow") is not None + assert bundle.workflows.library.get("legacy_workflow") is None + + @pytest.mark.asyncio + async def test_config_roundtrip_keeps_file_deps_unscoped(self, tmp_path: Path) -> None: + """Scoped file storage bootstrap keeps deps rooted at the unscoped base path.""" + from py_code_mode.bootstrap import bootstrap_namespaces + from py_code_mode.storage import FileStorage + + scoped_storage = FileStorage(tmp_path, workspace_id="client_a") + + bundle = await bootstrap_namespaces(scoped_storage.to_bootstrap_config()) + bundle.deps._store.add("requests>=2.0") + + assert (tmp_path / "deps" / "requirements.txt").read_text() == "requests>=2.0\n" + assert not (tmp_path / "workspaces" / "client_a" / "deps").exists() + # ============================================================================= # RedisStorage.to_bootstrap_config() Tests @@ -873,6 +931,75 @@ async def test_config_roundtrip(self, mock_redis: MockRedisClient) -> None: assert workflow is not None assert workflow.name == "greet" + @pytest.mark.asyncio + async def test_config_roundtrip_preserves_workspace_scoping( + self, mock_redis: MockRedisClient + ) -> None: + """Scoped Redis storage roundtrip preserves scoped visibility.""" + from py_code_mode.bootstrap import bootstrap_namespaces + from py_code_mode.errors import ArtifactNotFoundError + from py_code_mode.storage import RedisStorage + from py_code_mode.workflows import PythonWorkflow + + scoped_storage = RedisStorage(redis=mock_redis, prefix="app", workspace_id="client_a") + legacy_storage = RedisStorage(redis=mock_redis, prefix="app") + + scoped_storage.get_artifact_store().save( + "scoped.json", + {"scope": "client_a"}, + description="Scoped artifact", + ) + legacy_storage.get_artifact_store().save( + "legacy.json", + {"scope": "legacy"}, + description="Legacy artifact", + ) + + scoped_storage.get_workflow_store().save( + PythonWorkflow.from_source( + name="scoped_workflow", + source='async def run() -> str:\n return "scoped"', + description="Workflow visible only in client_a workspace", + ) + ) + legacy_storage.get_workflow_store().save( + PythonWorkflow.from_source( + name="legacy_workflow", + source='async def run() -> str:\n return "legacy"', + description="Workflow visible only in unscoped storage", + ) + ) + + config = scoped_storage.to_bootstrap_config() + + with patch("redis.Redis.from_url", return_value=mock_redis): + bundle = await bootstrap_namespaces(config) + + assert bundle.artifacts.load("scoped.json") == {"scope": "client_a"} + with pytest.raises(ArtifactNotFoundError): + bundle.artifacts.load("legacy.json") + + assert bundle.workflows.library.get("scoped_workflow") is not None + assert bundle.workflows.library.get("legacy_workflow") is None + + @pytest.mark.asyncio + async def test_config_roundtrip_keeps_redis_deps_unscoped( + self, mock_redis: MockRedisClient + ) -> None: + """Scoped Redis storage bootstrap keeps deps under the root Redis prefix.""" + from py_code_mode.bootstrap import bootstrap_namespaces + from py_code_mode.storage import RedisStorage + + scoped_storage = RedisStorage(redis=mock_redis, prefix="app", workspace_id="client_a") + + with patch("redis.Redis.from_url", return_value=mock_redis): + bundle = await bootstrap_namespaces(scoped_storage.to_bootstrap_config()) + + bundle.deps._store.add("requests>=2.0") + + assert "requests>=2.0" in mock_redis.smembers("app:deps") + assert mock_redis.smembers("app:ws:client_a:deps") == set() + # ============================================================================= # Lazy Connection Tests diff --git a/tests/test_workspace_scoped_storage.py b/tests/test_workspace_scoped_storage.py new file mode 100644 index 0000000..5edc731 --- /dev/null +++ b/tests/test_workspace_scoped_storage.py @@ -0,0 +1,530 @@ +"""Tests for workspace-scoped storage behavior.""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +import pytest + +from py_code_mode import Session +from py_code_mode.errors import ArtifactNotFoundError +from py_code_mode.storage import FileStorage, RedisStorage + +if TYPE_CHECKING: + from tests.conftest import MockRedisClient + + +SHARED_WORKFLOW_SOURCE = """async def run() -> str: + return "shared workflow" +""" + +SEARCHABLE_WORKFLOW_SOURCE = """async def run() -> str: + return "searchable workflow" +""" + + +class TestWorkspaceScopedFileStorageArtifacts: + """Session-facing artifact behavior with workspace-scoped FileStorage.""" + + @pytest.mark.asyncio + async def test_artifacts_are_isolated_between_workspaces(self, tmp_path: Path) -> None: + workspace_a = FileStorage(tmp_path, workspace_id="client_a") + workspace_b = FileStorage(tmp_path, workspace_id="client_b") + + async with Session(storage=workspace_a) as session_a: + await session_a.save_artifact( + name="campaign.json", + data={"workspace": "client_a"}, + description="Artifact scoped to client_a", + ) + + async with Session(storage=workspace_b) as session_b: + assert await session_b.list_artifacts() == [] + with pytest.raises(ArtifactNotFoundError): + await session_b.load_artifact("campaign.json") + + @pytest.mark.asyncio + async def test_artifacts_are_shared_by_separately_initialized_sessions_in_same_workspace( + self, tmp_path: Path + ) -> None: + writer_storage = FileStorage(tmp_path, workspace_id="shared_client") + reader_storage = FileStorage(tmp_path, workspace_id="shared_client") + + async with Session(storage=writer_storage) as writer: + await writer.save_artifact( + name="campaign.json", + data={"shared": True}, + description="Shared artifact", + ) + + async with Session(storage=reader_storage) as reader: + assert await reader.load_artifact("campaign.json") == {"shared": True} + artifacts = await reader.list_artifacts() + assert [artifact["name"] for artifact in artifacts] == ["campaign.json"] + + +class TestWorkspaceScopedFileStorageWorkflows: + """Session-facing workflow behavior with workspace-scoped FileStorage.""" + + @pytest.mark.asyncio + async def test_workflows_are_isolated_between_workspaces(self, tmp_path: Path) -> None: + workspace_a = FileStorage(tmp_path, workspace_id="client_a") + workspace_b = FileStorage(tmp_path, workspace_id="client_b") + + async with Session(storage=workspace_a) as session_a: + await session_a.add_workflow( + name="shared_campaign", + source=SHARED_WORKFLOW_SOURCE, + description="Workflow scoped to client_a", + ) + + async with Session(storage=workspace_b) as session_b: + assert await session_b.get_workflow("shared_campaign") is None + assert await session_b.list_workflows() == [] + + @pytest.mark.asyncio + async def test_workflows_are_shared_by_separately_initialized_sessions_in_same_workspace( + self, tmp_path: Path + ) -> None: + writer_storage = FileStorage(tmp_path, workspace_id="shared_client") + reader_storage = FileStorage(tmp_path, workspace_id="shared_client") + + async with Session(storage=writer_storage) as writer: + await writer.add_workflow( + name="shared_campaign", + source=SHARED_WORKFLOW_SOURCE, + description="Workflow visible within one workspace", + ) + + async with Session(storage=reader_storage) as reader: + workflow = await reader.get_workflow("shared_campaign") + assert workflow is not None + assert workflow["name"] == "shared_campaign" + + result = await reader.run("workflows.shared_campaign()") + assert result.is_ok + assert result.value == "shared workflow" + + @pytest.mark.asyncio + async def test_workflow_search_is_isolated_between_workspaces(self, tmp_path: Path) -> None: + workspace_a = FileStorage(tmp_path, workspace_id="client_a") + workspace_b = FileStorage(tmp_path, workspace_id="client_b") + + async with Session(storage=workspace_a) as session_a: + await session_a.add_workflow( + name="campaign_search", + source=SEARCHABLE_WORKFLOW_SOURCE, + description="Analyze campaign metrics and summarize ad performance", + ) + results = await session_a.search_workflows("campaign metrics ad performance") + assert [workflow["name"] for workflow in results] == ["campaign_search"] + + async with Session(storage=workspace_b) as session_b: + results = await session_b.search_workflows("campaign metrics ad performance") + assert results == [] + + +class TestWorkspaceScopedRedisStorageArtifacts: + """Session-facing artifact behavior with workspace-scoped RedisStorage.""" + + @pytest.mark.asyncio + async def test_artifacts_are_isolated_between_workspaces( + self, mock_redis: MockRedisClient + ) -> None: + workspace_a = RedisStorage(redis=mock_redis, prefix="app", workspace_id="client_a") + workspace_b = RedisStorage(redis=mock_redis, prefix="app", workspace_id="client_b") + + async with Session(storage=workspace_a) as session_a: + await session_a.save_artifact( + name="campaign.json", + data={"workspace": "client_a"}, + description="Artifact scoped to client_a", + ) + + async with Session(storage=workspace_b) as session_b: + assert await session_b.list_artifacts() == [] + with pytest.raises(ArtifactNotFoundError): + await session_b.load_artifact("campaign.json") + + @pytest.mark.asyncio + async def test_artifacts_are_shared_by_separately_initialized_sessions_in_same_workspace( + self, mock_redis: MockRedisClient + ) -> None: + writer_storage = RedisStorage( + redis=mock_redis, + prefix="app", + workspace_id="shared_client", + ) + reader_storage = RedisStorage( + redis=mock_redis, + prefix="app", + workspace_id="shared_client", + ) + + async with Session(storage=writer_storage) as writer: + await writer.save_artifact( + name="campaign.json", + data={"shared": True}, + description="Shared artifact", + ) + + async with Session(storage=reader_storage) as reader: + assert await reader.load_artifact("campaign.json") == {"shared": True} + artifacts = await reader.list_artifacts() + assert [artifact["name"] for artifact in artifacts] == ["campaign.json"] + + +class TestWorkspaceScopedRedisStorageWorkflows: + """Session-facing workflow behavior with workspace-scoped RedisStorage.""" + + @pytest.mark.asyncio + async def test_workflows_are_isolated_between_workspaces( + self, mock_redis: MockRedisClient + ) -> None: + workspace_a = RedisStorage(redis=mock_redis, prefix="app", workspace_id="client_a") + workspace_b = RedisStorage(redis=mock_redis, prefix="app", workspace_id="client_b") + + async with Session(storage=workspace_a) as session_a: + await session_a.add_workflow( + name="shared_campaign", + source=SHARED_WORKFLOW_SOURCE, + description="Workflow scoped to client_a", + ) + + async with Session(storage=workspace_b) as session_b: + assert await session_b.get_workflow("shared_campaign") is None + assert await session_b.list_workflows() == [] + + @pytest.mark.asyncio + async def test_workflows_are_shared_by_separately_initialized_sessions_in_same_workspace( + self, mock_redis: MockRedisClient + ) -> None: + writer_storage = RedisStorage( + redis=mock_redis, + prefix="app", + workspace_id="shared_client", + ) + reader_storage = RedisStorage( + redis=mock_redis, + prefix="app", + workspace_id="shared_client", + ) + + async with Session(storage=writer_storage) as writer: + await writer.add_workflow( + name="shared_campaign", + source=SHARED_WORKFLOW_SOURCE, + description="Workflow visible within one workspace", + ) + + async with Session(storage=reader_storage) as reader: + workflow = await reader.get_workflow("shared_campaign") + assert workflow is not None + assert workflow["name"] == "shared_campaign" + + result = await reader.run("workflows.shared_campaign()") + assert result.is_ok + assert result.value == "shared workflow" + + @pytest.mark.asyncio + async def test_workflow_search_is_isolated_between_workspaces( + self, mock_redis: MockRedisClient + ) -> None: + workspace_a = RedisStorage(redis=mock_redis, prefix="app", workspace_id="client_a") + workspace_b = RedisStorage(redis=mock_redis, prefix="app", workspace_id="client_b") + + async with Session(storage=workspace_a) as session_a: + await session_a.add_workflow( + name="campaign_search", + source=SEARCHABLE_WORKFLOW_SOURCE, + description="Analyze campaign metrics and summarize ad performance", + ) + results = await session_a.search_workflows("campaign metrics ad performance") + assert [workflow["name"] for workflow in results] == ["campaign_search"] + + async with Session(storage=workspace_b) as session_b: + results = await session_b.search_workflows("campaign metrics ad performance") + assert results == [] + + +class TestWorkspaceScopedStorageDefaults: + """Expected behavior when workspace_id is omitted.""" + + @pytest.mark.asyncio + async def test_omitting_workspace_id_preserves_current_unscoped_session_behavior( + self, tmp_path: Path + ) -> None: + writer_storage = FileStorage(tmp_path) + reader_storage = FileStorage(tmp_path, workspace_id=None) + + async with Session(storage=writer_storage) as writer: + await writer.save_artifact( + name="legacy.json", + data={"mode": "legacy"}, + description="Legacy unscoped artifact", + ) + + async with Session(storage=reader_storage) as reader: + assert await reader.load_artifact("legacy.json") == {"mode": "legacy"} + + @pytest.mark.asyncio + async def test_omitting_workspace_id_preserves_current_unscoped_redis_behavior( + self, mock_redis: MockRedisClient + ) -> None: + writer_storage = RedisStorage(redis=mock_redis, prefix="app") + reader_storage = RedisStorage(redis=mock_redis, prefix="app", workspace_id=None) + + async with Session(storage=writer_storage) as writer: + await writer.save_artifact( + name="legacy.json", + data={"mode": "legacy"}, + description="Legacy unscoped artifact", + ) + + async with Session(storage=reader_storage) as reader: + assert await reader.load_artifact("legacy.json") == {"mode": "legacy"} + + @pytest.mark.asyncio + async def test_omitting_workspace_id_preserves_current_unscoped_workflow_behavior( + self, tmp_path: Path + ) -> None: + writer_storage = FileStorage(tmp_path) + reader_storage = FileStorage(tmp_path, workspace_id=None) + + async with Session(storage=writer_storage) as writer: + await writer.add_workflow( + name="legacy_workflow", + source=SHARED_WORKFLOW_SOURCE, + description="Legacy unscoped workflow", + ) + + async with Session(storage=reader_storage) as reader: + workflow = await reader.get_workflow("legacy_workflow") + assert workflow is not None + result = await reader.run("workflows.legacy_workflow()") + assert result.is_ok + assert result.value == "shared workflow" + + @pytest.mark.asyncio + async def test_omitting_workspace_id_preserves_current_unscoped_redis_workflow_behavior( + self, mock_redis: MockRedisClient + ) -> None: + writer_storage = RedisStorage(redis=mock_redis, prefix="app") + reader_storage = RedisStorage(redis=mock_redis, prefix="app", workspace_id=None) + + async with Session(storage=writer_storage) as writer: + await writer.add_workflow( + name="legacy_workflow", + source=SHARED_WORKFLOW_SOURCE, + description="Legacy unscoped workflow", + ) + + async with Session(storage=reader_storage) as reader: + workflow = await reader.get_workflow("legacy_workflow") + assert workflow is not None + result = await reader.run("workflows.legacy_workflow()") + assert result.is_ok + assert result.value == "shared workflow" + + @pytest.mark.asyncio + async def test_scoped_file_storage_does_not_see_unscoped_artifacts( + self, tmp_path: Path + ) -> None: + legacy_storage = FileStorage(tmp_path) + scoped_storage = FileStorage(tmp_path, workspace_id="client_a") + + async with Session(storage=legacy_storage) as legacy: + await legacy.save_artifact( + name="legacy.json", + data={"mode": "legacy"}, + description="Legacy artifact", + ) + + async with Session(storage=scoped_storage) as scoped: + assert await scoped.list_artifacts() == [] + with pytest.raises(ArtifactNotFoundError): + await scoped.load_artifact("legacy.json") + + @pytest.mark.asyncio + async def test_unscoped_file_storage_does_not_see_scoped_artifacts( + self, tmp_path: Path + ) -> None: + legacy_storage = FileStorage(tmp_path) + scoped_storage = FileStorage(tmp_path, workspace_id="client_a") + + async with Session(storage=scoped_storage) as scoped: + await scoped.save_artifact( + name="workspace.json", + data={"mode": "scoped"}, + description="Scoped artifact", + ) + + async with Session(storage=legacy_storage) as legacy: + assert await legacy.list_artifacts() == [] + with pytest.raises(ArtifactNotFoundError): + await legacy.load_artifact("workspace.json") + + @pytest.mark.asyncio + async def test_scoped_file_storage_does_not_see_unscoped_workflows( + self, tmp_path: Path + ) -> None: + legacy_storage = FileStorage(tmp_path) + scoped_storage = FileStorage(tmp_path, workspace_id="client_a") + + async with Session(storage=legacy_storage) as legacy: + await legacy.add_workflow( + name="legacy_workflow", + source=SHARED_WORKFLOW_SOURCE, + description="Legacy unscoped workflow", + ) + + async with Session(storage=scoped_storage) as scoped: + assert await scoped.get_workflow("legacy_workflow") is None + assert await scoped.search_workflows("legacy unscoped workflow") == [] + + @pytest.mark.asyncio + async def test_scoped_redis_storage_does_not_see_unscoped_artifacts( + self, mock_redis: MockRedisClient + ) -> None: + legacy_storage = RedisStorage(redis=mock_redis, prefix="app") + scoped_storage = RedisStorage(redis=mock_redis, prefix="app", workspace_id="client_a") + + async with Session(storage=legacy_storage) as legacy: + await legacy.save_artifact( + name="legacy.json", + data={"mode": "legacy"}, + description="Legacy artifact", + ) + + async with Session(storage=scoped_storage) as scoped: + assert await scoped.list_artifacts() == [] + with pytest.raises(ArtifactNotFoundError): + await scoped.load_artifact("legacy.json") + + @pytest.mark.asyncio + async def test_unscoped_redis_storage_does_not_see_scoped_artifacts( + self, mock_redis: MockRedisClient + ) -> None: + legacy_storage = RedisStorage(redis=mock_redis, prefix="app") + scoped_storage = RedisStorage(redis=mock_redis, prefix="app", workspace_id="client_a") + + async with Session(storage=scoped_storage) as scoped: + await scoped.save_artifact( + name="workspace.json", + data={"mode": "scoped"}, + description="Scoped artifact", + ) + + async with Session(storage=legacy_storage) as legacy: + assert await legacy.list_artifacts() == [] + with pytest.raises(ArtifactNotFoundError): + await legacy.load_artifact("workspace.json") + + @pytest.mark.asyncio + async def test_scoped_redis_storage_does_not_see_unscoped_workflows( + self, mock_redis: MockRedisClient + ) -> None: + legacy_storage = RedisStorage(redis=mock_redis, prefix="app") + scoped_storage = RedisStorage(redis=mock_redis, prefix="app", workspace_id="client_a") + + async with Session(storage=legacy_storage) as legacy: + await legacy.add_workflow( + name="legacy_workflow", + source=SHARED_WORKFLOW_SOURCE, + description="Legacy unscoped workflow", + ) + + async with Session(storage=scoped_storage) as scoped: + assert await scoped.get_workflow("legacy_workflow") is None + assert await scoped.search_workflows("legacy unscoped workflow") == [] + + def test_file_storage_without_workspace_id_uses_legacy_layout(self, tmp_path: Path) -> None: + storage = FileStorage(tmp_path, workspace_id=None) + + access = storage.get_serializable_access() + + assert access.workflows_path == tmp_path / "workflows" + assert access.artifacts_path == tmp_path / "artifacts" + if access.vectors_path is not None: + assert access.vectors_path == tmp_path / "vectors" + + def test_file_storage_workspace_id_scopes_paths(self, tmp_path: Path) -> None: + storage = FileStorage(tmp_path, workspace_id="client_a") + + access = storage.get_serializable_access() + + assert access.workflows_path == tmp_path / "workspaces" / "client_a" / "workflows" + assert access.artifacts_path == tmp_path / "workspaces" / "client_a" / "artifacts" + if access.vectors_path is not None: + assert access.vectors_path == tmp_path / "workspaces" / "client_a" / "vectors" + + def test_redis_storage_workspace_id_scopes_prefixes(self, mock_redis: MockRedisClient) -> None: + storage = RedisStorage( + redis=mock_redis, + prefix="app", + workspace_id="client_a", + ) + + access = storage.get_serializable_access() + + assert access.workflows_prefix == "app:ws:client_a:workflows" + assert access.artifacts_prefix == "app:ws:client_a:artifacts" + if access.vectors_prefix is not None: + assert access.vectors_prefix == "app:ws:client_a:vectors" + + def test_redis_storage_without_workspace_id_preserves_current_prefixes( + self, mock_redis: MockRedisClient + ) -> None: + storage = RedisStorage( + redis=mock_redis, + prefix="app", + workspace_id=None, + ) + + access = storage.get_serializable_access() + + assert access.workflows_prefix == "app:workflows" + assert access.artifacts_prefix == "app:artifacts" + if access.vectors_prefix is not None: + assert access.vectors_prefix == "app:vectors" + + +class TestWorkspaceScopedBootstrapConfig: + """Bootstrap config preserves workspace scope without re-scoping deps roots.""" + + def test_file_storage_workspace_id_is_serialized_explicitly(self, tmp_path: Path) -> None: + storage = FileStorage(tmp_path, workspace_id="client_a") + + config = storage.to_bootstrap_config() + + assert config["base_path"] == str(tmp_path) + assert config["workspace_id"] == "client_a" + + def test_redis_storage_workspace_id_is_serialized_explicitly( + self, mock_redis: MockRedisClient + ) -> None: + storage = RedisStorage(redis=mock_redis, prefix="app", workspace_id="client_a") + + config = storage.to_bootstrap_config() + + assert config["prefix"] == "app" + assert config["workspace_id"] == "client_a" + + +class TestWorkspaceIdValidation: + """Validation behavior for workspace identifiers.""" + + @pytest.mark.parametrize("workspace_id", ["", ".", "..", "../escape", "bad/name", r"bad\\name"]) + def test_file_storage_rejects_invalid_workspace_ids( + self, tmp_path: Path, workspace_id: str + ) -> None: + with pytest.raises(ValueError): + FileStorage(tmp_path, workspace_id=workspace_id) + + @pytest.mark.parametrize("workspace_id", ["", ".", "..", "../escape", "bad/name", "bad:name"]) + def test_redis_storage_rejects_invalid_workspace_ids( + self, mock_redis: MockRedisClient, workspace_id: str + ) -> None: + with pytest.raises(ValueError): + RedisStorage(redis=mock_redis, prefix="app", workspace_id=workspace_id)