From 2b23bbc3fdb08b2eb66f52f23bc95040d05d7ed6 Mon Sep 17 00:00:00 2001 From: Rauf Akdemir Date: Fri, 10 Apr 2026 17:24:56 +0200 Subject: [PATCH 01/25] feat(sharepoint): client credentials auth + ACL-filtered search (#1750) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(search): add ACL filtering to V2 search + fix SP Online access control - Add acl_principals param to VectorDB protocol and Vespa client - Executor resolves user principals via AccessBroker, builds ACL YQL clause - Admin search-as-user endpoints for all 3 tiers (instant/classic/agentic) - Fail-closed transformer: AC sources without access data default to invisible - SP Online: set access on site, drive, and page entities (drive root permissions) - SP Online: remove dead build_item_entity code - SP token exchange: implement OAuthTokenProvider.get_token_for_resource() - OAuth2Service.exchange_token_for_scope() for resource-scoped token exchange - Fix OAuth callback to pass stored config_fields to lifecycle.validate() - Fix pre-existing D301 ruff error in vespa transformer * feat(sharepoint): enforce required site_url + track Entra groups on all entity types - Make site_url required in SharePointOnlineConfig (one site per connection) - Fix validate_config to validate required fields even when config is empty - Add _track_entity_groups for site, drive, and page entities (enables Entra group expansion) - Update config SSRF tests for new required site_url behavior * feat(sharepoint): switch to application-level client credentials auth - Remove OAuth auth methods, use DIRECT only with client credentials - Add SharePointOnlineAppAuthConfig (tenant_id, client_id, client_secret, private_key) - Implement client_credentials token exchange for Graph API - Implement certificate-based JWT assertion for SP REST API tokens - Use getAllSites for complete site discovery (app permissions) - Add Prefer: deltashowsharingchanges headers for permission change tracking in delta - Make site_url optional (empty = sync all sites in tenant) - Simplify all token providers, group expander, and _get() to single auth path * fix(sharepoint): fix file downloads for client-credentials auth DirectCredentialProvider has no get_token(), so FileService raised "No access token available" for every file download. Two fixes: 1. Recognize SharePoint tempauth= URLs as pre-signed (skip auth header) 2. Fall back to a Graph bearer token via StaticTokenProvider when the download URL is a Graph API content endpoint Tested: 18/18 files now sync successfully with client credentials. Co-Authored-By: Claude Opus 4.6 * feat(sharepoint): split into OAuth + client credentials sources with shared base Refactors the single SharePointOnlineSource into two separate sources: - sharepoint_online (OAuth, delegated user auth) — original behavior - sharepoint_online_app (client credentials, app-only auth) — new Both inherit from SharePointOnlineBase which contains all shared sync, browse tree, file download, and ACL logic. Subclasses only override auth-specific hooks: token acquisition, 401 handling, SP REST token exchange, file download auth, and site discovery strategy. Also fixes SP site group expansion by using the actual uploaded certificate PEM for x5t thumbprint computation instead of generating a new self-signed cert each call (which caused thumbprint mismatch with Azure AD). Co-Authored-By: Claude Opus 4.6 (1M context) --------- Co-authored-by: Daan Manneke Co-authored-by: Claude Opus 4.6 --- backend/airweave/api/v1/endpoints/search.py | 96 ++- backend/airweave/core/container/factory.py | 3 + .../domains/oauth/callback_service.py | 3 + .../airweave/domains/oauth/oauth2_service.py | 76 +++ backend/airweave/domains/oauth/protocols.py | 18 + .../adapters/vector_db/fakes/vector_db.py | 1 + .../search/adapters/vector_db/protocol.py | 7 +- .../search/adapters/vector_db/vespa_client.py | 30 +- .../airweave/domains/search/agentic/agent.py | 13 +- .../domains/search/agentic/service.py | 5 +- .../domains/search/agentic/tools/search.py | 3 + .../domains/search/classic/service.py | 7 +- backend/airweave/domains/search/executor.py | 79 ++- .../airweave/domains/search/fakes/executor.py | 1 + .../domains/search/instant/service.py | 7 +- backend/airweave/domains/search/protocols.py | 6 +- .../domains/search/tests/test_executor.py | 4 + .../domains/sources/token_providers/oauth.py | 25 + .../airweave/domains/sources/validation.py | 10 +- .../airweave/domains/storage/file_service.py | 2 +- .../sync_pipeline/builders/destinations.py | 9 +- .../airweave/domains/sync_pipeline/factory.py | 4 + backend/airweave/platform/configs/auth.py | 46 ++ backend/airweave/platform/configs/config.py | 16 +- .../destinations/vespa/destination.py | 2 + .../destinations/vespa/transformer.py | 15 +- backend/airweave/platform/sources/__init__.py | 3 +- .../sources/sharepoint_online/__init__.py | 8 +- .../sources/sharepoint_online/builders.py | 45 +- .../sources/sharepoint_online/client.py | 37 +- .../sources/sharepoint_online/source.py | 638 ++++++++++++++---- .../unit/platform/configs/test_config_ssrf.py | 16 +- 32 files changed, 1019 insertions(+), 216 deletions(-) diff --git a/backend/airweave/api/v1/endpoints/search.py b/backend/airweave/api/v1/endpoints/search.py index 9a8aa5c19..a9c6b5c74 100644 --- a/backend/airweave/api/v1/endpoints/search.py +++ b/backend/airweave/api/v1/endpoints/search.py @@ -10,7 +10,7 @@ import json from collections.abc import AsyncGenerator -from fastapi import Depends, Path +from fastapi import Depends, Path, Query from sqlalchemy.ext.asyncio import AsyncSession from starlette.responses import StreamingResponse @@ -349,3 +349,97 @@ async def admin_stream_agentic_search( media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) + + +@admin_router.post( + "/{readable_id}/search/instant/as-user", + response_model=SearchV2Response, + summary="Instant Search as User (Admin)", + description=( + "Admin-only: Instant search with access control applied for a specific user principal." + ), +) +async def admin_instant_search_as_user( + readable_id: str = Path(...), + request: InstantSearchRequest = ..., # type: ignore[assignment] + user_principal: str = Query( + ..., + description="User principal (email) to search as. " + "Access control filtering will use this user's resolved group memberships.", + ), + db: AsyncSession = Depends(deps.get_db), + ctx: ApiContext = Depends(deps.get_context), + usage_checker: UsageLimitCheckerProtocol = Inject(UsageLimitCheckerProtocol), + service: InstantSearchServiceProtocol = Inject(InstantSearchServiceProtocol), +) -> SearchV2Response: + """Admin-only: instant search with ACL filtering for a specific user.""" + _require_admin(ctx) + await usage_checker.is_allowed(db, ctx.organization.id, ActionType.QUERIES) + + results = await service.search( + db, ctx, readable_id, request, user_principal_override=user_principal + ) + return SearchV2Response(results=results.results) + + +@admin_router.post( + "/{readable_id}/search/classic/as-user", + response_model=SearchV2Response, + summary="Classic Search as User (Admin)", + description=( + "Admin-only: Classic search with access control applied for a specific user principal." + ), +) +async def admin_classic_search_as_user( + readable_id: str = Path(...), + request: ClassicSearchRequest = ..., # type: ignore[assignment] + user_principal: str = Query( + ..., + description="User principal (email) to search as. " + "Access control filtering will use this user's resolved group memberships.", + ), + db: AsyncSession = Depends(deps.get_db), + ctx: ApiContext = Depends(deps.get_context), + usage_checker: UsageLimitCheckerProtocol = Inject(UsageLimitCheckerProtocol), + service: ClassicSearchServiceProtocol = Inject(ClassicSearchServiceProtocol), +) -> SearchV2Response: + """Admin-only: classic search with ACL filtering for a specific user.""" + _require_admin(ctx) + await usage_checker.is_allowed(db, ctx.organization.id, ActionType.QUERIES) + + results = await service.search( + db, ctx, readable_id, request, user_principal_override=user_principal + ) + return SearchV2Response(results=results.results) + + +@admin_router.post( + "/{readable_id}/search/agentic/as-user", + response_model=SearchV2Response, + summary="Agentic Search as User (Admin)", + description=( + "Admin-only: Agentic search with access control applied for a specific user principal." + ), +) +async def admin_agentic_search_as_user( + readable_id: str = Path(...), + request: AgenticSearchRequest = ..., # type: ignore[assignment] + user_principal: str = Query( + ..., + description="User principal (email) to search as. " + "Access control filtering will use this user's resolved group memberships.", + ), + db: AsyncSession = Depends(deps.get_db), + ctx: ApiContext = Depends(deps.get_context), + usage_checker: UsageLimitCheckerProtocol = Inject(UsageLimitCheckerProtocol), + service: AgenticSearchServiceProtocol = Inject(AgenticSearchServiceProtocol), +) -> SearchV2Response: + """Admin-only: agentic search with ACL filtering for a specific user.""" + _require_admin(ctx) + await usage_checker.is_allowed(db, ctx.organization.id, ActionType.TOKENS) + + results = await service.search( + db, ctx, readable_id, request, user_principal_override=user_principal + ) + truncated = results.results[: request.limit] if request.limit else results.results + return SearchV2Response(results=truncated) diff --git a/backend/airweave/core/container/factory.py b/backend/airweave/core/container/factory.py index 61565cc00..4d3d55ac4 100644 --- a/backend/airweave/core/container/factory.py +++ b/backend/airweave/core/container/factory.py @@ -496,6 +496,7 @@ def create_container(settings: Settings) -> Container: entity_definition_registry=source_deps["entity_definition_registry"], event_bus=event_bus, source_lifecycle=source_deps["source_lifecycle_service"], + access_broker=access_broker, ) # ----------------------------------------------------------------- @@ -1284,6 +1285,7 @@ def _create_search_services( entity_definition_registry: "EntityDefinitionRegistry", event_bus: "EventBus", source_lifecycle: "SourceLifecycleService", + access_broker: "AccessBroker", ) -> dict: """Create search domain services (LLM, tokenizer, reranker, metadata builder, per-tier). @@ -1351,6 +1353,7 @@ def _create_search_services( sc_repo=sc_repo, source_registry=source_registry, source_lifecycle=source_lifecycle, + access_broker=access_broker, ) # 6. Per-tier services diff --git a/backend/airweave/domains/oauth/callback_service.py b/backend/airweave/domains/oauth/callback_service.py index 8b6ee0650..a27f7eb3b 100644 --- a/backend/airweave/domains/oauth/callback_service.py +++ b/backend/airweave/domains/oauth/callback_service.py @@ -203,6 +203,7 @@ async def complete_oauth2_callback( await self._validate_oauth2_token_or_raise( source_entry=source_entry, access_token=token_response.access_token, + config=source_conn_shell.config_fields, ctx=ctx, ) @@ -591,6 +592,7 @@ async def _validate_oauth2_token_or_raise( *, source_entry: SourceRegistryEntry | None, access_token: str, + config: dict | None = None, ctx: ApiContext, ) -> None: """Validate OAuth2 token using source lifecycle service; fail callback if invalid.""" @@ -601,6 +603,7 @@ async def _validate_oauth2_token_or_raise( await self._source_lifecycle.validate( short_name=source_entry.short_name, credentials=access_token, + config=config, ) except (SourceNotFoundError, SourceError) as e: raise http_exception_for_credential_validation( diff --git a/backend/airweave/domains/oauth/oauth2_service.py b/backend/airweave/domains/oauth/oauth2_service.py index e53fa4e26..b5c2188ef 100644 --- a/backend/airweave/domains/oauth/oauth2_service.py +++ b/backend/airweave/domains/oauth/oauth2_service.py @@ -475,6 +475,82 @@ async def refresh_and_persist( expires_in=response.expires_in, ) + async def exchange_token_for_scope( + self, + db: AsyncSession, + integration_short_name: str, + connection_id: UUID, + ctx: ApiContext, + scope: str, + ) -> str: + """Exchange refresh token for an access token with a different scope. + + Uses the existing refresh token but requests a different resource scope + (e.g., SharePoint REST API scope instead of Graph scope). + Does NOT persist the rotated refresh token. + + Returns the access token string for the requested scope. + """ + connection = await self.conn_repo.get(db=db, id=connection_id, ctx=ctx) + if not connection or not connection.integration_credential_id: + raise OAuthRefreshCredentialMissingError( + f"Connection {connection_id} not found or has no credential", + integration_short_name=integration_short_name, + ) + + credential = await self.cred_repo.get( + db=db, id=connection.integration_credential_id, ctx=ctx + ) + if not credential: + raise OAuthRefreshCredentialMissingError( + "Integration credential not found", + integration_short_name=integration_short_name, + ) + + decrypted = self.encryptor.decrypt(credential.encrypted_credentials) + refresh_token = await self._get_refresh_token(ctx.logger, decrypted) + + integration_config = await self._get_integration_config(ctx.logger, integration_short_name) + + client_id, client_secret = await self._get_client_credentials( + integration_config, None, decrypted + ) + + # Build request with explicit scope (unlike normal refresh which skips scope) + headers = {"Content-Type": integration_config.content_type} + payload = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "scope": scope, + } + + if integration_config.client_credential_location == "header": + encoded = self._encode_client_credentials(client_id, client_secret) + headers["Authorization"] = f"Basic {encoded}" + else: + payload["client_id"] = client_id + payload["client_secret"] = client_secret + + ctx.logger.info( + f"Exchanging token for scope {scope} (integration={integration_short_name})" + ) + + response = await self._make_token_request( + ctx.logger, + integration_config.backend_url, + headers, + payload, + integration_short_name=integration_short_name, + ) + + # Parse response but do NOT persist the refresh token + token_response = OAuth2TokenResponse(**response.json()) + ctx.logger.info( + f"Successfully exchanged token for scope {scope} " + f"(expires_in={token_response.expires_in})" + ) + return str(token_response.access_token) + # ------------------------------------------------------------------ # Private helpers # ------------------------------------------------------------------ diff --git a/backend/airweave/domains/oauth/protocols.py b/backend/airweave/domains/oauth/protocols.py index ba55594e1..4faa4cf05 100644 --- a/backend/airweave/domains/oauth/protocols.py +++ b/backend/airweave/domains/oauth/protocols.py @@ -143,6 +143,24 @@ async def refresh_and_persist( """ ... + async def exchange_token_for_scope( + self, + db: AsyncSession, + integration_short_name: str, + connection_id: UUID, + ctx: ApiContext, + scope: str, + ) -> str: + """Exchange refresh token for an access token with a different scope. + + Uses the existing refresh token but requests a different resource scope. + Does NOT persist the rotated refresh token (the response token is + scoped to the new resource and should not replace the original). + + Returns the access token string for the requested scope. + """ + ... + # --------------------------------------------------------------------------- # Init session + redirect session repositories diff --git a/backend/airweave/domains/search/adapters/vector_db/fakes/vector_db.py b/backend/airweave/domains/search/adapters/vector_db/fakes/vector_db.py index 1a4a4312e..5c3c63e41 100644 --- a/backend/airweave/domains/search/adapters/vector_db/fakes/vector_db.py +++ b/backend/airweave/domains/search/adapters/vector_db/fakes/vector_db.py @@ -61,6 +61,7 @@ async def compile_query( plan: SearchPlan, embeddings: QueryEmbeddings, collection_id: str, + acl_principals: list[str] | None = None, ) -> CompiledQuery: """Return a fake compiled query, or raise seeded error.""" self._calls.append(("compile_query", plan, embeddings, collection_id)) diff --git a/backend/airweave/domains/search/adapters/vector_db/protocol.py b/backend/airweave/domains/search/adapters/vector_db/protocol.py index 2cefaabea..aa5f0dac2 100644 --- a/backend/airweave/domains/search/adapters/vector_db/protocol.py +++ b/backend/airweave/domains/search/adapters/vector_db/protocol.py @@ -1,6 +1,6 @@ """Vector database protocol for the search module.""" -from typing import Protocol +from typing import Optional, Protocol from airweave.domains.search.types.embeddings import QueryEmbeddings from airweave.domains.search.types.filters import FilterGroup @@ -23,6 +23,7 @@ async def compile_query( plan: SearchPlan, embeddings: QueryEmbeddings, collection_id: str, + acl_principals: Optional[list[str]] = None, ) -> CompiledQuery: """Compile plan and embeddings into a DB-specific query. @@ -30,6 +31,10 @@ async def compile_query( plan: Search plan with queries, filters, strategy, pagination. embeddings: Dense and sparse embeddings for the queries. collection_id: Collection readable ID for tenant filtering. + acl_principals: Resolved user principals for access control filtering. + None = no AC sources in collection (skip filtering). + [] = user has no principals (only public entities visible). + ["user:x", "group:y"] = match these principals. Returns: CompiledQuery with raw (full) and display (no embeddings) versions. diff --git a/backend/airweave/domains/search/adapters/vector_db/vespa_client.py b/backend/airweave/domains/search/adapters/vector_db/vespa_client.py index 9a864fafd..9dfcb2704 100644 --- a/backend/airweave/domains/search/adapters/vector_db/vespa_client.py +++ b/backend/airweave/domains/search/adapters/vector_db/vespa_client.py @@ -87,9 +87,10 @@ async def compile_query( plan: SearchPlan, embeddings: QueryEmbeddings, collection_id: str, + acl_principals: Optional[list[str]] = None, ) -> CompiledQuery: """Compile plan and embeddings into Vespa query.""" - yql = self._build_yql(plan, collection_id) + yql = self._build_yql(plan, collection_id, acl_principals=acl_principals) params = self._build_params(plan, embeddings) raw_query = {"yql": yql, "params": params} @@ -236,7 +237,12 @@ async def close(self) -> None: # YQL Building # ========================================================================= - def _build_yql(self, plan: SearchPlan, collection_id: str) -> str: + def _build_yql( + self, + plan: SearchPlan, + collection_id: str, + acl_principals: Optional[list[str]] = None, + ) -> str: """Build the complete YQL query string.""" num_embeddings = self._count_dense_embeddings(plan) retrieval_clause = self._build_retrieval_clause(plan.retrieval_strategy, num_embeddings) @@ -250,11 +256,31 @@ def _build_yql(self, plan: SearchPlan, collection_id: str) -> str: if filter_yql: where_parts.append(f"({filter_yql})") + acl_yql = self._build_acl_clause(acl_principals) + if acl_yql: + where_parts.append(f"({acl_yql})") + all_schemas = ", ".join(ALL_VESPA_SCHEMAS) yql = f"select * from sources {all_schemas} where {' AND '.join(where_parts)}" return yql + def _build_acl_clause(self, acl_principals: Optional[list[str]]) -> Optional[str]: + """Build access control YQL clause from resolved principals. + + Returns None if acl_principals is None (no AC sources → skip filtering). + Returns a clause that matches public entities or entities with matching viewers. + """ + if acl_principals is None: + return None + + clauses = ["access_is_public = true"] + for principal in acl_principals: + escaped = principal.replace("\\", "\\\\").replace("'", "\\'") + clauses.append(f"access_viewers contains '{escaped}'") + + return " OR ".join(clauses) + def _build_retrieval_clause( self, strategy: RetrievalStrategy, diff --git a/backend/airweave/domains/search/agentic/agent.py b/backend/airweave/domains/search/agentic/agent.py index 8097ebed1..75b0da96a 100644 --- a/backend/airweave/domains/search/agentic/agent.py +++ b/backend/airweave/domains/search/agentic/agent.py @@ -128,11 +128,13 @@ async def run( ctx: ApiContext, readable_id: str, request: AgenticSearchRequest, + user_principal_override: str | None = None, ) -> SearchResults: """Run the agent loop. Emits events throughout. Returns collected results.""" start_time = time.monotonic() state = AgentState() diag = _DiagnosticsAccumulator() + self._user_principal_override = user_principal_override ctx.logger.info( f"Agentic search started collection={readable_id} query={request.query!r} " f"thinking={request.thinking}" @@ -208,7 +210,14 @@ async def _run( # noqa: C901 — agent loop orchestration is inherently complex thinking_enabled = request.thinking # Construct per-request tools - dispatcher = self._build_dispatcher(collection_id, user_filter, db, ctx, readable_id) + dispatcher = self._build_dispatcher( + collection_id, + user_filter, + db, + ctx, + readable_id, + user_principal=self._user_principal_override, + ) context_mgr = ContextManager( tokenizer=self._tokenizer, @@ -587,6 +596,7 @@ def _build_dispatcher( db: AsyncSession, ctx: ApiContext, collection_readable_id: str, + user_principal: str | None = None, ) -> ToolDispatcher: """Construct tools and dispatcher for this request.""" return ToolDispatcher( @@ -598,6 +608,7 @@ def _build_dispatcher( db=db, ctx=ctx, collection_readable_id=collection_readable_id, + user_principal=user_principal, ), ToolName.READ: ReadTool( vector_db=self._vector_db, diff --git a/backend/airweave/domains/search/agentic/service.py b/backend/airweave/domains/search/agentic/service.py index bd167cd36..7f15fd193 100644 --- a/backend/airweave/domains/search/agentic/service.py +++ b/backend/airweave/domains/search/agentic/service.py @@ -72,6 +72,7 @@ async def search( ctx: ApiContext, readable_id: str, request: AgenticSearchRequest, + user_principal_override: str | None = None, ) -> SearchResults: """Run agentic search and return results.""" agent = Agent( @@ -85,4 +86,6 @@ async def search( event_bus=self._event_bus, config=SearchConfig(), ) - return await agent.run(db, ctx, readable_id, request) + return await agent.run( + db, ctx, readable_id, request, user_principal_override=user_principal_override + ) diff --git a/backend/airweave/domains/search/agentic/tools/search.py b/backend/airweave/domains/search/agentic/tools/search.py index 3818d6311..20a6e53c1 100644 --- a/backend/airweave/domains/search/agentic/tools/search.py +++ b/backend/airweave/domains/search/agentic/tools/search.py @@ -43,6 +43,7 @@ def __init__( db: AsyncSession, ctx: ApiContext, collection_readable_id: str, + user_principal: str | None = None, ) -> None: """Initialize with executor, user filter, collection ID, and request context.""" self._executor = executor @@ -51,6 +52,7 @@ def __init__( self._db = db self._ctx = ctx self._collection_readable_id = collection_readable_id + self._user_principal = user_principal async def execute( self, @@ -67,6 +69,7 @@ async def execute( db=self._db, ctx=self._ctx, collection_readable_id=self._collection_readable_id, + user_principal=self._user_principal, ) # Track new results in state diff --git a/backend/airweave/domains/search/classic/service.py b/backend/airweave/domains/search/classic/service.py index f8ed42a94..b2f98db5f 100644 --- a/backend/airweave/domains/search/classic/service.py +++ b/backend/airweave/domains/search/classic/service.py @@ -68,13 +68,16 @@ async def search( ctx: ApiContext, readable_id: str, request: ClassicSearchRequest, + user_principal_override: str | None = None, ) -> SearchResults: """Generate strategy via LLM, execute, optionally rerank, return results.""" start_time = time.monotonic() ctx.logger.info(f"Classic search started collection={readable_id} query={request.query!r}") try: - result = await self._execute(db, ctx, readable_id, request, start_time) + result = await self._execute( + db, ctx, readable_id, request, start_time, user_principal_override + ) duration_ms = int((time.monotonic() - start_time) * 1000) ctx.logger.info( f"Classic search completed collection={readable_id} " @@ -106,6 +109,7 @@ async def _execute( readable_id: str, request: ClassicSearchRequest, start_time: float, + user_principal_override: str | None = None, ) -> SearchResults: """Internal execution — resolve collection, LLM strategy, search, rerank.""" # 1. Resolve collection @@ -157,6 +161,7 @@ async def _execute( db=db, ctx=ctx, collection_readable_id=readable_id, + user_principal=user_principal_override, ) # 6. Optional rerank diff --git a/backend/airweave/domains/search/executor.py b/backend/airweave/domains/search/executor.py index 05c4c8ad4..8869e81f1 100644 --- a/backend/airweave/domains/search/executor.py +++ b/backend/airweave/domains/search/executor.py @@ -9,12 +9,13 @@ import asyncio from datetime import datetime -from typing import Any +from typing import Any, Optional from uuid import UUID from sqlalchemy.ext.asyncio import AsyncSession from airweave.api.context import ApiContext +from airweave.domains.access_control.protocols import AccessBrokerProtocol from airweave.domains.embedders.protocols import DenseEmbedderProtocol, SparseEmbedderProtocol from airweave.domains.search.adapters.vector_db.protocol import VectorDBProtocol from airweave.domains.search.builders.search_plan import SearchPlanBuilder @@ -67,6 +68,7 @@ def __init__( sc_repo: SourceConnectionRepositoryProtocol, source_registry: SourceRegistryProtocol, source_lifecycle: SourceLifecycleServiceProtocol, + access_broker: AccessBrokerProtocol, ) -> None: """Initialize with embedders, vector database, and federated source dependencies.""" self._dense_embedder = dense_embedder @@ -75,6 +77,7 @@ def __init__( self._sc_repo = sc_repo self._source_registry = source_registry self._source_lifecycle = source_lifecycle + self._access_broker = access_broker async def execute( self, @@ -84,8 +87,14 @@ async def execute( db: AsyncSession, ctx: ApiContext, collection_readable_id: str, + user_principal: Optional[str] = None, ) -> SearchResults: """Execute the full search pipeline including federated sources.""" + # 0. Resolve access control principals + acl_principals = await self._resolve_acl_principals( + db, ctx, user_principal, collection_readable_id + ) + # 1. Merge plan filters with user filters complete_plan = SearchPlanBuilder.build(plan, user_filter) @@ -106,7 +115,9 @@ async def execute( # 4. Run vector DB search and federated search in parallel fetch_limit = original_offset + original_limit - vector_task = asyncio.create_task(self._execute_vector_search(complete_plan, collection_id)) + vector_task = asyncio.create_task( + self._execute_vector_search(complete_plan, collection_id, acl_principals) + ) fed_task = None if federated_sources: @@ -130,6 +141,10 @@ async def execute( # Filter federated results in-memory and merge, or slice vector-only. fed_filtered = self._apply_filters_in_memory(fed_results, complete_plan.filter_groups) + # Also apply ACL filtering to federated results in-memory + if acl_principals is not None: + fed_filtered = self._apply_acl_in_memory(fed_filtered, acl_principals) + if fed_filtered: merged = self._merge_with_rrf(vector_results, fed_filtered) return SearchResults(results=merged[original_offset : original_offset + original_limit]) @@ -143,6 +158,7 @@ async def _execute_vector_search( self, plan: SearchPlan, collection_id: str, + acl_principals: Optional[list[str]] = None, ) -> list[SearchResult]: """Embed, compile, and execute vector DB search. @@ -174,9 +190,68 @@ async def _execute_vector_search( plan=plan, embeddings=embeddings, collection_id=collection_id, + acl_principals=acl_principals, ) return (await self._vector_db.execute_query(compiled_query)).results + # ------------------------------------------------------------------ + # Access control resolution + # ------------------------------------------------------------------ + + async def _resolve_acl_principals( + self, + db: AsyncSession, + ctx: ApiContext, + user_principal: Optional[str], + collection_readable_id: str, + ) -> Optional[list[str]]: + """Resolve user's ACL principals for a collection. + + Returns None if user_principal is not set or collection has no AC sources. + Returns a list of principals (possibly empty) otherwise. + """ + if not user_principal: + return None + + access_context = await self._access_broker.resolve_access_context_for_collection( + db=db, + user_principal=user_principal, + readable_collection_id=collection_readable_id, + organization_id=ctx.organization.id, + ) + + if access_context is None: + return None + + principals = list(access_context.all_principals) + ctx.logger.info(f"[ACL] Resolved {len(principals)} principals for user '{user_principal}'") + return principals + + @staticmethod + def _apply_acl_in_memory( + results: list[SearchResult], + principals: list[str], + ) -> list[SearchResult]: + """Apply ACL filtering to results in-memory (for federated sources). + + Keeps results that are: + - From non-AC sources (is_public is None — no ACL data means pass through) + - Explicitly public (is_public is True) + - Matching a viewer principal + """ + principal_set = set(principals) + + def _passes(r: SearchResult) -> bool: + if r.access.is_public is None: + return True # Non-AC source — no access data, pass through + if r.access.is_public: + return True + if r.access.viewers: + return bool(principal_set & set(r.access.viewers)) + return False + + return [r for r in results if _passes(r)] + # ------------------------------------------------------------------ # Federated source discovery # ------------------------------------------------------------------ diff --git a/backend/airweave/domains/search/fakes/executor.py b/backend/airweave/domains/search/fakes/executor.py index 5a97fb8d0..c49226d85 100644 --- a/backend/airweave/domains/search/fakes/executor.py +++ b/backend/airweave/domains/search/fakes/executor.py @@ -40,6 +40,7 @@ async def execute( db: Any = None, ctx: Any = None, collection_readable_id: str = "", + user_principal: str | None = None, ) -> SearchResults: """Record the call and return seeded result, or raise seeded error.""" self._calls.append(("execute", plan, user_filter, collection_id)) diff --git a/backend/airweave/domains/search/instant/service.py b/backend/airweave/domains/search/instant/service.py index 75ad1e7b0..5e973f8e2 100644 --- a/backend/airweave/domains/search/instant/service.py +++ b/backend/airweave/domains/search/instant/service.py @@ -49,13 +49,16 @@ async def search( ctx: ApiContext, readable_id: str, request: InstantSearchRequest, + user_principal_override: str | None = None, ) -> SearchResults: """Build plan from request and execute.""" start_time = time.monotonic() ctx.logger.info(f"Instant search started collection={readable_id} query={request.query!r}") try: - result = await self._execute(db, ctx, readable_id, request, start_time) + result = await self._execute( + db, ctx, readable_id, request, start_time, user_principal_override + ) duration_ms = int((time.monotonic() - start_time) * 1000) ctx.logger.info( f"Instant search completed collection={readable_id} " @@ -87,6 +90,7 @@ async def _execute( readable_id: str, request: InstantSearchRequest, start_time: float, + user_principal_override: str | None = None, ) -> SearchResults: """Internal execution — resolve collection, build plan, execute.""" collection = await self._collection_repo.get_by_readable_id(db, readable_id, ctx) @@ -107,6 +111,7 @@ async def _execute( db=db, ctx=ctx, collection_readable_id=readable_id, + user_principal=user_principal_override, ) duration_ms = int((time.monotonic() - start_time) * 1000) diff --git a/backend/airweave/domains/search/protocols.py b/backend/airweave/domains/search/protocols.py index 1ae3251b3..04cee0ae1 100644 --- a/backend/airweave/domains/search/protocols.py +++ b/backend/airweave/domains/search/protocols.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Optional, Protocol, runtime_checkable from sqlalchemy.ext.asyncio import AsyncSession @@ -39,6 +39,7 @@ async def execute( db: AsyncSession, ctx: ApiContext, collection_readable_id: str, + user_principal: Optional[str] = None, ) -> SearchResults: """Execute a search plan and return results.""" ... @@ -68,6 +69,7 @@ async def search( ctx: ApiContext, readable_id: str, request: InstantSearchRequest, + user_principal_override: Optional[str] = None, ) -> SearchResults: """Execute instant search and return results.""" ... @@ -83,6 +85,7 @@ async def search( ctx: ApiContext, readable_id: str, request: ClassicSearchRequest, + user_principal_override: Optional[str] = None, ) -> SearchResults: """Execute classic search and return results.""" ... @@ -98,6 +101,7 @@ async def search( ctx: ApiContext, readable_id: str, request: AgenticSearchRequest, + user_principal_override: Optional[str] = None, ) -> SearchResults: """Execute agentic search and return results.""" ... diff --git a/backend/airweave/domains/search/tests/test_executor.py b/backend/airweave/domains/search/tests/test_executor.py index 8da30518a..1768da077 100644 --- a/backend/airweave/domains/search/tests/test_executor.py +++ b/backend/airweave/domains/search/tests/test_executor.py @@ -17,6 +17,7 @@ import pytest +from airweave.domains.access_control.fakes.broker import FakeAccessBroker from airweave.domains.embedders.fakes.embedder import FakeDenseEmbedder, FakeSparseEmbedder from airweave.domains.search.adapters.vector_db.fakes.vector_db import FakeVectorDB from airweave.domains.search.executor import ( @@ -255,6 +256,7 @@ def _build_executor( sc_repo=sc_repo or FakeSourceConnectionRepository(), source_registry=source_registry or FakeSourceRegistry(), source_lifecycle=source_lifecycle or FakeSourceLifecycleService(), + access_broker=FakeAccessBroker(), ) @@ -913,6 +915,7 @@ async def test_dense_embedding_failure_propagates(self): sc_repo=FakeSourceConnectionRepository(), source_registry=FakeSourceRegistry(), source_lifecycle=FakeSourceLifecycleService(), + access_broker=FakeAccessBroker(), ) with pytest.raises(RuntimeError, match="provider down"): @@ -938,6 +941,7 @@ async def test_sparse_embedding_failure_propagates(self): sc_repo=FakeSourceConnectionRepository(), source_registry=FakeSourceRegistry(), source_lifecycle=FakeSourceLifecycleService(), + access_broker=FakeAccessBroker(), ) with pytest.raises(RuntimeError, match="timeout"): diff --git a/backend/airweave/domains/sources/token_providers/oauth.py b/backend/airweave/domains/sources/token_providers/oauth.py index 844a6326a..6ce63cfc4 100644 --- a/backend/airweave/domains/sources/token_providers/oauth.py +++ b/backend/airweave/domains/sources/token_providers/oauth.py @@ -146,6 +146,31 @@ async def force_refresh(self) -> str: self._apply_refresh(result) return self._token + async def get_token_for_resource(self, resource_scope: str) -> Optional[str]: + """Exchange refresh token for an access token scoped to a different resource. + + Used by SharePoint Online to get a SP REST API token from a Graph-scoped + refresh token. The exchange does not persist the rotated refresh token. + + Returns None if refresh is not supported. + """ + if not self._can_refresh: + return None + + try: + async with get_db_context() as db: + result = await self._oauth2_service.exchange_token_for_scope( + db=db, + integration_short_name=self._source_short_name, + connection_id=self._connection_id, + ctx=self._ctx, + scope=resource_scope, + ) + return str(result) + except Exception as e: + self._logger.warning(f"Failed to exchange token for scope {resource_scope}: {e}") + return None + # ------------------------------------------------------------------ # Private # ------------------------------------------------------------------ diff --git a/backend/airweave/domains/sources/validation.py b/backend/airweave/domains/sources/validation.py index 195e9f567..25edf4ecb 100644 --- a/backend/airweave/domains/sources/validation.py +++ b/backend/airweave/domains/sources/validation.py @@ -37,14 +37,12 @@ def validate_config( """ entry = self._get_entry_or_404(short_name) - if not config_fields: - return {} - - payload = self._as_mapping(config_fields) - config_class = entry.config_ref if config_class is None: - return payload + # Source has no config schema — anything goes + return self._as_mapping(config_fields) if config_fields else {} + + payload = self._as_mapping(config_fields) if config_fields else {} self._enforce_feature_flags(short_name, payload, config_class, ctx) diff --git a/backend/airweave/domains/storage/file_service.py b/backend/airweave/domains/storage/file_service.py index 1d3897942..945ebf40d 100644 --- a/backend/airweave/domains/storage/file_service.py +++ b/backend/airweave/domains/storage/file_service.py @@ -61,7 +61,7 @@ def _ensure_base_dir(self) -> None: @staticmethod async def _resolve_headers(auth: SourceAuthProvider, url: str) -> dict: """Build auth headers. Pre-signed URLs skip the bearer token.""" - if "X-Amz-Algorithm" in url: + if "X-Amz-Algorithm" in url or "tempauth=" in url: return {} token = await auth.get_token() if hasattr(auth, "get_token") else None if not token: diff --git a/backend/airweave/domains/sync_pipeline/builders/destinations.py b/backend/airweave/domains/sync_pipeline/builders/destinations.py index f96bb9faa..2fa2896db 100644 --- a/backend/airweave/domains/sync_pipeline/builders/destinations.py +++ b/backend/airweave/domains/sync_pipeline/builders/destinations.py @@ -24,6 +24,7 @@ async def build_destinations( collection: schemas.CollectionRecord, logger: ContextualLogger, execution_config: Optional[SyncConfig] = None, + source_supports_acl: bool = False, ) -> List[BaseDestination]: """Build destinations.""" return await cls._create_destinations( @@ -31,6 +32,7 @@ async def build_destinations( collection=collection, logger=logger, execution_config=execution_config, + source_supports_acl=source_supports_acl, ) # ------------------------------------------------------------------------- @@ -44,6 +46,7 @@ async def _create_destinations( collection: schemas.CollectionRecord, logger: ContextualLogger, execution_config: Optional[SyncConfig] = None, + source_supports_acl: bool = False, ) -> List[BaseDestination]: """Create destination instances.""" destinations = [] @@ -59,6 +62,7 @@ async def _create_destinations( destination_connection_id=destination_connection_id, collection=collection, logger=logger, + source_supports_acl=source_supports_acl, ) if destination: destinations.append(destination) @@ -87,18 +91,20 @@ async def _create_single_destination( destination_connection_id: UUID, collection: schemas.CollectionRecord, logger: ContextualLogger, + source_supports_acl: bool = False, ) -> Optional[BaseDestination]: """Create a single destination instance.""" if destination_connection_id != NATIVE_VESPA_UUID: logger.warning(f"Unknown destination connection {destination_connection_id}, skipping") return None - return await cls._create_vespa(collection, logger) + return await cls._create_vespa(collection, logger, source_supports_acl=source_supports_acl) @classmethod async def _create_vespa( cls, collection: schemas.CollectionRecord, logger: ContextualLogger, + source_supports_acl: bool = False, ) -> BaseDestination: """Create native Vespa destination directly.""" logger.info("Using native Vespa destination (settings-based)") @@ -109,6 +115,7 @@ async def _create_vespa( organization_id=collection.organization_id, vector_size=None, logger=logger, + source_supports_acl=source_supports_acl, ) logger.info("Created native Vespa destination") return destination diff --git a/backend/airweave/domains/sync_pipeline/factory.py b/backend/airweave/domains/sync_pipeline/factory.py index ca0dff8c9..cdf225b02 100644 --- a/backend/airweave/domains/sync_pipeline/factory.py +++ b/backend/airweave/domains/sync_pipeline/factory.py @@ -184,12 +184,14 @@ async def create_orchestrator( execution_config=resolved_config, access_token=access_token, ) + source_entry = self._source_registry.get(sc.short_name) destinations = await self._build_destinations( db=db, sync=sync, collection=collection, ctx=ctx, execution_config=resolved_config, + source_supports_acl=source_entry.supports_access_control, ) entity_tracker = await self._build_entity_tracker( db=db, @@ -498,6 +500,7 @@ async def _build_destinations( collection: schemas.CollectionRecord, ctx: BaseContext, execution_config: SyncConfig, + source_supports_acl: bool = False, ) -> list: """Build destination instances for the sync.""" dest_logger = LoggerConfigurator.configure_logger( @@ -513,6 +516,7 @@ async def _build_destinations( collection=collection, logger=dest_logger, execution_config=execution_config, + source_supports_acl=source_supports_acl, ) # ------------------------------------------------------------------------- diff --git a/backend/airweave/platform/configs/auth.py b/backend/airweave/platform/configs/auth.py index 8659d872e..c6b62cd8b 100644 --- a/backend/airweave/platform/configs/auth.py +++ b/backend/airweave/platform/configs/auth.py @@ -752,6 +752,52 @@ class ShopifyAuthConfig(AuthConfig): ) +class SharePointOnlineAppAuthConfig(AuthConfig): + """SharePoint Online app-only authentication using client credentials. + + Uses client_id + client_secret for Microsoft Graph API calls, + and a certificate private key for SharePoint REST API calls. + Requires an Azure AD app registration with application permissions + and admin consent. + """ + + tenant_id: str = Field( + title="Tenant ID", + description="Azure AD tenant ID (e.g., 'contoso.onmicrosoft.com' or a UUID)", + min_length=1, + ) + client_id: str = Field( + title="Client ID", + description="Application (client) ID from the Azure AD app registration", + min_length=1, + ) + client_secret: str = Field( + title="Client Secret", + description="Client secret from the Azure AD app registration (for Graph API)", + min_length=1, + json_schema_extra={"is_secret": True}, + ) + private_key: str = Field( + title="Private Key (PEM)", + description=( + "PEM-encoded private key for certificate authentication " + "(for SharePoint REST API). Starts with '-----BEGIN PRIVATE KEY-----'" + ), + min_length=1, + json_schema_extra={"is_secret": True}, + ) + certificate: str = Field( + default="", + title="Certificate (PEM)", + description=( + "PEM-encoded certificate that was uploaded to the Azure AD app registration. " + "Used to compute the x5t thumbprint for SP REST API token exchange. " + "If omitted, SP site group expansion will not work." + ), + json_schema_extra={"is_secret": True}, + ) + + class ServiceNowAuthConfig(AuthConfig): """ServiceNow instance authentication credentials schema. diff --git a/backend/airweave/platform/configs/config.py b/backend/airweave/platform/configs/config.py index 70ca32ac5..c522e4d5e 100644 --- a/backend/airweave/platform/configs/config.py +++ b/backend/airweave/platform/configs/config.py @@ -1194,24 +1194,20 @@ class SharePointOnlineConfig(SourceConfig): default="", title="SharePoint Site URL", description=( - "URL of the SharePoint site(s) to sync. Supports a single URL " - "(e.g., 'https://contoso.sharepoint.com/sites/Marketing'), " - "comma-separated URLs for multiple sites, or leave empty to " - "sync all accessible sites." + "URL of a specific SharePoint site to sync " + "(e.g., 'https://contoso.sharepoint.com/sites/Marketing'). " + "Leave empty to sync all sites in the tenant." ), ) @field_validator("site_url") @classmethod def validate_site_url_ssrf(cls, v: str) -> str: - """Validate each comma-separated site URL for SSRF safety.""" + """Validate site URL for SSRF safety.""" if not v: return v - for url in v.split(","): - url = url.strip() - if url: - validate_url(url) - return v + validate_url(v.strip()) + return v.strip() include_personal_sites: bool = Field( default=False, diff --git a/backend/airweave/platform/destinations/vespa/destination.py b/backend/airweave/platform/destinations/vespa/destination.py index 7e006c665..b92cd14d7 100644 --- a/backend/airweave/platform/destinations/vespa/destination.py +++ b/backend/airweave/platform/destinations/vespa/destination.py @@ -112,10 +112,12 @@ async def create( instance.organization_id = organization_id # Initialize components + source_supports_acl = kwargs.get("source_supports_acl", False) instance._client = await VespaClient.connect(logger=instance.logger) instance._transformer = EntityTransformer( collection_id=collection_id, logger=instance.logger, + source_supports_acl=source_supports_acl, ) instance._query_builder = QueryBuilder() diff --git a/backend/airweave/platform/destinations/vespa/transformer.py b/backend/airweave/platform/destinations/vespa/transformer.py index 378ce214e..d118f6d62 100644 --- a/backend/airweave/platform/destinations/vespa/transformer.py +++ b/backend/airweave/platform/destinations/vespa/transformer.py @@ -25,7 +25,7 @@ def _sanitize_for_vespa(text: str) -> str: - """Sanitize text for Vespa by removing illegal characters. + r"""Sanitize text for Vespa by removing illegal characters. Vespa strictly rejects: 1. Control characters (code points < 32) except \n (0x0A), \r (0x0D), \t (0x09) @@ -167,15 +167,20 @@ def __init__( self, collection_id: Optional[UUID] = None, logger: Optional[ContextualLogger] = None, + source_supports_acl: bool = False, ): """Initialize the entity transformer. Args: collection_id: SQL collection UUID for multi-tenant filtering logger: Optional logger for debug/warning messages + source_supports_acl: Whether the source supports access control. + When True, entities with no access data default to invisible (fail-closed). + When False, entities default to public (visible to all). """ self.collection_id = collection_id self._logger = logger or default_logger + self._source_supports_acl = source_supports_acl def transform(self, entity: BaseEntity) -> VespaDocument: """Transform a single entity to Vespa document format. @@ -343,13 +348,17 @@ def _add_access_control_fields(self, fields: Dict[str, Any], entity: BaseEntity) """Add access control fields. Always sets access control fields with appropriate defaults: - - AC-enabled sources: Use the actual ACL values from entity.access - - Non-AC sources: Set is_public=True so entities are visible to everyone + - If entity has access data: use the actual ACL values + - AC-enabled source but no access data: fail-closed (invisible) + - Non-AC source: default to public (visible to everyone) """ access = getattr(entity, "access", None) if access is not None: fields["access_is_public"] = access.is_public fields["access_viewers"] = access.viewers if access.viewers else [] + elif self._source_supports_acl: + fields["access_is_public"] = False + fields["access_viewers"] = [] else: fields["access_is_public"] = True fields["access_viewers"] = [] diff --git a/backend/airweave/platform/sources/__init__.py b/backend/airweave/platform/sources/__init__.py index 84b97bd74..0dc48be82 100644 --- a/backend/airweave/platform/sources/__init__.py +++ b/backend/airweave/platform/sources/__init__.py @@ -50,7 +50,7 @@ from .servicenow import ServiceNowSource from .sharepoint import SharePointSource from .sharepoint2019v2.source import SharePoint2019V2Source -from .sharepoint_online.source import SharePointOnlineSource +from .sharepoint_online.source import SharePointOnlineAppSource, SharePointOnlineSource from .shopify import ShopifySource from .slab import SlabSource from .slack import SlackSource @@ -117,6 +117,7 @@ SharePointSource, SharePoint2019V2Source, SharePointOnlineSource, + SharePointOnlineAppSource, ShopifySource, SlabSource, SliteSource, diff --git a/backend/airweave/platform/sources/sharepoint_online/__init__.py b/backend/airweave/platform/sources/sharepoint_online/__init__.py index c6b62d75e..a8701157f 100644 --- a/backend/airweave/platform/sources/sharepoint_online/__init__.py +++ b/backend/airweave/platform/sources/sharepoint_online/__init__.py @@ -1,8 +1,12 @@ """SharePoint Online source connector. Uses Microsoft Graph API for content sync and Entra ID for access control. +Two variants: OAuth (delegated) and App (client credentials). """ -from airweave.platform.sources.sharepoint_online.source import SharePointOnlineSource +from airweave.platform.sources.sharepoint_online.source import ( + SharePointOnlineAppSource, + SharePointOnlineSource, +) -__all__ = ["SharePointOnlineSource"] +__all__ = ["SharePointOnlineSource", "SharePointOnlineAppSource"] diff --git a/backend/airweave/platform/sources/sharepoint_online/builders.py b/backend/airweave/platform/sources/sharepoint_online/builders.py index d4d1cb59f..eea0e83d1 100644 --- a/backend/airweave/platform/sources/sharepoint_online/builders.py +++ b/backend/airweave/platform/sources/sharepoint_online/builders.py @@ -9,11 +9,10 @@ from typing import Any, Dict, List, Optional from airweave.domains.sync_pipeline.exceptions import EntityProcessingError -from airweave.platform.entities._base import Breadcrumb +from airweave.platform.entities._base import AccessControl, Breadcrumb from airweave.platform.entities.sharepoint_online import ( SharePointOnlineDriveEntity, SharePointOnlineFileEntity, - SharePointOnlineItemEntity, SharePointOnlinePageEntity, SharePointOnlineSiteEntity, ) @@ -32,6 +31,7 @@ def _parse_datetime(dt_str: Optional[str]) -> Optional[datetime]: async def build_site_entity( site_data: Dict[str, Any], breadcrumbs: List[Breadcrumb], + access: Optional[AccessControl] = None, ) -> SharePointOnlineSiteEntity: """Build a site entity from Graph API site data.""" site_id = site_data.get("id") @@ -53,6 +53,7 @@ async def build_site_entity( created_at=_parse_datetime(site_data.get("createdDateTime")), last_modified_at=_parse_datetime(site_data.get("lastModifiedDateTime")), breadcrumbs=breadcrumbs, + access=access, ) @@ -60,6 +61,7 @@ async def build_drive_entity( drive_data: Dict[str, Any], site_id: str, breadcrumbs: List[Breadcrumb], + access: Optional[AccessControl] = None, ) -> SharePointOnlineDriveEntity: """Build a drive entity from Graph API drive data.""" drive_id = drive_data.get("id") @@ -84,6 +86,7 @@ async def build_drive_entity( created_at=_parse_datetime(drive_data.get("createdDateTime")), last_modified_at=_parse_datetime(drive_data.get("lastModifiedDateTime")), breadcrumbs=breadcrumbs, + access=access, ) @@ -154,46 +157,11 @@ async def build_file_entity( ) -async def build_item_entity( - item_data: Dict[str, Any], - site_id: str, - list_id: str, - breadcrumbs: List[Breadcrumb], -) -> SharePointOnlineItemEntity: - """Build a list item entity from Graph API list item data.""" - item_id = item_data.get("id") - if not item_id: - raise EntityProcessingError("Missing id for list item") - - fields = item_data.get("fields", {}) or {} - title = fields.get("Title") or fields.get("title") or item_data.get("id", "Untitled") - - content_type = None - ct_obj = item_data.get("contentType") - if ct_obj: - content_type = ct_obj.get("name") - - spo_entity_id = f"spo:item:{site_id}:{list_id}:{item_id}" - - return SharePointOnlineItemEntity( - spo_entity_id=spo_entity_id, - item_id=item_id, - list_id=list_id, - site_id=site_id, - title=title, - web_url=item_data.get("webUrl", ""), - content_type=content_type, - fields=fields, - created_at=_parse_datetime(item_data.get("createdDateTime")), - updated_at=_parse_datetime(item_data.get("lastModifiedDateTime")), - breadcrumbs=breadcrumbs, - ) - - async def build_page_entity( page_data: Dict[str, Any], site_id: str, breadcrumbs: List[Breadcrumb], + access: Optional[AccessControl] = None, ) -> SharePointOnlinePageEntity: """Build a page entity from Graph API site page data.""" page_id = page_data.get("id") @@ -214,4 +182,5 @@ async def build_page_entity( created_at=_parse_datetime(page_data.get("createdDateTime")), updated_at=_parse_datetime(page_data.get("lastModifiedDateTime")), breadcrumbs=breadcrumbs, + access=access, ) diff --git a/backend/airweave/platform/sources/sharepoint_online/client.py b/backend/airweave/platform/sources/sharepoint_online/client.py index a93e20dde..f98d70d16 100644 --- a/backend/airweave/platform/sources/sharepoint_online/client.py +++ b/backend/airweave/platform/sources/sharepoint_online/client.py @@ -137,6 +137,12 @@ async def search_sites(self, query: str = "*") -> AsyncGenerator[Dict[str, Any], async for site in self.get_paginated(url, params): yield site + async def get_all_sites(self) -> AsyncGenerator[Dict[str, Any], None]: + """Enumerate all sites in the tenant (requires application permissions).""" + url = f"{GRAPH_BASE_URL}/sites/getAllSites" + async for site in self.get_paginated(url): + yield site + async def get_subsites(self, site_id: str) -> AsyncGenerator[Dict[str, Any], None]: """Get subsites of a SharePoint site.""" url = f"{GRAPH_BASE_URL}/sites/{site_id}/sites" @@ -233,11 +239,18 @@ async def get_drive_delta( self, drive_id: str, delta_token: str = "", + prefer_headers: Optional[List[str]] = None, ) -> Tuple[List[Dict[str, Any]], str]: """Get changes since the last delta token. Returns (changed_items, new_delta_token). If delta_token is empty, returns all items (initial sync). + + Args: + drive_id: The drive to query. + delta_token: Continuation token from a previous delta query. + prefer_headers: Optional Prefer header values for app-only delta + (e.g., ["deltashowsharingchanges", "deltashowremovedasdeleted"]). """ if delta_token: url = delta_token # Delta tokens are full URLs @@ -249,7 +262,15 @@ async def get_drive_delta( delta_link = "" while current_url: - data = await self.get(current_url) + if prefer_headers: + headers = await self._headers() + headers["Prefer"] = ", ".join(prefer_headers) + self.logger.debug(f"GET {current_url} (Prefer: {headers['Prefer']})") + response = await self._http_client.get(current_url, headers=headers, timeout=30.0) + response.raise_for_status() + data = response.json() + else: + data = await self.get(current_url) items = data.get("value", []) all_items.extend(items) @@ -282,6 +303,20 @@ async def get_item_permissions( return [] raise + async def get_drive_root_permissions( + self, + drive_id: str, + ) -> List[Dict[str, Any]]: + """Get permissions for the root of a drive (site-level permissions).""" + url = f"{GRAPH_BASE_URL}/drives/{drive_id}/root/permissions" + try: + data = await self.get(url) + return data.get("value", []) + except httpx.HTTPStatusError as e: + if e.response.status_code in (404, 403): + return [] + raise + # -- Lists -- async def get_lists(self, site_id: str) -> AsyncGenerator[Dict[str, Any], None]: diff --git a/backend/airweave/platform/sources/sharepoint_online/source.py b/backend/airweave/platform/sources/sharepoint_online/source.py index cfa141458..c5eb96336 100644 --- a/backend/airweave/platform/sources/sharepoint_online/source.py +++ b/backend/airweave/platform/sources/sharepoint_online/source.py @@ -18,6 +18,10 @@ Incremental sync: - Uses Graph delta queries (/drives/{id}/root/delta) - Per-drive delta tokens stored in cursor + +Two source variants: +- SharePointOnlineSource: OAuth (delegated user auth) +- SharePointOnlineAppSource: Client credentials (app-only auth) """ from __future__ import annotations @@ -34,11 +38,14 @@ from airweave.domains.access_control.schemas import MembershipTuple from airweave.domains.browse_tree.types import BrowseNode, NodeSelectionData from airweave.domains.sources.exceptions import SourceAuthError +from airweave.domains.sources.token_providers.credential import DirectCredentialProvider from airweave.domains.sources.token_providers.protocol import TokenProviderProtocol +from airweave.domains.sources.token_providers.static import StaticTokenProvider from airweave.domains.storage import FileSkippedException from airweave.domains.storage.file_service import FileService from airweave.domains.sync_pipeline.exceptions import EntityProcessingError from airweave.domains.syncs.cursors.cursor import SyncCursor +from airweave.platform.configs.auth import SharePointOnlineAppAuthConfig from airweave.platform.configs.config import SharePointOnlineConfig from airweave.platform.cursors.sharepoint_online import SharePointOnlineCursor from airweave.platform.decorators import source @@ -53,6 +60,7 @@ retry_if_rate_limit_or_timeout, wait_rate_limit_with_backoff, ) +from airweave.platform.sources.sharepoint_online.acl import extract_access_control from airweave.platform.sources.sharepoint_online.builders import ( build_drive_entity, build_file_entity, @@ -76,52 +84,77 @@ class PendingFileDownload: item_id: str -@source( - name="SharePoint Online", - short_name="sharepoint_online", - auth_methods=[ - AuthenticationMethod.OAUTH_BROWSER, - AuthenticationMethod.OAUTH_TOKEN, - AuthenticationMethod.AUTH_PROVIDER, - ], - oauth_type=OAuthType.WITH_ROTATING_REFRESH, - auth_config_class=None, - config_class=SharePointOnlineConfig, - supports_continuous=True, - cursor_class=SharePointOnlineCursor, - supports_access_control=True, - supports_browse_tree=True, - feature_flag="sharepoint_2019_v2", - labels=["Collaboration", "File Storage"], -) -class SharePointOnlineSource(BaseSource): - """SharePoint Online source using Microsoft Graph API. +# ============================================================================= +# Base class — shared sync, browse tree, download, and ACL logic +# ============================================================================= + - Syncs sites, drives, files, lists, and pages with full ACL support. - Uses Entra ID for group membership expansion. +class SharePointOnlineBase(BaseSource): + """Shared implementation for SharePoint Online sources. + + Subclasses must implement the auth-specific hooks: + - create() — class constructor + - _get_access_token() — return a valid Microsoft Graph token + - _handle_401() — refresh/re-exchange on 401, return new token + - _make_sp_token_provider() — callable returning SP REST API token + - _get_download_auth(url) — auth suitable for file download + - _discover_sites(graph_client) — site discovery strategy """ - @classmethod - async def create( - cls, - *, - auth: TokenProviderProtocol, - logger: ContextualLogger, - http_client: AirweaveHttpClient, - config: SharePointOnlineConfig, - ) -> SharePointOnlineSource: - """Create and configure a SharePoint Online source instance.""" - instance = cls(auth=auth, logger=logger, http_client=http_client) - instance._site_url = config.site_url.rstrip("/") if config.site_url else "" - instance._include_personal_sites = config.include_personal_sites - instance._include_pages = config.include_pages - instance._item_level_entra_groups: set = set() - instance._item_level_sp_groups: set = set() - return instance + # Instance attributes set by _init_common() + _site_url: str + _include_personal_sites: bool + _include_pages: bool + _item_level_entra_groups: set + _item_level_sp_groups: set + + def _init_common(self, config: SharePointOnlineConfig) -> None: + """Initialize fields shared by both OAuth and client-credentials sources.""" + self._site_url = config.site_url.rstrip("/") if config.site_url else "" + self._include_personal_sites = config.include_personal_sites + self._include_pages = config.include_pages + self._item_level_entra_groups = set() + self._item_level_sp_groups = set() + + # -- Auth hooks (subclasses override) -- + + async def _get_access_token(self) -> str: + """Get a valid Microsoft Graph access token.""" + raise NotImplementedError + + async def _handle_401(self) -> str: + """Handle a 401 by refreshing/re-exchanging. Returns new token.""" + raise NotImplementedError + + def _make_sp_token_provider(self) -> Optional[Callable]: + """Create an async callable returning a SharePoint REST API token, or None.""" + raise NotImplementedError + + async def _get_download_auth(self, url: str) -> Any: + """Return an auth object suitable for FileService.download_from_url.""" + return self.auth + + async def _discover_sites(self, graph_client: GraphClient) -> List[Dict[str, Any]]: + """Discover SharePoint sites to sync.""" + raise NotImplementedError + + @property + def _delta_prefer_headers(self) -> List[str]: + """Prefer headers for delta queries (permission change tracking).""" + return [] + + # -- Shared client factories -- def _create_graph_client(self) -> GraphClient: return GraphClient( - access_token_provider=self.auth.get_token, + access_token_provider=self._get_access_token, + http_client=self.http_client, + logger=self.logger, + ) + + def _create_group_expander(self) -> EntraGroupExpander: + return EntraGroupExpander( + access_token_provider=self._get_access_token, http_client=self.http_client, logger=self.logger, ) @@ -134,13 +167,13 @@ def _create_graph_client(self) -> GraphClient: ) async def _get(self, url: str, params: Optional[Dict] = None) -> Dict[str, Any]: """Make an authenticated GET request to Microsoft Graph API.""" - token = await self.auth.get_token() + token = await self._get_access_token() headers = {"Authorization": f"Bearer {token}", "Accept": "application/json"} response = await self.http_client.get(url, headers=headers, params=params) - if response.status_code == 401 and self.auth.supports_refresh: + if response.status_code == 401: self.logger.warning("Received 401 from Microsoft Graph API — refreshing token") - new_token = await self.auth.force_refresh() + new_token = await self._handle_401() headers = {"Authorization": f"Bearer {new_token}", "Accept": "application/json"} response = await self.http_client.get(url, headers=headers, params=params) @@ -151,42 +184,12 @@ async def _get(self, url: str, params: Optional[Dict] = None) -> Dict[str, Any]: ) return response.json() - def _create_group_expander(self) -> EntraGroupExpander: - return EntraGroupExpander( - access_token_provider=self.auth.get_token, - http_client=self.http_client, - logger=self.logger, - ) - - def _derive_sp_resource_scope(self) -> Optional[str]: - """Derive the SharePoint resource scope from the site URL. - - E.g. https://neenacorp.sharepoint.com/sites/JAman - -> https://neenacorp.sharepoint.com/.default - """ + def _derive_sp_hostname(self) -> Optional[str]: + """Derive the SharePoint hostname from the site URL.""" if not self._site_url: return None parsed = urlparse(self._site_url) - if not parsed.netloc: - return None - return f"https://{parsed.netloc}/.default" - - def _make_sp_token_provider(self) -> Optional[Callable]: - """Create an async callable that returns a SharePoint-scoped token. - - Returns None if the site URL is not set or no token manager is available. - """ - sp_scope = self._derive_sp_resource_scope() - if not sp_scope: - return None - - async def _provider() -> str: - token = await self.get_token_for_resource(sp_scope) - if not token: - raise RuntimeError(f"Could not obtain SharePoint token for scope {sp_scope}") - return token - - return _provider + return parsed.netloc or None def _track_entity_groups(self, entity: BaseEntity) -> None: """Track Entra ID and SP site groups found in entity permissions.""" @@ -236,14 +239,7 @@ async def get_browse_children( self, parent_node_id: Optional[str] = None, ) -> List[BrowseNode]: - """Lazy-load tree nodes from Microsoft Graph API. - - Tree structure: - - Root (parent_node_id=None): returns discovered sites - - Site node (site:{site_id}): returns drives for the site - - Drive node (drive:{site_id}|{drive_id}): returns root children of the drive - - Folder node (folder:{drive_id}|{folder_id}): returns children of the folder - """ + """Lazy-load tree nodes from Microsoft Graph API.""" graph_client = self._create_graph_client() nodes: List[BrowseNode] = [] @@ -381,10 +377,13 @@ async def _download_and_save_file( entity.url = ( f"https://graph.microsoft.com/v1.0/drives/{drive_id}/items/{item_id}/content" ) + + auth = await self._get_download_auth(entity.url) + await files.download_from_url( entity=entity, client=self.http_client, - auth=self.auth, + auth=auth, logger=self.logger, ) return entity @@ -416,6 +415,8 @@ async def download_one(item: PendingFileDownload): self.logger.debug(f"File download skipped for {item.drive_id}/{item.item_id}") except EntityProcessingError as e: self.logger.warning(f"Skipping file download: {e}") + except Exception as e: + self.logger.warning(f"Unexpected error downloading {item.entity.name}: {e}") tasks = [asyncio.create_task(download_one(p)) for p in pending] await asyncio.gather(*tasks, return_exceptions=True) @@ -469,39 +470,6 @@ async def generate_entities( async for entity in self._incremental_sync(cursor, files): yield entity - async def _discover_sites(self, graph_client: GraphClient) -> List[Dict[str, Any]]: - """Discover sites to sync based on config. - - Supports: - - Single URL: "https://tenant.sharepoint.com/sites/MySite" - - Comma-separated: "https://tenant.sharepoint.com/sites/A, .../sites/B" - - Empty string: discover all accessible sites - """ - sites = [] - - if self._site_url: - urls = [u.strip() for u in self._site_url.split(",") if u.strip()] - for url in urls: - parsed = urlparse(url) - hostname = parsed.netloc - site_path = parsed.path.lstrip("/") - try: - site = await graph_client.get_site_by_url(hostname, site_path) - sites.append(site) - except SourceAuthError: - raise - except Exception as e: - self.logger.warning(f"Could not resolve site URL {url}: {e}") - raise - else: - async for site in graph_client.search_sites("*"): - if not self._include_personal_sites and site.get("isPersonalSite", False): - continue - sites.append(site) - - self.logger.info(f"Discovered {len(sites)} sites to sync") - return sites - async def _resolve_unresolved_viewers( self, entity: BaseEntity, graph_client: GraphClient ) -> None: @@ -528,11 +496,7 @@ async def _resolve_unresolved_viewers( entity.access.viewers = new_viewers async def _fetch_sp_group_viewers(self) -> List[str]: - """Fetch all SP site groups and return their viewer strings. - - Uses the shared http_client with SP-scoped token headers. - Returns empty list if SP token is unavailable. - """ + """Fetch all SP site groups and return their viewer strings.""" sp_token_provider = self._make_sp_token_provider() if not sp_token_provider or not self._site_url: return [] @@ -578,8 +542,25 @@ async def _full_sync( # noqa: C901 for site_data in sites: site_id = site_data.get("id", "") + # Collect all drives for this site (single API call) + all_drives = [] + async for drive_data in graph_client.get_drives(site_id): + all_drives.append(drive_data) + + # Fetch site-level permissions from the first drive's root. + site_access = None + if all_drives: + try: + site_permissions = await graph_client.get_drive_root_permissions( + all_drives[0]["id"] + ) + site_access = await extract_access_control(site_permissions) + except Exception as e: + self.logger.warning(f"Could not fetch site-level permissions: {e}") + try: - site_entity = await build_site_entity(site_data, []) + site_entity = await build_site_entity(site_data, [], access=site_access) + self._track_entity_groups(site_entity) yield site_entity entity_count += 1 @@ -595,10 +576,24 @@ async def _full_sync( # noqa: C901 sp_group_viewers = await self._fetch_sp_group_viewers() - async for drive_data in graph_client.get_drives(site_id): + for drive_data in all_drives: drive_id = drive_data.get("id", "") try: - drive_entity = await build_drive_entity(drive_data, site_id, site_breadcrumbs) + # Each drive gets its own root permissions + drive_access = site_access + if drive_id != all_drives[0]["id"]: + try: + drive_permissions = await graph_client.get_drive_root_permissions( + drive_id + ) + drive_access = await extract_access_control(drive_permissions) + except Exception: + pass # Fall back to site_access + + drive_entity = await build_drive_entity( + drive_data, site_id, site_breadcrumbs, access=drive_access + ) + self._track_entity_groups(drive_entity) yield drive_entity entity_count += 1 @@ -661,6 +656,8 @@ async def _full_sync( # noqa: C901 except EntityProcessingError as e: self.logger.warning(f"Skipping file: {e}") + except Exception as e: + self.logger.warning(f"Unexpected error processing file: {e}") if pending_files and files: downloaded = await self._download_files_parallel(pending_files, files) @@ -670,7 +667,9 @@ async def _full_sync( # noqa: C901 if cursor: try: - _, delta_token = await graph_client.get_drive_delta(drive_id) + _, delta_token = await graph_client.get_drive_delta( + drive_id, prefer_headers=self._delta_prefer_headers + ) if delta_token: cursor_schema = SharePointOnlineCursor(**cursor.data) cursor_schema.update_entity_cursor( @@ -696,8 +695,9 @@ async def _full_sync( # noqa: C901 async for page_data in graph_client.get_pages(site_id): try: page_entity = await build_page_entity( - page_data, site_id, site_breadcrumbs + page_data, site_id, site_breadcrumbs, access=site_access ) + self._track_entity_groups(page_entity) yield page_entity entity_count += 1 except EntityProcessingError as e: @@ -743,7 +743,9 @@ async def _incremental_sync( # noqa: C901 for drive_id, token in delta_tokens.items(): try: - changed_items, new_token = await graph_client.get_drive_delta(drive_id, token) + changed_items, new_token = await graph_client.get_drive_delta( + drive_id, token, prefer_headers=self._delta_prefer_headers + ) except SourceAuthError: raise except Exception as e: @@ -848,7 +850,19 @@ async def _targeted_sync( # noqa: C901 try: site_data = await graph_client.get_site(site_id) - site_entity = await build_site_entity(site_data, []) + + # Fetch site-level permissions from first drive root + targeted_site_access = None + async for peek_drive in graph_client.get_drives(site_id): + try: + perms = await graph_client.get_drive_root_permissions(peek_drive["id"]) + targeted_site_access = await extract_access_control(perms) + except Exception: + pass + break + + site_entity = await build_site_entity(site_data, [], access=targeted_site_access) + self._track_entity_groups(site_entity) yield site_entity entity_count += 1 except SourceAuthError: @@ -935,7 +949,19 @@ async def _sync_drive( """Sync all files in a single drive (used by both full and targeted sync).""" try: drive_data = await graph_client.get_drive(drive_id) - drive_entity = await build_drive_entity(drive_data, site_id, site_breadcrumbs) + + # Fetch drive root permissions for the drive entity + drive_access = None + try: + drive_permissions = await graph_client.get_drive_root_permissions(drive_id) + drive_access = await extract_access_control(drive_permissions) + except Exception: + pass + + drive_entity = await build_drive_entity( + drive_data, site_id, site_breadcrumbs, access=drive_access + ) + self._track_entity_groups(drive_entity) yield drive_entity drive_breadcrumbs = site_breadcrumbs + [ @@ -1044,10 +1070,7 @@ async def _expand_entra_groups( yield membership async def _expand_sp_site_groups(self) -> AsyncGenerator[MembershipTuple, None]: - """Expand tracked SP site groups into user memberships. - - Uses the shared http_client with SP-scoped token headers. - """ + """Expand tracked SP site groups into user memberships.""" sp_group_names = list(self._item_level_sp_groups) if not sp_group_names or not self._site_url: return @@ -1116,3 +1139,348 @@ async def generate_access_control_memberships( group_expander.log_stats() self.logger.info(f"Access control extraction complete: {membership_count} memberships") + + +# ============================================================================= +# OAuth source — delegated user auth +# ============================================================================= + + +@source( + name="SharePoint Online", + short_name="sharepoint_online", + auth_methods=[ + AuthenticationMethod.OAUTH_BROWSER, + AuthenticationMethod.OAUTH_TOKEN, + AuthenticationMethod.AUTH_PROVIDER, + ], + oauth_type=OAuthType.WITH_ROTATING_REFRESH, + auth_config_class=None, + config_class=SharePointOnlineConfig, + supports_continuous=True, + cursor_class=SharePointOnlineCursor, + supports_access_control=True, + supports_browse_tree=True, + feature_flag="sharepoint_2019_v2", + labels=["Collaboration", "File Storage"], +) +class SharePointOnlineSource(SharePointOnlineBase): + """SharePoint Online source using delegated OAuth. + + Uses the signed-in user's permissions via OAuth browser flow. + Site discovery uses Graph search (delegated permissions). + """ + + @classmethod + async def create( + cls, + *, + auth: TokenProviderProtocol, + logger: ContextualLogger, + http_client: AirweaveHttpClient, + config: SharePointOnlineConfig, + ) -> SharePointOnlineSource: + """Create and configure an OAuth SharePoint Online source.""" + instance = cls(auth=auth, logger=logger, http_client=http_client) + instance._init_common(config) + return instance + + async def _get_access_token(self) -> str: + return await self.auth.get_token() + + async def _handle_401(self) -> str: + if self.auth.supports_refresh: + return await self.auth.force_refresh() + return await self.auth.get_token() + + def _make_sp_token_provider(self) -> Optional[Callable]: + """Create SP token provider via OAuth scope exchange.""" + sp_scope = self._derive_sp_resource_scope() + if not sp_scope: + return None + + async def _provider() -> str: + token = await self.get_token_for_resource(sp_scope) + if not token: + raise RuntimeError(f"Could not obtain SharePoint token for scope {sp_scope}") + return token + + return _provider + + def _derive_sp_resource_scope(self) -> Optional[str]: + """Derive the SharePoint resource scope from the site URL. + + E.g. https://neenacorp.sharepoint.com/sites/JAman + -> https://neenacorp.sharepoint.com/.default + """ + if not self._site_url: + return None + parsed = urlparse(self._site_url) + if not parsed.netloc: + return None + return f"https://{parsed.netloc}/.default" + + async def _discover_sites(self, graph_client: GraphClient) -> List[Dict[str, Any]]: + """Discover sites via Graph search (delegated permissions). + + Supports: + - Single URL: "https://tenant.sharepoint.com/sites/MySite" + - Comma-separated: "https://tenant.sharepoint.com/sites/A, .../sites/B" + - Empty string: search all accessible sites + """ + sites = [] + + if self._site_url: + urls = [u.strip() for u in self._site_url.split(",") if u.strip()] + for url in urls: + parsed = urlparse(url) + hostname = parsed.netloc + site_path = parsed.path.lstrip("/") + try: + site = await graph_client.get_site_by_url(hostname, site_path) + sites.append(site) + except SourceAuthError: + raise + except Exception as e: + self.logger.warning(f"Could not resolve site URL {url}: {e}") + raise + else: + async for site in graph_client.search_sites("*"): + if not self._include_personal_sites and site.get("isPersonalSite", False): + continue + sites.append(site) + + self.logger.info(f"Discovered {len(sites)} sites to sync") + return sites + + +# ============================================================================= +# Client credentials source — app-only auth +# ============================================================================= + + +@source( + name="SharePoint Online (App)", + short_name="sharepoint_online_app", + auth_methods=[AuthenticationMethod.DIRECT], + auth_config_class=SharePointOnlineAppAuthConfig, + config_class=SharePointOnlineConfig, + supports_continuous=True, + cursor_class=SharePointOnlineCursor, + supports_access_control=True, + supports_browse_tree=True, + feature_flag="sharepoint_2019_v2", + labels=["Collaboration", "File Storage"], +) +class SharePointOnlineAppSource(SharePointOnlineBase): + """SharePoint Online source using client credentials (app-only auth). + + Uses client_id + client_secret for Graph API and certificate-based + authentication for SharePoint REST API. Requires Azure AD app registration + with application permissions and admin consent. + """ + + _tenant_id: str + _client_id: str + _client_secret: str + _private_key: str + _certificate: str + _graph_token: Optional[str] + _graph_token_expires: float + _sp_tokens: Dict[str, tuple[str, float]] + + @classmethod + async def create( + cls, + *, + auth: DirectCredentialProvider, + logger: ContextualLogger, + http_client: AirweaveHttpClient, + config: SharePointOnlineConfig, + ) -> SharePointOnlineAppSource: + """Create and configure a client-credentials SharePoint Online source.""" + instance = cls(auth=auth, logger=logger, http_client=http_client) + instance._init_common(config) + + creds: SharePointOnlineAppAuthConfig = auth.credentials + instance._tenant_id = creds.tenant_id + instance._client_id = creds.client_id + instance._client_secret = creds.client_secret + instance._private_key = creds.private_key + instance._certificate = creds.certificate + + # Token cache + instance._graph_token = None + instance._graph_token_expires = 0.0 + instance._sp_tokens = {} # hostname -> (token, expires_at) + + # Exchange for initial Graph token + instance._graph_token = await instance._exchange_graph_token() + instance._graph_token_expires = asyncio.get_event_loop().time() + 3500 + + return instance + + # -- Token exchange (app-only mode) -- + + async def _exchange_graph_token(self) -> str: + """Exchange client credentials for a Microsoft Graph access token.""" + url = f"https://login.microsoftonline.com/{self._tenant_id}/oauth2/v2.0/token" + async with httpx.AsyncClient() as client: + resp = await client.post( + url, + data={ + "grant_type": "client_credentials", + "client_id": self._client_id, + "client_secret": self._client_secret, + "scope": "https://graph.microsoft.com/.default", + }, + ) + resp.raise_for_status() + data = resp.json() + self.logger.info(f"App-only Graph token obtained (expires_in={data.get('expires_in')})") + return str(data["access_token"]) + + async def _exchange_sp_token_with_certificate(self, hostname: str) -> str: + """Exchange certificate credentials for a SharePoint REST API access token.""" + import base64 + import hashlib + import time as _time + + import jwt as pyjwt + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey + from cryptography.x509 import load_pem_x509_certificate + + token_url = f"https://login.microsoftonline.com/{self._tenant_id}/oauth2/v2.0/token" + + loaded_key = serialization.load_pem_private_key(self._private_key.encode(), password=None) + if not isinstance(loaded_key, RSAPrivateKey): + raise ValueError("SharePoint certificate auth requires an RSA private key") + private_key: RSAPrivateKey = loaded_key + + if not self._certificate: + raise ValueError( + "Certificate PEM is required for SP REST API token exchange. " + "Provide the PEM certificate that was uploaded to the Azure AD app registration." + ) + + cert = load_pem_x509_certificate(self._certificate.encode()) + cert_der = cert.public_bytes(serialization.Encoding.DER) + cert_hash = hashlib.sha1(cert_der).digest() # noqa: S324 + x5t = base64.urlsafe_b64encode(cert_hash).rstrip(b"=").decode() + + now = int(_time.time()) + assertion = pyjwt.encode( + { + "aud": token_url, + "iss": self._client_id, + "sub": self._client_id, + "jti": str(now), + "nbf": now, + "exp": now + 600, + }, + private_key, + algorithm="RS256", + headers={"x5t": x5t}, + ) + + async with httpx.AsyncClient() as client: + resp = await client.post( + token_url, + data={ + "grant_type": "client_credentials", + "client_id": self._client_id, + "client_assertion_type": ( + "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" + ), + "client_assertion": assertion, + "scope": f"https://{hostname}/.default", + }, + ) + resp.raise_for_status() + data = resp.json() + self.logger.info( + f"App-only SP token for {hostname} obtained (expires_in={data.get('expires_in')})" + ) + return str(data["access_token"]) + + async def _get_sp_token(self, hostname: str) -> str: + """Get a valid SP REST API token for a hostname, re-exchanging if expired.""" + now = asyncio.get_event_loop().time() + cached = self._sp_tokens.get(hostname) + if cached: + token, expires_at = cached + if now < expires_at: + return token + token = await self._exchange_sp_token_with_certificate(hostname) + self._sp_tokens[hostname] = (token, now + 3500) + return token + + # -- Auth hooks -- + + async def _get_access_token(self) -> str: + now = asyncio.get_event_loop().time() + if self._graph_token and now < self._graph_token_expires: + return self._graph_token + self._graph_token = await self._exchange_graph_token() + self._graph_token_expires = now + 3500 # ~58 min + return self._graph_token + + async def _handle_401(self) -> str: + self._graph_token_expires = 0 # force re-exchange + return await self._get_access_token() + + def _make_sp_token_provider(self) -> Optional[Callable]: + """Create SP token provider via certificate exchange.""" + hostname = self._derive_sp_hostname() + if not hostname: + return None + + async def _provider() -> str: + return await self._get_sp_token(hostname) + + return _provider + + @property + def _delta_prefer_headers(self) -> List[str]: + return [ + "deltashowsharingchanges", + "deltashowremovedasdeleted", + "deltatraversepermissiongaps", + ] + + async def _get_download_auth(self, url: str) -> Any: + """For client-credentials auth, use StaticTokenProvider for Graph URLs.""" + if "tempauth=" in url: + return self.auth # pre-signed URL, no auth needed + graph_token = await self._get_access_token() + return StaticTokenProvider(graph_token) + + async def _discover_sites(self, graph_client: GraphClient) -> List[Dict[str, Any]]: + """Discover sites via getAllSites (application permissions). + + When site_url is set: resolve the specific site. + When empty: use getAllSites for complete enumeration. + """ + sites = [] + + if self._site_url: + parsed = urlparse(self._site_url) + hostname = parsed.netloc + site_path = parsed.path.lstrip("/") + try: + site = await graph_client.get_site_by_url(hostname, site_path) + sites.append(site) + except SourceAuthError: + raise + except Exception as e: + self.logger.warning(f"Could not resolve site URL {self._site_url}: {e}") + raise + else: + async for site in graph_client.get_all_sites(): + if not self._include_personal_sites and site.get("isPersonalSite", False): + continue + sites.append(site) + + self.logger.info(f"Discovered {len(sites)} sites to sync") + return sites diff --git a/backend/tests/unit/platform/configs/test_config_ssrf.py b/backend/tests/unit/platform/configs/test_config_ssrf.py index 115e4daf7..832487b7a 100644 --- a/backend/tests/unit/platform/configs/test_config_ssrf.py +++ b/backend/tests/unit/platform/configs/test_config_ssrf.py @@ -217,21 +217,23 @@ def test_accepts_public_site(self): class TestSharePointOnlineConfig: - def test_rejects_loopback_in_csv(self): + def test_rejects_loopback(self): with pytest.raises(ValidationError, match="SSRF|blocked"): - SharePointOnlineConfig( - site_url="https://ok.sharepoint.com, http://127.0.0.1" - ) + SharePointOnlineConfig(site_url="http://127.0.0.1") def test_empty_site_url_passes(self): cfg = SharePointOnlineConfig(site_url="") assert cfg.site_url == "" - def test_accepts_valid_csv(self): + def test_missing_site_url_defaults_empty(self): + cfg = SharePointOnlineConfig() + assert cfg.site_url == "" + + def test_accepts_valid_site_url(self): cfg = SharePointOnlineConfig( - site_url="https://a.sharepoint.com, https://b.sharepoint.com" + site_url="https://contoso.sharepoint.com/sites/Marketing" ) - assert cfg.site_url == "https://a.sharepoint.com, https://b.sharepoint.com" + assert cfg.site_url == "https://contoso.sharepoint.com/sites/Marketing" class TestSalesforceConfig: From 90fe115327ca8c3f62e91dc439188916b436c288 Mon Sep 17 00:00:00 2001 From: EwanTauran Date: Tue, 14 Apr 2026 09:17:47 +0200 Subject: [PATCH 02/25] fix(docs): update connector links to use the correct documentation paths Updated the links for various connectors in the overview documentation to point to the correct `/docs/connectors/` paths instead of the previous `/connectors/` paths. This change ensures that users can access the appropriate documentation for each connector. --- fern/docs/pages/connectors/overview.mdx | 90 ++++++++++++------------- 1 file changed, 45 insertions(+), 45 deletions(-) diff --git a/fern/docs/pages/connectors/overview.mdx b/fern/docs/pages/connectors/overview.mdx index f937a05bb..cbfde1337 100644 --- a/fern/docs/pages/connectors/overview.mdx +++ b/fern/docs/pages/connectors/overview.mdx @@ -18,14 +18,14 @@ Airweave supports many different types of connectors across productivity tools, ### Popular Connectors - - - - - - - - + + + + + + + + ### All Connectors @@ -33,57 +33,57 @@ Airweave supports many different types of connectors across productivity tools, ### Productivity & Collaboration -- [Notion](/connectors/notion) -- [Slack](/connectors/slack) -- [Asana](/connectors/asana) -- [Monday](/connectors/monday) -- [Linear](/connectors/linear) -- [Trello](/connectors/trello) -- [Clickup](/connectors/clickup) -- [Todoist](/connectors/todoist) -- [Airtable](/connectors/airtable) +- [Notion](/docs/connectors/notion) +- [Slack](/docs/connectors/slack) +- [Asana](/docs/connectors/asana) +- [Monday](/docs/connectors/monday) +- [Linear](/docs/connectors/linear) +- [Trello](/docs/connectors/trello) +- [Clickup](/docs/connectors/clickup) +- [Todoist](/docs/connectors/todoist) +- [Airtable](/docs/connectors/airtable) ### Cloud Storage & Documents -- [Google Drive](/connectors/google_drive) -- [Google Docs](/connectors/google_docs) -- [Google Slides](/connectors/google_slides) -- [Dropbox](/connectors/dropbox) -- [OneDrive](/connectors/onedrive) -- [Box](/connectors/box) -- [SharePoint](/connectors/sharepoint) -- [Word](/connectors/word) -- [OneNote](/connectors/onenote) +- [Google Drive](/docs/connectors/google_drive) +- [Google Docs](/docs/connectors/google_docs) +- [Google Slides](/docs/connectors/google_slides) +- [Dropbox](/docs/connectors/dropbox) +- [OneDrive](/docs/connectors/onedrive) +- [Box](/docs/connectors/box) +- [SharePoint](/docs/connectors/sharepoint) +- [Word](/docs/connectors/word) +- [OneNote](/docs/connectors/onenote) ### Developer Tools -- [GitHub](/connectors/github) -- [GitLab](/connectors/gitlab) -- [Bitbucket](/connectors/bitbucket) -- [Jira](/connectors/jira) -- [Confluence](/connectors/confluence) +- [GitHub](/docs/connectors/github) +- [GitLab](/docs/connectors/gitlab) +- [Bitbucket](/docs/connectors/bitbucket) +- [Jira](/docs/connectors/jira) +- [Confluence](/docs/connectors/confluence) ### CRM & Sales -- [Salesforce](/connectors/salesforce) -- [HubSpot](/connectors/hubspot) -- [Pipedrive](/connectors/pipedrive) -- [Attio](/connectors/attio) -- [Zoho CRM](/connectors/zoho_crm) +- [Salesforce](/docs/connectors/salesforce) +- [HubSpot](/docs/connectors/hubspot) +- [Pipedrive](/docs/connectors/pipedrive) +- [Attio](/docs/connectors/attio) +- [Zoho CRM](/docs/connectors/zoho_crm) ### Communication & Email -- [Gmail](/connectors/gmail) -- [Outlook Mail](/connectors/outlook_mail) -- [Outlook Calendar](/connectors/outlook_calendar) -- [Google Calendar](/connectors/google_calendar) -- [Teams](/connectors/teams) +- [Gmail](/docs/connectors/gmail) +- [Outlook Mail](/docs/connectors/outlook_mail) +- [Outlook Calendar](/docs/connectors/outlook_calendar) +- [Google Calendar](/docs/connectors/google_calendar) +- [Teams](/docs/connectors/teams) ### Support & Service -- [Zendesk](/connectors/zendesk) +- [Zendesk](/docs/connectors/zendesk) ### E-commerce & Payments -- [Shopify](/connectors/shopify) -- [Stripe](/connectors/stripe) +- [Shopify](/docs/connectors/shopify) +- [Stripe](/docs/connectors/stripe) ### Other -- [Ctti](/connectors/ctti) +- [Ctti](/docs/connectors/ctti) From c21ba64db98ca0520a312bc039fe2cdb224d2795 Mon Sep 17 00:00:00 2001 From: EwanTauran Date: Tue, 14 Apr 2026 13:53:19 +0200 Subject: [PATCH 03/25] chore: update Node.js version in GitHub Actions workflow from 18 to 20 --- .github/workflows/fern-docs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/fern-docs.yml b/.github/workflows/fern-docs.yml index 6a652c8b0..054c9bbe1 100644 --- a/.github/workflows/fern-docs.yml +++ b/.github/workflows/fern-docs.yml @@ -71,7 +71,7 @@ jobs: - name: Setup Node.js uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # v6.2.0 with: - node-version: "18" + node-version: "20" - name: Install Fern run: npm install -g fern-api From 4b2e744652c4dacc393013a5eb9f617c352a42e9 Mon Sep 17 00:00:00 2001 From: Rauf Akdemir Date: Wed, 15 Apr 2026 11:31:14 +0200 Subject: [PATCH 04/25] =?UTF-8?q?refactor:=20unified=20SyncService=20?= =?UTF-8?q?=E2=80=94=20consolidate=20lifecycle=20+=20record=20+=20runner?= =?UTF-8?q?=20(#1719)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: consolidate SyncLifecycleService + SyncRecordService into unified SyncService * fix: restore Temporal workflow start for manual sync runs SourceConnectionService.run was missing the Temporal workflow invocation after creating a PENDING sync job, causing API-triggered syncs to remain stuck in PENDING. Also fixes ruff, mypy, and test failures from the unified SyncService refactor. * test: add unit tests for run, get_jobs, cancel_job, count_by_organization Covers the new sync lifecycle proxy methods on SourceConnectionService to improve diff-coverage for the PR. * test: cover force_full_sync, _resolve_collection, _resolve_connection Adds tests for the remaining uncovered paths in SourceConnectionService including error cases for missing collection/connection. * test: add comprehensive unit tests for SyncService lifecycle/CRUD methods Covers get, pause, resume, delete, resolve_destination_ids, trigger_run (validation paths), get_jobs, cancel_job (success/not-found/wrong-status/ temporal-failure/workflow-not-found), validate_force_full_sync, create (federated/no-schedule/with-cron), _resolve_cron, _validate_cron_for_source, _cancel_active_syncs, _wait_for_terminal, and _schedule_cleanup. * fix: use enum value for DB status update + add SyncService unit tests SyncStateMachine.transition was assigning the SyncStatus enum member (name='PAUSED') to the ORM column, but PostgreSQL expects the lowercase value ('paused'). Use .value for correct serialization. Also adds comprehensive unit tests covering get, pause, resume, delete, trigger_run, get_jobs, cancel_job, validate_force_full_sync, create, _resolve_cron, _validate_cron_for_source, and all private helpers. * fix: pass entity count fields through SourceConnectionJob mapping The refactored get_jobs() and cancel_job() methods were constructing SourceConnectionJob objects without entity count fields (entities_inserted, entities_updated, entities_deleted), causing them to always default to 0. This broke E2E tests that assert entities_inserted > 0 after a completed sync. * Complete SourceConnectionJob field mapping with duration_seconds, entities_failed, error_category * fix: add missing mock fields to _make_sync_job_schema The MagicMock-based sync job fixture was missing error_category and entities_skipped, causing Pydantic validation to fail when the service mapped the mock to SourceConnectionJob. * chore: remove audit script and output files from PR * fix: address PR #1719 review comments across sync and source_connection domains - SyncService.create: raise ValueError instead of returning None for federated sources and no-schedule cases - SyncService.delete: accept single sync_id instead of List[UUID] - SyncService.trigger_run: consolidate Temporal workflow start internally, accept collection/connection/force_full_sync params - SyncService.get: raise HTTPException(404) instead of ValueError - SyncService.get_jobs: push limit to DB query instead of in-memory slice - SyncService.cancel_job: add retry logic for Temporal cancellation - SourceConnectionService: remove temporal_workflow_service dependency, remove try/except around event bus publish, use NotFoundException in _resolve helpers - SourceConnectionUpdateService: delegate sync creation to sync_service.create instead of direct repo access - SourceConnectionDeletionService: simplify collection field access - CollectionService.delete: use asyncio.gather for per-sync deletion - OAuthCallbackService: fix resume reason string - FakeSyncService: add explicit SyncServiceProtocol inheritance - Update all protocols, fakes, and test files to match new signatures * Update backend/airweave/domains/collections/service.py Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com> * Update backend/airweave/domains/source_connections/service.py Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com> * fix: guard against scheduling syncs for federated search sources in update service Add upfront federated_search check in SourceConnectionUpdateService before calling sync_service.create, returning HTTP 400 instead of letting the ValueError propagate from SyncService. * fix: remove stray paren and lint violation breaking CI The asyncio.gather → sequential loop conversion left an unmatched ')' in CollectionService.delete, causing a SyntaxError that cascaded across import-linter, test, and test-public-api checks. Also breaks the _duration_seconds signature to satisfy ruff E501 (line too long). * fix: resolve mypy Column[UUID] and Liskov substitution violations - collections/service: use schema result.id (UUID) instead of ORM db_obj.id (Column[UUID]) for sync_service.delete and repo.remove - delete/update: validate ORM collection to CollectionRecord schema before passing .id to sync_service methods - jobs/repository: widen ctx parameter from ApiContext to BaseContext in get_all_by_sync_id to satisfy the protocol contract * fix: guard sync creation for null-schedule case and fix test mocks - create.py: skip sync_service.create when schedule.cron is explicitly null and sync_immediately is false (prevents ValueError → 500) - delete.py: inline model_validate call to satisfy ruff format - test_delete/test_update: add required CollectionRecord fields to mock collection objects so model_validate succeeds * fix: ruff format and mypy str|None for sync_service.create name arg Inline resolve_destination_ids calls (ruff format) and use `obj_in.name or entry.name` to guarantee str for the name parameter. * fix: remove redundant sync_id ownership check from cancel_job The job lookup in SyncService.cancel_job is already org-scoped via ctx, so the extra sync_id check in SourceConnectionService was unnecessary. Simplifies cancel_job to delegate directly to the sync service. --------- Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com> --- backend/airweave/core/container/container.py | 7 - backend/airweave/core/container/factory.py | 48 +- backend/airweave/crud/crud_sync_job.py | 7 +- .../airweave/domains/collections/service.py | 31 +- .../domains/collections/tests/test_service.py | 36 +- .../domains/connect/tests/conftest.py | 4 +- .../domains/oauth/callback_service.py | 30 +- .../oauth/tests/test_callback_service.py | 30 +- .../domains/source_connections/create.py | 88 +- .../domains/source_connections/delete.py | 129 +-- .../source_connections/fakes/service.py | 14 +- .../domains/source_connections/service.py | 158 +++- .../source_connections/tests/test_create.py | 6 +- .../source_connections/tests/test_delete.py | 208 +---- .../tests/test_fake_service.py | 6 +- .../source_connections/tests/test_service.py | 306 +++++- .../source_connections/tests/test_update.py | 114 ++- .../domains/source_connections/update.py | 73 +- .../domains/syncs/fakes/lifecycle_service.py | 129 --- .../domains/syncs/fakes/record_service.py | 91 -- .../airweave/domains/syncs/fakes/service.py | 177 +++- .../domains/syncs/jobs/fakes/repository.py | 16 +- .../airweave/domains/syncs/jobs/protocols.py | 6 +- .../airweave/domains/syncs/jobs/repository.py | 9 +- .../domains/syncs/lifecycle_service.py | 466 ---------- backend/airweave/domains/syncs/protocols.py | 165 ++-- .../airweave/domains/syncs/record_service.py | 139 --- backend/airweave/domains/syncs/service.py | 518 ++++++++++- .../syncs/tests/test_lifecycle_service.py | 876 ------------------ .../syncs/tests/test_record_service.py | 254 ----- .../domains/syncs/tests/test_service.py | 608 +++++++++++- backend/conftest.py | 26 +- 32 files changed, 2114 insertions(+), 2661 deletions(-) delete mode 100644 backend/airweave/domains/syncs/fakes/lifecycle_service.py delete mode 100644 backend/airweave/domains/syncs/fakes/record_service.py delete mode 100644 backend/airweave/domains/syncs/lifecycle_service.py delete mode 100644 backend/airweave/domains/syncs/record_service.py delete mode 100644 backend/airweave/domains/syncs/tests/test_lifecycle_service.py delete mode 100644 backend/airweave/domains/syncs/tests/test_record_service.py diff --git a/backend/airweave/core/container/container.py b/backend/airweave/core/container/container.py index b2fb6d25a..4b0822ad0 100644 --- a/backend/airweave/core/container/container.py +++ b/backend/airweave/core/container/container.py @@ -99,11 +99,8 @@ ) from airweave.domains.syncs.protocols import ( SyncCursorRepositoryProtocol, - SyncLifecycleServiceProtocol, - SyncRecordServiceProtocol, SyncRepositoryProtocol, SyncServiceProtocol, - SyncStateMachineProtocol, ) from airweave.domains.temporal.protocols import ( TemporalScheduleServiceProtocol, @@ -203,15 +200,11 @@ async def my_endpoint(event_bus: EventBus = Inject(EventBus)): # Sync domain sync_repo: SyncRepositoryProtocol sync_cursor_repo: SyncCursorRepositoryProtocol - # Sync cursor service — cursor CRUD operations sync_cursor_service: SyncCursorService sync_job_repo: SyncJobRepositoryProtocol - sync_record_service: SyncRecordServiceProtocol sync_job_service: SyncJobServiceProtocol sync_job_state_machine: SyncJobStateMachineProtocol - sync_state_machine: SyncStateMachineProtocol sync_service: SyncServiceProtocol - sync_lifecycle: SyncLifecycleServiceProtocol sync_factory: SyncFactoryProtocol entity_repo: EntityRepositoryProtocol diff --git a/backend/airweave/core/container/factory.py b/backend/airweave/core/container/factory.py index 4d3d55ac4..9fe45e4cd 100644 --- a/backend/airweave/core/container/factory.py +++ b/backend/airweave/core/container/factory.py @@ -129,8 +129,6 @@ from airweave.domains.syncs.jobs.repository import SyncJobRepository from airweave.domains.syncs.jobs.service import SyncJobService from airweave.domains.syncs.jobs.state_machine import SyncJobStateMachine -from airweave.domains.syncs.lifecycle_service import SyncLifecycleService -from airweave.domains.syncs.record_service import SyncRecordService from airweave.domains.syncs.repository import SyncRepository from airweave.domains.syncs.service import SyncService from airweave.domains.syncs.state_machine import SyncStateMachine @@ -401,41 +399,24 @@ def create_container(settings: Settings) -> Container: ) sync_service = SyncService( - state_machine=sync_deps["sync_job_state_machine"], - sync_factory=sync_factory, - sync_state_machine=sync_deps["sync_state_machine"], - ) - - sync_record_service = SyncRecordService( sync_repo=source_deps["sync_repo"], sync_job_repo=source_deps["sync_job_repo"], - connection_repo=source_deps["conn_repo"], - ) - - sync_lifecycle = SyncLifecycleService( - sc_repo=source_deps["sc_repo"], - collection_repo=source_deps["collection_repo"], - connection_repo=source_deps["conn_repo"], sync_cursor_repo=sync_deps["sync_cursor_repo"], - sync_service=sync_record_service, - state_machine=sync_deps["sync_job_state_machine"], - sync_job_repo=source_deps["sync_job_repo"], + state_machine=sync_deps["sync_state_machine"], + job_state_machine=sync_deps["sync_job_state_machine"], temporal_workflow_service=sync_deps["temporal_workflow_service"], temporal_schedule_service=sync_deps["temporal_schedule_service"], - response_builder=sync_deps["response_builder"], - event_bus=event_bus, + sync_factory=sync_factory, ) # ----------------------------------------------------------------- - # Source connection sub-services + # Source connection sub-services (need sync_service) # ----------------------------------------------------------------- deletion_service = SourceConnectionDeletionService( sc_repo=source_deps["sc_repo"], collection_repo=source_deps["collection_repo"], - sync_job_repo=source_deps["sync_job_repo"], - sync_lifecycle=sync_lifecycle, response_builder=sync_deps["response_builder"], - temporal_workflow_service=sync_deps["temporal_workflow_service"], + sync_service=sync_service, ) update_service = SourceConnectionUpdateService( sc_repo=source_deps["sc_repo"], @@ -443,13 +424,13 @@ def create_container(settings: Settings) -> Container: connection_repo=source_deps["conn_repo"], cred_repo=source_deps["cred_repo"], sync_repo=source_deps["sync_repo"], - sync_record_service=sync_record_service, + sync_service=sync_service, source_service=source_deps["source_service"], + source_registry=source_deps["source_registry"], source_validation=source_validation, credential_encryptor=encryptor, response_builder=sync_deps["response_builder"], temporal_schedule_service=sync_deps["temporal_schedule_service"], - sync_state_machine=sync_deps["sync_state_machine"], ) create_service = SourceConnectionCreationService( sc_repo=source_deps["sc_repo"], @@ -459,8 +440,7 @@ def create_container(settings: Settings) -> Container: source_registry=source_deps["source_registry"], source_validation=source_validation, source_lifecycle=source_deps["source_lifecycle_service"], - sync_lifecycle=sync_lifecycle, - sync_record_service=sync_record_service, + sync_service=sync_service, response_builder=sync_deps["response_builder"], oauth_flow_service=oauth_flow_svc, temporal_workflow_service=sync_deps["temporal_workflow_service"], @@ -476,7 +456,8 @@ def create_container(settings: Settings) -> Container: source_registry=source_deps["source_registry"], auth_provider_registry=source_deps["auth_provider_registry"], response_builder=sync_deps["response_builder"], - sync_lifecycle=sync_lifecycle, + sync_service=sync_service, + event_bus=event_bus, create_service=create_service, update_service=update_service, deletion_service=deletion_service, @@ -505,7 +486,7 @@ def create_container(settings: Settings) -> Container: collection_service = CollectionService( collection_repo=source_deps["collection_repo"], sc_repo=source_deps["sc_repo"], - sync_lifecycle=sync_lifecycle, + sync_service=sync_service, event_bus=event_bus, settings=settings, deployment_metadata_repo=VectorDbDeploymentMetadataRepository(), @@ -521,10 +502,8 @@ def create_container(settings: Settings) -> Container: response_builder=sync_deps["response_builder"], source_registry=source_deps["source_registry"], source_lifecycle=source_deps["source_lifecycle_service"], - sync_lifecycle=sync_lifecycle, - sync_record_service=sync_record_service, + sync_service=sync_service, temporal_workflow_service=sync_deps["temporal_workflow_service"], - sync_state_machine=sync_deps["sync_state_machine"], event_bus=event_bus, organization_repo=OrgRepo(), sc_repo=source_deps["sc_repo"], @@ -619,9 +598,6 @@ def create_container(settings: Settings) -> Container: sync_job_service=sync_deps["sync_job_service"], sync_job_state_machine=sync_deps["sync_job_state_machine"], sync_service=sync_service, - sync_record_service=sync_record_service, - sync_lifecycle=sync_lifecycle, - sync_state_machine=sync_deps["sync_state_machine"], sync_factory=sync_factory, entity_repo=sync_deps["entity_repo"], access_broker=access_broker, diff --git a/backend/airweave/crud/crud_sync_job.py b/backend/airweave/crud/crud_sync_job.py index 7056a868d..fa906b682 100644 --- a/backend/airweave/crud/crud_sync_job.py +++ b/backend/airweave/crud/crud_sync_job.py @@ -39,8 +39,9 @@ async def get_all_by_sync_id( db: AsyncSession, sync_id: UUID, status: Optional[list[str]] = None, + limit: Optional[int] = None, ) -> list[SyncJob]: - """Get all jobs for a specific sync, optionally filtered by status.""" + """Get jobs for a sync; optional status filter; newest first; optional row limit.""" stmt = ( select(SyncJob, Sync.name.label("sync_name")) .join(Sync, SyncJob.sync_id == Sync.id) @@ -52,6 +53,10 @@ async def get_all_by_sync_id( # Database enum already uses uppercase values stmt = stmt.where(SyncJob.status.in_(status)) + stmt = stmt.order_by(SyncJob.created_at.desc()) + if limit is not None: + stmt = stmt.limit(limit) + result = await db.execute(stmt) jobs = [] for job, sync_name in result: diff --git a/backend/airweave/domains/collections/service.py b/backend/airweave/domains/collections/service.py index c77a6b3b6..86f77712a 100644 --- a/backend/airweave/domains/collections/service.py +++ b/backend/airweave/domains/collections/service.py @@ -21,7 +21,7 @@ ) from airweave.domains.embedders.protocols import DenseEmbedderRegistryProtocol from airweave.domains.source_connections.protocols import SourceConnectionRepositoryProtocol -from airweave.domains.syncs.protocols import SyncLifecycleServiceProtocol +from airweave.domains.syncs.protocols import SyncServiceProtocol from airweave.models.collection import Collection from airweave.schemas.collection import SourceConnectionSummary @@ -33,7 +33,7 @@ def __init__( self, collection_repo: CollectionRepositoryProtocol, sc_repo: SourceConnectionRepositoryProtocol, - sync_lifecycle: SyncLifecycleServiceProtocol, + sync_service: SyncServiceProtocol, event_bus: EventBus, settings: Settings, deployment_metadata_repo: VectorDbDeploymentMetadataRepositoryProtocol, @@ -42,7 +42,7 @@ def __init__( """Initialize with injected dependencies.""" self._collection_repo = collection_repo self._sc_repo = sc_repo - self._sync_lifecycle = sync_lifecycle + self._sync_service = sync_service self._event_bus = event_bus self._settings = settings self._deployment_metadata_repo = deployment_metadata_repo @@ -179,34 +179,31 @@ async def delete( if db_obj is None: raise CollectionNotFoundError(readable_id) - collection_id = db_obj.id - organization_id = ctx.organization.id - # Snapshot while session is fresh (teardown expires all objects via db.expire_all) result = self._to_response(db_obj) # Collect sync IDs before CASCADE removes them sync_ids = await self._sc_repo.get_sync_ids_for_collection( - db, organization_id=organization_id, readable_collection_id=result.readable_id + db, organization_id=ctx.organization.id, readable_collection_id=result.readable_id ) - # Cancel running workflows and wait for workers to stop - await self._sync_lifecycle.teardown_syncs_for_collection( - db, - sync_ids=sync_ids, - collection_id=collection_id, - organization_id=organization_id, - ctx=ctx, - ) + for sid in sync_ids: + await self._sync_service.delete( + db, + sync_id=sid, + collection_id=result.id, + organization_id=ctx.organization.id, + ctx=ctx, + ) # CASCADE-delete the collection and all child objects - await self._collection_repo.remove(db, id=collection_id, ctx=ctx) + await self._collection_repo.remove(db, id=result.id, ctx=ctx) # Publish event try: await self._event_bus.publish( CollectionLifecycleEvent.deleted( - organization_id=organization_id, + organization_id=ctx.organization.id, collection_id=result.id, collection_name=result.name, collection_readable_id=result.readable_id, diff --git a/backend/airweave/domains/collections/tests/test_service.py b/backend/airweave/domains/collections/tests/test_service.py index 1b16d38aa..7d019faf4 100644 --- a/backend/airweave/domains/collections/tests/test_service.py +++ b/backend/airweave/domains/collections/tests/test_service.py @@ -28,7 +28,7 @@ from airweave.domains.source_connections.fakes.repository import ( FakeSourceConnectionRepository, ) -from airweave.domains.syncs.fakes.lifecycle_service import FakeSyncLifecycleService +from airweave.domains.syncs.fakes.service import FakeSyncService from airweave.models.collection import Collection from airweave.schemas.organization import Organization @@ -132,7 +132,7 @@ def _fake_dense_registry() -> FakeDenseEmbedderRegistry: def _build_service( collection_repo=None, sc_repo=None, - sync_lifecycle=None, + sync_service=None, event_bus=None, settings=None, deployment_metadata_repo=None, @@ -141,7 +141,7 @@ def _build_service( return CollectionService( collection_repo=collection_repo or FakeCollectionRepository(), sc_repo=sc_repo or FakeSourceConnectionRepository(), - sync_lifecycle=sync_lifecycle or FakeSyncLifecycleService(), + sync_service=sync_service or FakeSyncService(), event_bus=event_bus or _FakeEventBus(), settings=settings or _fake_settings(), deployment_metadata_repo=deployment_metadata_repo @@ -388,7 +388,7 @@ async def test_delete_full_flow(): """delete() gathers sync IDs, calls teardown, cascade-deletes, publishes event.""" repo = FakeCollectionRepository() sc_repo = FakeSourceConnectionRepository() - sync_lifecycle = FakeSyncLifecycleService() + sync_service = FakeSyncService() event_bus = _FakeEventBus() col = _collection() @@ -402,7 +402,7 @@ async def test_delete_full_flow(): svc = _build_service( collection_repo=repo, sc_repo=sc_repo, - sync_lifecycle=sync_lifecycle, + sync_service=sync_service, event_bus=event_bus, ) @@ -410,13 +410,13 @@ async def test_delete_full_flow(): assert result is not None - # Verify teardown was called with correct args - teardown_calls = [c for c in sync_lifecycle._calls if c[0] == "teardown_syncs_for_collection"] - assert len(teardown_calls) == 1 - _, _, sync_ids, coll_id, org_id, _ = teardown_calls[0] - assert set(sync_ids) == {sync_id_1, sync_id_2} - assert coll_id == COLLECTION_ID - assert org_id == ORG_ID + # Verify one delete per sync ID: ("delete", sync_id, collection_id, organization_id) + delete_calls = [c for c in sync_service._calls if c[0] == "delete"] + assert len(delete_calls) == 2 + assert {c[1] for c in delete_calls} == {sync_id_1, sync_id_2} + for c in delete_calls: + assert c[2] == COLLECTION_ID # collection_id + assert c[3] == ORG_ID # organization_id # Verify cascade delete was called remove_calls = [c for c in repo._calls if c[0] == "remove"] @@ -438,10 +438,10 @@ async def test_delete_not_found(): @pytest.mark.asyncio async def test_delete_no_syncs(): - """delete() works when collection has no syncs — teardown called with empty list.""" + """delete() works when collection has no syncs — no sync delete calls (empty gather).""" repo = FakeCollectionRepository() sc_repo = FakeSourceConnectionRepository() - sync_lifecycle = FakeSyncLifecycleService() + sync_service = FakeSyncService() event_bus = _FakeEventBus() col = _collection() @@ -452,7 +452,7 @@ async def test_delete_no_syncs(): svc = _build_service( collection_repo=repo, sc_repo=sc_repo, - sync_lifecycle=sync_lifecycle, + sync_service=sync_service, event_bus=event_bus, ) @@ -460,7 +460,5 @@ async def test_delete_no_syncs(): assert result is not None - teardown_calls = [c for c in sync_lifecycle._calls if c[0] == "teardown_syncs_for_collection"] - assert len(teardown_calls) == 1 - _, _, sync_ids, _, _, _ = teardown_calls[0] - assert sync_ids == [] + delete_calls = [c for c in sync_service._calls if c[0] == "delete"] + assert len(delete_calls) == 0 diff --git a/backend/airweave/domains/connect/tests/conftest.py b/backend/airweave/domains/connect/tests/conftest.py index 2237507db..342f7c414 100644 --- a/backend/airweave/domains/connect/tests/conftest.py +++ b/backend/airweave/domains/connect/tests/conftest.py @@ -63,8 +63,8 @@ def org_repo(): @pytest.fixture -def sc_service(fake_sync_lifecycle): - return FakeSourceConnectionService(sync_lifecycle=fake_sync_lifecycle) +def sc_service(fake_sync_service): + return FakeSourceConnectionService(sync_service=fake_sync_service) @pytest.fixture diff --git a/backend/airweave/domains/oauth/callback_service.py b/backend/airweave/domains/oauth/callback_service.py index a27f7eb3b..a682eda78 100644 --- a/backend/airweave/domains/oauth/callback_service.py +++ b/backend/airweave/domains/oauth/callback_service.py @@ -22,7 +22,7 @@ from airweave.core.logging import logger from airweave.core.protocols.encryption import CredentialEncryptor from airweave.core.protocols.event_bus import EventBus -from airweave.core.shared_models import AuthMethod, ConnectionStatus, SyncJobStatus, SyncStatus +from airweave.core.shared_models import AuthMethod, ConnectionStatus, SyncJobStatus from airweave.db.unit_of_work import UnitOfWork from airweave.domains.collections.protocols import CollectionRepositoryProtocol from airweave.domains.connections.protocols import ConnectionRepositoryProtocol @@ -45,12 +45,7 @@ ) from airweave.domains.sources.types import SourceRegistryEntry from airweave.domains.syncs.jobs.protocols import SyncJobRepositoryProtocol -from airweave.domains.syncs.protocols import ( - SyncLifecycleServiceProtocol, - SyncRecordServiceProtocol, - SyncRepositoryProtocol, - SyncStateMachineProtocol, -) +from airweave.domains.syncs.protocols import SyncRepositoryProtocol, SyncServiceProtocol from airweave.domains.syncs.types import InvalidSyncTransitionError, OptimisticLockError from airweave.domains.temporal.protocols import TemporalWorkflowServiceProtocol from airweave.models.collection import Collection @@ -87,10 +82,8 @@ def __init__( response_builder: ResponseBuilderProtocol, source_registry: SourceRegistryProtocol, source_lifecycle: SourceLifecycleServiceProtocol, - sync_lifecycle: SyncLifecycleServiceProtocol, - sync_record_service: SyncRecordServiceProtocol, + sync_service: SyncServiceProtocol, temporal_workflow_service: TemporalWorkflowServiceProtocol, - sync_state_machine: SyncStateMachineProtocol, event_bus: EventBus, organization_repo: OrganizationRepositoryProtocol, sc_repo: SourceConnectionRepositoryProtocol, @@ -107,10 +100,8 @@ def __init__( self._response_builder = response_builder self._source_registry = source_registry self._source_lifecycle = source_lifecycle - self._sync_lifecycle = sync_lifecycle - self._sync_record_service = sync_record_service + self._sync_service = sync_service self._temporal_workflow_service = temporal_workflow_service - self._sync_state_machine = sync_state_machine self._event_bus = event_bus self._organization_repo = organization_repo self._sc_repo = sc_repo @@ -530,11 +521,9 @@ async def _complete_connection_common( # noqa: C901 if raw_cron: schedule_config = ScheduleConfig(cron=raw_cron) - destination_ids = await self._sync_record_service.resolve_destination_ids( - uow.session, ctx - ) + destination_ids = await self._sync_service.resolve_destination_ids(uow.session, ctx) - sync_result = await self._sync_lifecycle.provision_sync( + sync_result = await self._sync_service.create( uow.session, name=payload.get("name") or source_entry.name, source_connection_id=connection.id, @@ -576,10 +565,9 @@ async def _complete_connection_common( # noqa: C901 if source_conn.sync_id: try: - await self._sync_state_machine.transition( - sync_id=source_conn.sync_id, - target=SyncStatus.ACTIVE, - ctx=ctx, + await self._sync_service.resume( + source_conn.sync_id, + ctx, reason="OAuth completed", ) except (InvalidSyncTransitionError, OptimisticLockError, ValueError): diff --git a/backend/airweave/domains/oauth/tests/test_callback_service.py b/backend/airweave/domains/oauth/tests/test_callback_service.py index 82bd3ad9a..a2b1a762c 100644 --- a/backend/airweave/domains/oauth/tests/test_callback_service.py +++ b/backend/airweave/domains/oauth/tests/test_callback_service.py @@ -120,10 +120,8 @@ def _service( response_builder=None, source_registry=None, source_lifecycle=None, - sync_lifecycle=None, - sync_record_service=None, + sync_service=None, temporal_workflow_service=None, - sync_state_machine=None, event_bus=None, ) -> OAuthCallbackService: return OAuthCallbackService( @@ -132,10 +130,8 @@ def _service( response_builder=response_builder or AsyncMock(), source_registry=source_registry or MagicMock(), source_lifecycle=source_lifecycle or AsyncMock(), - sync_lifecycle=sync_lifecycle or AsyncMock(), - sync_record_service=sync_record_service or AsyncMock(), + sync_service=sync_service or AsyncMock(), temporal_workflow_service=temporal_workflow_service or AsyncMock(), - sync_state_machine=sync_state_machine or AsyncMock(), event_bus=event_bus or AsyncMock(), organization_repo=organization_repo or FakeOrganizationRepository(), sc_repo=sc_repo or FakeSourceConnectionRepository(), @@ -869,8 +865,8 @@ async def test_federated_source_skips_sync_provisioning(self): svc._source_registry.get = MagicMock( return_value=SimpleNamespace(source_class_ref=SimpleNamespace(federated_search=True)) ) - svc._sync_record_service.resolve_destination_ids = AsyncMock(return_value=[uuid4()]) - svc._sync_lifecycle.provision_sync = AsyncMock() + svc._sync_service.resolve_destination_ids = AsyncMock(return_value=[uuid4()]) + svc._sync_service.create = AsyncMock() from airweave.domains.oauth import callback_service as callback_module @@ -916,7 +912,7 @@ async def commit(self): finally: monkeypatch.undo() - svc._sync_lifecycle.provision_sync.assert_not_awaited() + svc._sync_service.create.assert_not_awaited() async def test_claim_token_session_skips_mark_completed(self): svc = _service() @@ -931,9 +927,9 @@ async def test_claim_token_session_skips_mark_completed(self): return_value=SimpleNamespace(id=sc_id, connection_id=conn_id, sync_id=None) ) svc._init_session_repo.mark_completed = AsyncMock() - svc._sync_record_service.resolve_destination_ids = AsyncMock(return_value=[uuid4()]) + svc._sync_service.resolve_destination_ids = AsyncMock(return_value=[uuid4()]) sync_id = uuid4() - svc._sync_lifecycle.provision_sync = AsyncMock( + svc._sync_service.create = AsyncMock( return_value=SimpleNamespace(sync_id=sync_id) ) @@ -1000,9 +996,9 @@ async def test_non_federated_source_provisions_sync_with_cron_schedule(self): svc._source_registry.get = MagicMock( return_value=SimpleNamespace(source_class_ref=SimpleNamespace(federated_search=False)) ) - svc._sync_record_service.resolve_destination_ids = AsyncMock(return_value=[uuid4()]) + svc._sync_service.resolve_destination_ids = AsyncMock(return_value=[uuid4()]) sync_id = uuid4() - svc._sync_lifecycle.provision_sync = AsyncMock( + svc._sync_service.create = AsyncMock( return_value=SimpleNamespace(sync_id=sync_id) ) @@ -1050,7 +1046,7 @@ async def commit(self): finally: monkeypatch.undo() - svc._sync_lifecycle.provision_sync.assert_awaited_once() + svc._sync_service.create.assert_awaited_once() svc._init_session_repo.mark_completed.assert_awaited_once() @@ -2219,7 +2215,7 @@ async def test_complete_connection_common_activates_sync(self): """_complete_connection_common transitions sync to ACTIVE after commit.""" from unittest.mock import patch - sync_state_machine = AsyncMock() + sync_service = AsyncMock() sc_repo = FakeSourceConnectionRepository() shell = _source_conn_shell() shell.connection_id = uuid4() @@ -2227,7 +2223,7 @@ async def test_complete_connection_common_activates_sync(self): svc = _service( sc_repo=sc_repo, - sync_state_machine=sync_state_machine, + sync_service=sync_service, ) svc._validate_config = MagicMock(return_value={}) collection = MagicMock() @@ -2267,4 +2263,4 @@ async def test_complete_connection_common_activates_sync(self): has_claim_token=True, ) - sync_state_machine.transition.assert_awaited_once() + sync_service.resume.assert_awaited_once() diff --git a/backend/airweave/domains/source_connections/create.py b/backend/airweave/domains/source_connections/create.py index e1c6ade60..63bc07aca 100644 --- a/backend/airweave/domains/source_connections/create.py +++ b/backend/airweave/domains/source_connections/create.py @@ -35,10 +35,7 @@ SourceValidationServiceProtocol, ) from airweave.domains.syncs.jobs.protocols import SyncJobRepositoryProtocol -from airweave.domains.syncs.protocols import ( - SyncLifecycleServiceProtocol, - SyncRecordServiceProtocol, -) +from airweave.domains.syncs.protocols import SyncServiceProtocol from airweave.domains.temporal.protocols import TemporalWorkflowServiceProtocol from airweave.models.connection_init_session import ConnectionInitStatus from airweave.schemas.connection import ConnectionCreate @@ -75,8 +72,7 @@ def __init__( source_registry: SourceRegistryProtocol, source_validation: SourceValidationServiceProtocol, source_lifecycle: SourceLifecycleServiceProtocol, - sync_lifecycle: SyncLifecycleServiceProtocol, - sync_record_service: SyncRecordServiceProtocol, + sync_service: SyncServiceProtocol, response_builder: ResponseBuilderProtocol, oauth_flow_service: OAuthFlowServiceProtocol, temporal_workflow_service: TemporalWorkflowServiceProtocol, @@ -92,8 +88,7 @@ def __init__( self._source_registry = source_registry self._source_validation = source_validation self._source_lifecycle = source_lifecycle - self._sync_lifecycle = sync_lifecycle - self._sync_record_service = sync_record_service + self._sync_service = sync_service self._response_builder = response_builder self._oauth_flow_service = oauth_flow_service self._temporal_workflow_service = temporal_workflow_service @@ -450,23 +445,27 @@ async def _create_with_auth_provider( ) await uow.session.flush() connection_schema = schemas.Connection.model_validate(connection, from_attributes=True) - destination_ids = await self._sync_record_service.resolve_destination_ids( - uow.session, ctx - ) - sync_result = await self._sync_lifecycle.provision_sync( - uow.session, - name=obj_in.name, - source_connection_id=connection.id, - destination_connection_ids=destination_ids, - collection_id=collection.id, - collection_readable_id=collection.readable_id, - source_entry=entry, - schedule_config=obj_in.schedule, - run_immediately=bool(obj_in.sync_immediately), - ctx=ctx, - uow=uow, - ) - await uow.session.flush() + + has_schedule = obj_in.schedule is None or ( + obj_in.schedule and obj_in.schedule.cron is not None + ) + sync_result = None + if bool(obj_in.sync_immediately) or has_schedule: + destination_ids = await self._sync_service.resolve_destination_ids(uow.session, ctx) + sync_result = await self._sync_service.create( + uow.session, + name=obj_in.name or entry.name, + source_connection_id=connection.id, + destination_connection_ids=destination_ids, + collection_id=collection.id, + collection_readable_id=collection.readable_id, + source_entry=entry, + schedule_config=obj_in.schedule, + run_immediately=bool(obj_in.sync_immediately), + ctx=ctx, + uow=uow, + ) + await uow.session.flush() source_conn = await self._sc_repo.create( uow.session, @@ -658,23 +657,28 @@ async def _create_authenticated_connection( ) await uow.session.flush() connection_schema = schemas.Connection.model_validate(connection, from_attributes=True) - destination_ids = await self._sync_record_service.resolve_destination_ids( - uow.session, ctx - ) - sync_result = await self._sync_lifecycle.provision_sync( - uow.session, - name=obj_in.name, - source_connection_id=connection.id, - destination_connection_ids=destination_ids, - collection_id=collection.id, - collection_readable_id=collection.readable_id, - source_entry=entry, - schedule_config=obj_in.schedule, - run_immediately=bool(obj_in.sync_immediately), - ctx=ctx, - uow=uow, - ) - await uow.session.flush() + + has_schedule = obj_in.schedule is None or ( + obj_in.schedule and obj_in.schedule.cron is not None + ) + sync_result = None + if bool(obj_in.sync_immediately) or has_schedule: + destination_ids = await self._sync_service.resolve_destination_ids(uow.session, ctx) + sync_result = await self._sync_service.create( + uow.session, + name=obj_in.name or entry.name, + source_connection_id=connection.id, + destination_connection_ids=destination_ids, + collection_id=collection.id, + collection_readable_id=collection.readable_id, + source_entry=entry, + schedule_config=obj_in.schedule, + run_immediately=bool(obj_in.sync_immediately), + ctx=ctx, + uow=uow, + ) + await uow.session.flush() + source_conn = await self._sc_repo.create( uow.session, obj_in={ diff --git a/backend/airweave/domains/source_connections/delete.py b/backend/airweave/domains/source_connections/delete.py index 2c073b7ae..8a25a1416 100644 --- a/backend/airweave/domains/source_connections/delete.py +++ b/backend/airweave/domains/source_connections/delete.py @@ -1,22 +1,19 @@ """Source connection deletion service.""" -import asyncio from uuid import UUID from sqlalchemy.ext.asyncio import AsyncSession +from airweave import schemas from airweave.api.context import ApiContext from airweave.core.exceptions import NotFoundException -from airweave.core.shared_models import SyncJobStatus from airweave.domains.collections.protocols import CollectionRepositoryProtocol from airweave.domains.source_connections.protocols import ( ResponseBuilderProtocol, SourceConnectionDeletionServiceProtocol, SourceConnectionRepositoryProtocol, ) -from airweave.domains.syncs.jobs.protocols import SyncJobRepositoryProtocol -from airweave.domains.syncs.protocols import SyncLifecycleServiceProtocol -from airweave.domains.temporal.protocols import TemporalWorkflowServiceProtocol +from airweave.domains.syncs.protocols import SyncServiceProtocol from airweave.schemas.source_connection import SourceConnection as SourceConnectionSchema @@ -24,27 +21,21 @@ class SourceConnectionDeletionService(SourceConnectionDeletionServiceProtocol): """Deletes a source connection and all related data. The flow is: - 1. Cancel any running sync workflows and wait for them to stop. + 1. Delegate cancel + wait + cleanup scheduling to SyncService.delete. 2. CASCADE-delete the DB records (source connection, sync, jobs, entities). - 3. Fire-and-forget a Temporal cleanup workflow for the slow external - data deletion (Vespa, ARF, schedules) which can take minutes. """ - def __init__( + def __init__( # noqa: D107 self, sc_repo: SourceConnectionRepositoryProtocol, collection_repo: CollectionRepositoryProtocol, - sync_job_repo: SyncJobRepositoryProtocol, - sync_lifecycle: SyncLifecycleServiceProtocol, response_builder: ResponseBuilderProtocol, - temporal_workflow_service: TemporalWorkflowServiceProtocol, + sync_service: SyncServiceProtocol, ) -> None: self._sc_repo = sc_repo self._collection_repo = collection_repo - self._sync_job_repo = sync_job_repo - self._sync_lifecycle = sync_lifecycle self._response_builder = response_builder - self._temporal_workflow_service = temporal_workflow_service + self._sync_service = sync_service async def delete( self, @@ -58,112 +49,26 @@ async def delete( if not source_conn: raise NotFoundException("Source connection not found") - # Capture attributes upfront to avoid lazy-loading issues after session changes sync_id = source_conn.sync_id - collection = await self._collection_repo.get_by_readable_id( + collection_orm = await self._collection_repo.get_by_readable_id( db, readable_id=source_conn.readable_collection_id, ctx=ctx ) - if not collection: + if not collection_orm: raise NotFoundException("Collection not found") - collection_id = str(collection.id) - organization_id = str(collection.organization_id) + collection = schemas.CollectionRecord.model_validate(collection_orm, from_attributes=True) - # Build response before deletion response = await self._response_builder.build_response(db, source_conn, ctx) - # Cancel any running jobs and wait for the Temporal workflow to - # terminate before we cascade-delete the DB rows. if sync_id: - latest_job = await self._sync_job_repo.get_latest_by_sync_id(db, sync_id=sync_id) - if latest_job and latest_job.status in [ - SyncJobStatus.PENDING, - SyncJobStatus.RUNNING, - SyncJobStatus.CANCELLING, - ]: - if latest_job.status in [SyncJobStatus.PENDING, SyncJobStatus.RUNNING]: - ctx.logger.info( - f"Cancelling job {latest_job.id} for source connection {id} before deletion" - ) - try: - await self._sync_lifecycle.cancel_job( - db, - source_connection_id=id, - job_id=latest_job.id, - ctx=ctx, - ) - except Exception as e: - ctx.logger.warning( - f"Failed to cancel job {latest_job.id} during deletion: {e}" - ) + await self._sync_service.delete( + db, + sync_id=sync_id, + collection_id=collection.id, + organization_id=collection.organization_id, + ctx=ctx, + cancel_timeout_seconds=15, + ) - # BARRIER: Wait for the workflow to reach a terminal state so - # the worker stops writing before we cascade-delete the rows. - reached_terminal = await self._wait_for_sync_job_terminal_state( - db, sync_id, timeout_seconds=15 - ) - if not reached_terminal: - ctx.logger.warning( - f"Job for sync {sync_id} did not reach terminal state within 15s " - f"-- proceeding with deletion anyway" - ) - - # Delete the source connection first (CASCADE removes sync, jobs, entities). await self._sc_repo.remove(db, id=id, ctx=ctx) - # Fire-and-forget: schedule async cleanup of external data (Vespa, ARF, - # Temporal schedules). This can take minutes for Vespa and must not - # block the API response. - if sync_id: - try: - await self._temporal_workflow_service.start_cleanup_sync_data_workflow( - sync_ids=[str(sync_id)], - collection_id=collection_id, - organization_id=organization_id, - ctx=ctx, - ) - except Exception as e: - ctx.logger.error( - f"Failed to schedule async cleanup for sync {sync_id}: {e}. " - f"Data may be orphaned in Vespa/ARF." - ) - return response - - async def _wait_for_sync_job_terminal_state( - self, - db: AsyncSession, - sync_id: UUID, - *, - timeout_seconds: int = 30, - poll_interval: float = 1.0, - ) -> bool: - """Wait for the latest sync job to reach a terminal state. - - Polls the database until the job reaches COMPLETED, FAILED, or CANCELLED. - Used as a cancellation barrier to prevent cleanup from running while - a Temporal worker is still actively writing. - - Args: - db: Database session. - sync_id: Sync ID whose latest job to monitor. - timeout_seconds: Maximum time to wait before giving up. - poll_interval: Seconds between poll attempts. - - Returns: - True if a terminal state was reached, False on timeout. - """ - terminal_states = { - SyncJobStatus.COMPLETED, - SyncJobStatus.FAILED, - SyncJobStatus.CANCELLED, - } - elapsed = 0.0 - while elapsed < timeout_seconds: - await asyncio.sleep(poll_interval) - elapsed += poll_interval - # Expire cached ORM objects to force a fresh read from the database - db.expire_all() - job = await self._sync_job_repo.get_latest_by_sync_id(db, sync_id=sync_id) - if job and job.status in terminal_states: - return True - return False diff --git a/backend/airweave/domains/source_connections/fakes/service.py b/backend/airweave/domains/source_connections/fakes/service.py index 547fd932b..8e0280e3e 100644 --- a/backend/airweave/domains/source_connections/fakes/service.py +++ b/backend/airweave/domains/source_connections/fakes/service.py @@ -12,7 +12,7 @@ SourceConnectionDeletionServiceProtocol, SourceConnectionUpdateServiceProtocol, ) -from airweave.domains.syncs.protocols import SyncLifecycleServiceProtocol +from airweave.domains.syncs.protocols import SyncServiceProtocol from airweave.models.source_connection import SourceConnection from airweave.schemas.source_connection import ( SourceConnection as SourceConnectionSchema, @@ -30,7 +30,7 @@ class FakeSourceConnectionService: def __init__( self, - sync_lifecycle: SyncLifecycleServiceProtocol, + sync_service: SyncServiceProtocol, create_service: Optional[SourceConnectionCreateServiceProtocol] = None, update_service: Optional[SourceConnectionUpdateServiceProtocol] = None, deletion_service: Optional[SourceConnectionDeletionServiceProtocol] = None, @@ -39,7 +39,7 @@ def __init__( self._list_items: List[SourceConnectionListItem] = [] self._redirect_urls: dict[str, str] = {} self._calls: list[tuple[Any, ...]] = [] - self._sync_lifecycle = sync_lifecycle + self._sync_service = sync_service self._create_service = create_service self._update_service = update_service self._deletion_service = deletion_service @@ -112,7 +112,7 @@ async def run( force_full_sync: bool = False, ) -> SourceConnectionJob: self._calls.append(("run", db, id, ctx, force_full_sync)) - return await self._sync_lifecycle.run(db, id=id, ctx=ctx, force_full_sync=force_full_sync) + raise NotImplementedError("FakeSourceConnectionService.run not wired") async def get_jobs( self, @@ -123,7 +123,7 @@ async def get_jobs( limit: int = 100, ) -> List[SourceConnectionJob]: self._calls.append(("get_jobs", db, id, ctx, limit)) - return await self._sync_lifecycle.get_jobs(db, id=id, ctx=ctx, limit=limit) + return [] async def cancel_job( self, @@ -134,9 +134,7 @@ async def cancel_job( ctx: ApiContext, ) -> SourceConnectionJob: self._calls.append(("cancel_job", db, source_connection_id, job_id, ctx)) - return await self._sync_lifecycle.cancel_job( - db, source_connection_id=source_connection_id, job_id=job_id, ctx=ctx - ) + raise NotImplementedError("FakeSourceConnectionService.cancel_job not wired") async def get_sync_id(self, db: AsyncSession, *, id: UUID, ctx: ApiContext) -> dict: self._calls.append(("get_sync_id", db, id, ctx)) diff --git a/backend/airweave/domains/source_connections/service.py b/backend/airweave/domains/source_connections/service.py index 674bb1564..33cc05ef6 100644 --- a/backend/airweave/domains/source_connections/service.py +++ b/backend/airweave/domains/source_connections/service.py @@ -1,13 +1,17 @@ """Service for source connections.""" +from datetime import datetime from typing import List, Optional from uuid import UUID from sqlalchemy.ext.asyncio import AsyncSession +from airweave import schemas from airweave.api.context import ApiContext from airweave.core.datetime_utils import utc_now +from airweave.core.events.sync import SyncLifecycleEvent from airweave.core.exceptions import NotFoundException +from airweave.core.protocols.event_bus import EventBus from airweave.domains.auth_provider.protocols import AuthProviderRegistryProtocol from airweave.domains.collections.protocols import CollectionRepositoryProtocol from airweave.domains.connections.protocols import ConnectionRepositoryProtocol @@ -21,7 +25,7 @@ SourceConnectionUpdateServiceProtocol, ) from airweave.domains.sources.protocols import SourceRegistryProtocol -from airweave.domains.syncs.protocols import SyncLifecycleServiceProtocol +from airweave.domains.syncs.protocols import SyncServiceProtocol from airweave.models.source_connection import SourceConnection from airweave.schemas.source_connection import ( SourceConnection as SourceConnectionSchema, @@ -34,6 +38,14 @@ ) +def _duration_seconds( + started_at: Optional[datetime], completed_at: Optional[datetime] +) -> Optional[float]: + if started_at and completed_at: + return (completed_at - started_at).total_seconds() + return None + + class SourceConnectionService(SourceConnectionServiceProtocol): """Service for source connections.""" @@ -49,7 +61,8 @@ def __init__( # noqa: D107 auth_provider_registry: AuthProviderRegistryProtocol, # Helpers response_builder: ResponseBuilderProtocol, - sync_lifecycle: SyncLifecycleServiceProtocol, + sync_service: SyncServiceProtocol, + event_bus: EventBus, # Sub-services create_service: SourceConnectionCreateServiceProtocol, update_service: SourceConnectionUpdateServiceProtocol, @@ -62,7 +75,8 @@ def __init__( # noqa: D107 self.source_registry = source_registry self.auth_provider_registry = auth_provider_registry self.response_builder = response_builder - self._sync_lifecycle = sync_lifecycle + self._sync_service = sync_service + self._event_bus = event_bus self._create_service = create_service self._update_service = update_service self._deletion_service = deletion_service @@ -142,7 +156,8 @@ async def delete(self, db: AsyncSession, id: UUID, ctx: ApiContext) -> SourceCon return await self._deletion_service.delete(db, id=id, ctx=ctx) # ------------------------------------------------------------------ - # Sync lifecycle proxies + # Sync lifecycle proxies — resolve source_connection → sync_id, then + # delegate to the unified SyncService and map results. # ------------------------------------------------------------------ async def run( @@ -154,7 +169,52 @@ async def run( force_full_sync: bool = False, ) -> SourceConnectionJob: """Trigger a sync run for this source connection.""" - return await self._sync_lifecycle.run(db, id=id, ctx=ctx, force_full_sync=force_full_sync) + source_conn = await self._resolve_source_connection(db, id, ctx) + sync_id = source_conn.sync_id + assert sync_id is not None + + if force_full_sync: + await self._sync_service.validate_force_full_sync(db, sync_id, ctx) + + collection = await self._resolve_collection(db, source_conn, ctx) + connection = await self._resolve_connection(db, source_conn, ctx) + + sync, sync_job = await self._sync_service.trigger_run( + db, + sync_id=sync_id, + collection=collection, + connection=connection, + ctx=ctx, + force_full_sync=force_full_sync, + ) + + await self._event_bus.publish( + SyncLifecycleEvent.pending( + organization_id=ctx.organization.id, + source_connection_id=id, + sync_job_id=sync_job.id, + sync_id=sync_id, + collection_id=collection.id, + source_type=connection.short_name, + collection_name=collection.name, + collection_readable_id=collection.readable_id, + ) + ) + + return SourceConnectionJob( + id=sync_job.id, + source_connection_id=id, + status=sync_job.status, + started_at=sync_job.started_at, + completed_at=sync_job.completed_at, + duration_seconds=_duration_seconds(sync_job.started_at, sync_job.completed_at), + entities_inserted=sync_job.entities_inserted or 0, + entities_updated=sync_job.entities_updated or 0, + entities_deleted=sync_job.entities_deleted or 0, + entities_failed=sync_job.entities_skipped or 0, + error=sync_job.error, + error_category=sync_job.error_category, + ) async def get_jobs( self, @@ -165,7 +225,29 @@ async def get_jobs( limit: int = 100, ) -> List[SourceConnectionJob]: """List sync jobs for this source connection.""" - return await self._sync_lifecycle.get_jobs(db, id=id, ctx=ctx, limit=limit) + source_conn = await self._resolve_source_connection(db, id, ctx) + sync_id = source_conn.sync_id + assert sync_id is not None + + jobs = await self._sync_service.get_jobs(db, sync_id=sync_id, ctx=ctx, limit=limit) + + return [ + SourceConnectionJob( + id=j.id, + source_connection_id=id, + status=j.status, + started_at=j.started_at, + completed_at=j.completed_at, + duration_seconds=_duration_seconds(j.started_at, j.completed_at), + entities_inserted=j.entities_inserted or 0, + entities_updated=j.entities_updated or 0, + entities_deleted=j.entities_deleted or 0, + entities_failed=j.entities_skipped or 0, + error=j.error, + error_category=j.error_category, + ) + for j in jobs + ] async def cancel_job( self, @@ -176,8 +258,21 @@ async def cancel_job( ctx: ApiContext, ) -> SourceConnectionJob: """Cancel a running sync job.""" - return await self._sync_lifecycle.cancel_job( - db, source_connection_id=source_connection_id, job_id=job_id, ctx=ctx + sync_job = await self._sync_service.cancel_job(db, job_id=job_id, ctx=ctx) + + return SourceConnectionJob( + id=sync_job.id, + source_connection_id=source_connection_id, + status=sync_job.status, + started_at=sync_job.started_at, + completed_at=sync_job.completed_at, + duration_seconds=_duration_seconds(sync_job.started_at, sync_job.completed_at), + entities_inserted=sync_job.entities_inserted or 0, + entities_updated=sync_job.entities_updated or 0, + entities_deleted=sync_job.entities_deleted or 0, + entities_failed=sync_job.entities_skipped or 0, + error=sync_job.error, + error_category=sync_job.error_category, ) async def get_sync_id(self, db: AsyncSession, *, id: UUID, ctx: ApiContext) -> dict: @@ -194,15 +289,50 @@ async def count_by_organization(self, db: AsyncSession, organization_id: UUID) - return await self.sc_repo.count_by_organization(db, organization_id) async def get_redirect_url(self, db: AsyncSession, *, code: str) -> str: - """Resolve a short redirect code to its final OAuth authorization URL. - - The redirect session is atomically consumed (deleted) on lookup, - enforcing one-time use per CASA Requirement #23. - """ + """Resolve a short redirect code to its final OAuth authorization URL.""" redirect_info = await self._redirect_session_repo.consume(db, code=code) if not redirect_info: raise NotFoundException("Authorization link expired or invalid") - # Check expiry *after* consume so expired tokens can't be replayed. if redirect_info.expires_at <= utc_now(): raise NotFoundException("Authorization link expired or invalid") return redirect_info.final_url + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + async def _resolve_source_connection( + self, db: AsyncSession, id: UUID, ctx: ApiContext + ) -> SourceConnection: + """Get a source connection and validate it has an associated sync.""" + source_conn = await self.sc_repo.get(db, id=id, ctx=ctx) + if not source_conn: + raise NotFoundException("Source connection not found") + if not source_conn.sync_id: + raise NotFoundException("No sync found for this source connection") + return source_conn + + async def _resolve_collection( + self, db: AsyncSession, source_conn: SourceConnection, ctx: ApiContext + ) -> schemas.CollectionRecord: + """Resolve the CollectionRecord schema for a source connection.""" + readable_id = source_conn.readable_collection_id + if not readable_id: + raise NotFoundException( + f"Source connection {source_conn.id} has no readable_collection_id" + ) + collection = await self.collection_repo.get_by_readable_id(db, str(readable_id), ctx) + if not collection: + raise NotFoundException("Collection not found") + return schemas.CollectionRecord.model_validate(collection, from_attributes=True) + + async def _resolve_connection( + self, db: AsyncSession, source_conn: SourceConnection, ctx: ApiContext + ) -> schemas.Connection: + """Resolve the Connection schema (not SourceConnection) for a source connection.""" + if not source_conn.connection_id: + raise NotFoundException(f"Source connection {source_conn.id} has no connection_id") + conn = await self.connection_repo.get(db, source_conn.connection_id, ctx) + if not conn: + raise NotFoundException(f"Connection {source_conn.connection_id} not found") + return schemas.Connection.model_validate(conn, from_attributes=True) diff --git a/backend/airweave/domains/source_connections/tests/test_create.py b/backend/airweave/domains/source_connections/tests/test_create.py index 61eef2dec..b57f596f2 100644 --- a/backend/airweave/domains/source_connections/tests/test_create.py +++ b/backend/airweave/domains/source_connections/tests/test_create.py @@ -26,8 +26,7 @@ from airweave.domains.sources.fakes.registry import FakeSourceRegistry from airweave.domains.sources.fakes.validation import FakeSourceValidationService from airweave.domains.syncs.jobs.fakes.repository import FakeSyncJobRepository -from airweave.domains.syncs.fakes.lifecycle_service import FakeSyncLifecycleService -from airweave.domains.syncs.fakes.record_service import FakeSyncRecordService +from airweave.domains.syncs.fakes.service import FakeSyncService from airweave.domains.temporal.fakes.service import FakeTemporalWorkflowService from airweave.schemas.organization import Organization from airweave.schemas.source_connection import ( @@ -86,8 +85,7 @@ def _service(entry) -> SourceConnectionCreationService: source_registry=registry, source_validation=FakeSourceValidationService(), source_lifecycle=FakeSourceLifecycleService(), - sync_lifecycle=FakeSyncLifecycleService(), - sync_record_service=FakeSyncRecordService(), + sync_service=FakeSyncService(), response_builder=FakeResponseBuilder(), oauth_flow_service=FakeOAuthFlowService(), temporal_workflow_service=FakeTemporalWorkflowService(), diff --git a/backend/airweave/domains/source_connections/tests/test_delete.py b/backend/airweave/domains/source_connections/tests/test_delete.py index 22b8d07fc..a70bf4428 100644 --- a/backend/airweave/domains/source_connections/tests/test_delete.py +++ b/backend/airweave/domains/source_connections/tests/test_delete.py @@ -1,37 +1,31 @@ """Unit tests for SourceConnectionDeletionService. Table-driven tests covering: -- Happy paths: no sync, completed job, running/cancelling/pending jobs -- Error paths: not found, collection not found, cancel failure, cleanup failure, timeout +- Happy paths: no sync vs with sync (delegates to sync_service.delete) +- Error paths: not found, collection not found """ from dataclasses import dataclass from datetime import datetime, timezone -from typing import Optional from unittest.mock import AsyncMock, MagicMock from uuid import uuid4 import pytest -import airweave.domains.source_connections.delete as delete_module from airweave.api.context import ApiContext from airweave.core.exceptions import NotFoundException from airweave.core.logging import logger -from airweave.core.shared_models import AuthMethod, SyncJobStatus +from airweave.core.shared_models import AuthMethod from airweave.domains.collections.fakes.repository import FakeCollectionRepository from airweave.domains.source_connections.delete import SourceConnectionDeletionService from airweave.domains.source_connections.fakes.repository import ( FakeSourceConnectionRepository, ) from airweave.domains.source_connections.fakes.response import FakeResponseBuilder -from airweave.domains.syncs.jobs.fakes.repository import FakeSyncJobRepository -from airweave.domains.syncs.fakes.lifecycle_service import FakeSyncLifecycleService -from airweave.domains.temporal.fakes.service import FakeTemporalWorkflowService +from airweave.domains.syncs.fakes.service import FakeSyncService from airweave.models.collection import Collection from airweave.models.source_connection import SourceConnection -from airweave.models.sync_job import SyncJob from airweave.schemas.organization import Organization -from airweave.schemas.source_connection import SourceConnectionJob NOW = datetime.now(timezone.utc) ORG_ID = uuid4() @@ -67,33 +61,28 @@ def _make_collection(*, id=None, readable_id="test-col"): col = MagicMock(spec=Collection) col.id = id or COLLECTION_ID col.readable_id = readable_id + col.name = "Test Collection" col.organization_id = ORG_ID + col.vector_db_deployment_metadata_id = uuid4() + col.sync_config = None + col.created_at = NOW + col.modified_at = NOW + col.created_by_email = None + col.modified_by_email = None return col -def _make_job(*, status=SyncJobStatus.COMPLETED, sync_id=None): - job = MagicMock(spec=SyncJob) - job.id = uuid4() - job.sync_id = sync_id or uuid4() - job.status = status - return job - - def _build_service( sc_repo=None, collection_repo=None, - sync_job_repo=None, - sync_lifecycle=None, response_builder=None, - temporal_workflow_service=None, + sync_service=None, ): return SourceConnectionDeletionService( sc_repo=sc_repo or FakeSourceConnectionRepository(), collection_repo=collection_repo or FakeCollectionRepository(), - sync_job_repo=sync_job_repo or FakeSyncJobRepository(), - sync_lifecycle=sync_lifecycle or FakeSyncLifecycleService(), response_builder=response_builder or FakeResponseBuilder(), - temporal_workflow_service=temporal_workflow_service or FakeTemporalWorkflowService(), + sync_service=sync_service or FakeSyncService(), ) @@ -106,18 +95,12 @@ def _build_service( class DeleteCase: desc: str has_sync: bool - job_status: Optional[SyncJobStatus] - expect_cancel: bool - expect_wait: bool - expect_cleanup: bool + expect_sync_delete: bool DELETE_CASES = [ - DeleteCase("no_sync", has_sync=False, job_status=None, expect_cancel=False, expect_wait=False, expect_cleanup=False), - DeleteCase("sync_no_running_job", has_sync=True, job_status=SyncJobStatus.COMPLETED, expect_cancel=False, expect_wait=False, expect_cleanup=True), - DeleteCase("running_job", has_sync=True, job_status=SyncJobStatus.RUNNING, expect_cancel=True, expect_wait=True, expect_cleanup=True), - DeleteCase("cancelling_job", has_sync=True, job_status=SyncJobStatus.CANCELLING, expect_cancel=False, expect_wait=True, expect_cleanup=True), - DeleteCase("pending_job", has_sync=True, job_status=SyncJobStatus.PENDING, expect_cancel=True, expect_wait=True, expect_cleanup=True), + DeleteCase("no_sync", has_sync=False, expect_sync_delete=False), + DeleteCase("with_sync", has_sync=True, expect_sync_delete=True), ] @@ -132,38 +115,23 @@ async def test_delete_happy_path(case: DeleteCase): col_repo = FakeCollectionRepository() col_repo.seed_readable(sc.readable_collection_id, col) - job_repo = FakeSyncJobRepository() - if case.job_status is not None and sync_id: - job_repo.seed_last_job(sync_id, _make_job(status=case.job_status, sync_id=sync_id)) - - lifecycle = FakeSyncLifecycleService() - if case.expect_cancel: - lifecycle.set_cancel_result(MagicMock(spec=SourceConnectionJob)) - - temporal = FakeTemporalWorkflowService() + sync_service = FakeSyncService() svc = _build_service( sc_repo=sc_repo, collection_repo=col_repo, - sync_job_repo=job_repo, - sync_lifecycle=lifecycle, - temporal_workflow_service=temporal, + sync_service=sync_service, ) - if case.expect_wait: - svc._wait_for_sync_job_terminal_state = AsyncMock(return_value=True) - result = await svc.delete(AsyncMock(), id=sc.id, ctx=_make_ctx()) assert result.id == sc.id assert sc_repo._store.get(sc.id) is None - if case.expect_cancel: - assert any(c[0] == "cancel_job" for c in lifecycle._calls) - if case.expect_wait: - svc._wait_for_sync_job_terminal_state.assert_awaited_once() - if case.expect_cleanup: - assert any(c[0] == "start_cleanup_sync_data_workflow" for c in temporal._calls) + if case.expect_sync_delete: + assert any(c[0] == "delete" for c in sync_service._calls) + else: + assert not any(c[0] == "delete" for c in sync_service._calls) # --------------------------------------------------------------------------- @@ -203,8 +171,8 @@ async def test_delete_error(case: DeleteErrorCase): await svc.delete(AsyncMock(), id=sc.id, ctx=_make_ctx()) -async def test_delete_cancel_failure_is_swallowed(): - """Cancel failure during delete is warned but not re-raised.""" +async def test_delete_sync_service_failure_propagates(): + """If sync_service.delete raises, the error propagates.""" sync_id = uuid4() sc = _make_sc(sync_id=sync_id) col = _make_collection() @@ -214,132 +182,14 @@ async def test_delete_cancel_failure_is_swallowed(): col_repo = FakeCollectionRepository() col_repo.seed_readable(sc.readable_collection_id, col) - running_job = _make_job(status=SyncJobStatus.RUNNING, sync_id=sync_id) - job_repo = FakeSyncJobRepository() - job_repo.seed_last_job(sync_id, running_job) - - lifecycle = FakeSyncLifecycleService() - lifecycle.set_error(RuntimeError("cancel boom")) - temporal = FakeTemporalWorkflowService() + sync_service = FakeSyncService() + sync_service.set_error(RuntimeError("sync delete boom")) svc = _build_service( sc_repo=sc_repo, collection_repo=col_repo, - sync_job_repo=job_repo, - sync_lifecycle=lifecycle, - temporal_workflow_service=temporal, + sync_service=sync_service, ) - svc._wait_for_sync_job_terminal_state = AsyncMock(return_value=True) - - result = await svc.delete(AsyncMock(), id=sc.id, ctx=_make_ctx()) - assert result.id == sc.id - -async def test_delete_temporal_cleanup_failure_is_logged(): - """Temporal cleanup failure is logged but not re-raised.""" - sync_id = uuid4() - sc = _make_sc(sync_id=sync_id) - col = _make_collection() - - sc_repo = FakeSourceConnectionRepository() - sc_repo.seed(sc.id, sc) - col_repo = FakeCollectionRepository() - col_repo.seed_readable(sc.readable_collection_id, col) - - job_repo = FakeSyncJobRepository() - completed_job = _make_job(status=SyncJobStatus.COMPLETED, sync_id=sync_id) - job_repo.seed_last_job(sync_id, completed_job) - - temporal = FakeTemporalWorkflowService() - temporal.set_error(RuntimeError("cleanup boom")) - - svc = _build_service( - sc_repo=sc_repo, - collection_repo=col_repo, - sync_job_repo=job_repo, - temporal_workflow_service=temporal, - ) - result = await svc.delete(AsyncMock(), id=sc.id, ctx=_make_ctx()) - assert result.id == sc.id - - -async def test_delete_wait_timeout_proceeds(): - """If wait_for_terminal_state returns False, deletion still proceeds.""" - sync_id = uuid4() - sc = _make_sc(sync_id=sync_id) - col = _make_collection() - - sc_repo = FakeSourceConnectionRepository() - sc_repo.seed(sc.id, sc) - col_repo = FakeCollectionRepository() - col_repo.seed_readable(sc.readable_collection_id, col) - - running_job = _make_job(status=SyncJobStatus.RUNNING, sync_id=sync_id) - job_repo = FakeSyncJobRepository() - job_repo.seed_last_job(sync_id, running_job) - - lifecycle = FakeSyncLifecycleService() - lifecycle.set_cancel_result(MagicMock(spec=SourceConnectionJob)) - temporal = FakeTemporalWorkflowService() - - svc = _build_service( - sc_repo=sc_repo, - collection_repo=col_repo, - sync_job_repo=job_repo, - sync_lifecycle=lifecycle, - temporal_workflow_service=temporal, - ) - svc._wait_for_sync_job_terminal_state = AsyncMock(return_value=False) - - result = await svc.delete(AsyncMock(), id=sc.id, ctx=_make_ctx()) - assert result.id == sc.id - assert sc_repo._store.get(sc.id) is None - - -async def test_wait_for_sync_job_terminal_state_reaches_terminal(monkeypatch): - sync_id = uuid4() - running_job = _make_job(status=SyncJobStatus.RUNNING, sync_id=sync_id) - cancelled_job = _make_job(status=SyncJobStatus.CANCELLED, sync_id=sync_id) - - job_repo = FakeSyncJobRepository() - job_repo.get_latest_by_sync_id = AsyncMock(side_effect=[running_job, cancelled_job]) # type: ignore[method-assign] - svc = _build_service(sync_job_repo=job_repo) - - async def _no_sleep(_: float) -> None: - return None - - monkeypatch.setattr(delete_module.asyncio, "sleep", _no_sleep) - - db = MagicMock() - db.expire_all = MagicMock() - - reached = await svc._wait_for_sync_job_terminal_state( - db, sync_id, timeout_seconds=2, poll_interval=1 - ) - - assert reached is True - assert db.expire_all.call_count == 2 - - -async def test_wait_for_sync_job_terminal_state_times_out(monkeypatch): - sync_id = uuid4() - running_job = _make_job(status=SyncJobStatus.RUNNING, sync_id=sync_id) - - job_repo = FakeSyncJobRepository() - job_repo.get_latest_by_sync_id = AsyncMock(side_effect=[running_job, running_job]) # type: ignore[method-assign] - svc = _build_service(sync_job_repo=job_repo) - - async def _no_sleep(_: float) -> None: - return None - - monkeypatch.setattr(delete_module.asyncio, "sleep", _no_sleep) - - db = MagicMock() - db.expire_all = MagicMock() - - reached = await svc._wait_for_sync_job_terminal_state( - db, sync_id, timeout_seconds=2, poll_interval=1 - ) - - assert reached is False - assert db.expire_all.call_count == 2 + with pytest.raises(RuntimeError, match="sync delete boom"): + await svc.delete(AsyncMock(), id=sc.id, ctx=_make_ctx()) diff --git a/backend/airweave/domains/source_connections/tests/test_fake_service.py b/backend/airweave/domains/source_connections/tests/test_fake_service.py index bd619d0e8..bba4f3e76 100644 --- a/backend/airweave/domains/source_connections/tests/test_fake_service.py +++ b/backend/airweave/domains/source_connections/tests/test_fake_service.py @@ -8,7 +8,7 @@ from airweave.core.exceptions import NotFoundException from airweave.domains.source_connections.fakes.delete import FakeSourceConnectionDeletionService from airweave.domains.source_connections.fakes.service import FakeSourceConnectionService -from airweave.domains.syncs.fakes.lifecycle_service import FakeSyncLifecycleService +from airweave.domains.syncs.fakes.service import FakeSyncService def _make_source_connection(): @@ -23,7 +23,7 @@ async def test_delete_delegation_removes_fake_store_entry_on_success(): deletion_service.seed_response(source_connection.id, MagicMock()) service = FakeSourceConnectionService( - sync_lifecycle=FakeSyncLifecycleService(), + sync_service=FakeSyncService(), deletion_service=deletion_service, ) service.seed(source_connection.id, source_connection) @@ -40,7 +40,7 @@ async def test_delete_delegation_keeps_fake_store_entry_when_delegate_raises(): deletion_service.set_should_raise(RuntimeError("boom")) service = FakeSourceConnectionService( - sync_lifecycle=FakeSyncLifecycleService(), + sync_service=FakeSyncService(), deletion_service=deletion_service, ) service.seed(source_connection.id, source_connection) diff --git a/backend/airweave/domains/source_connections/tests/test_service.py b/backend/airweave/domains/source_connections/tests/test_service.py index b8f2a911e..def0b1169 100644 --- a/backend/airweave/domains/source_connections/tests/test_service.py +++ b/backend/airweave/domains/source_connections/tests/test_service.py @@ -10,31 +10,30 @@ from datetime import datetime, timedelta from typing import Optional from unittest.mock import AsyncMock, MagicMock -from uuid import uuid4 +from uuid import UUID, uuid4 import pytest +from airweave.api.context import ApiContext from airweave.core.datetime_utils import utc_now from airweave.core.exceptions import NotFoundException - -from airweave.api.context import ApiContext from airweave.core.logging import logger from airweave.core.shared_models import AuthMethod, SourceConnectionStatus, SyncJobStatus from airweave.domains.auth_provider.fake import FakeAuthProviderRegistry from airweave.domains.collections.fakes.repository import FakeCollectionRepository from airweave.domains.connections.fakes.repository import FakeConnectionRepository from airweave.domains.oauth.fakes.repository import FakeOAuthRedirectSessionRepository +from airweave.domains.source_connections.fakes.create import FakeSourceConnectionCreateService +from airweave.domains.source_connections.fakes.delete import FakeSourceConnectionDeletionService from airweave.domains.source_connections.fakes.repository import ( FakeSourceConnectionRepository, ) from airweave.domains.source_connections.fakes.response import FakeResponseBuilder +from airweave.domains.source_connections.fakes.update import FakeSourceConnectionUpdateService from airweave.domains.source_connections.service import SourceConnectionService from airweave.domains.source_connections.types import LastJobInfo, SourceConnectionStats from airweave.domains.sources.fakes.registry import FakeSourceRegistry -from airweave.domains.source_connections.fakes.delete import FakeSourceConnectionDeletionService -from airweave.domains.source_connections.fakes.create import FakeSourceConnectionCreateService -from airweave.domains.source_connections.fakes.update import FakeSourceConnectionUpdateService -from airweave.domains.syncs.fakes.lifecycle_service import FakeSyncLifecycleService +from airweave.domains.syncs.fakes.service import FakeSyncService from airweave.schemas.organization import Organization from airweave.schemas.source_connection import AuthenticationMethod, SourceConnectionListItem @@ -99,7 +98,8 @@ def _build_service( source_registry=FakeSourceRegistry(), auth_provider_registry=FakeAuthProviderRegistry(), response_builder=FakeResponseBuilder(), - sync_lifecycle=FakeSyncLifecycleService(), + sync_service=FakeSyncService(), + event_bus=AsyncMock(), create_service=FakeSourceConnectionCreateService(), update_service=FakeSourceConnectionUpdateService(), deletion_service=FakeSourceConnectionDeletionService(), @@ -399,3 +399,293 @@ async def test_get_redirect_url_missing_code_raises(): svc = _build_service() with pytest.raises(NotFoundException, match="Authorization link expired or invalid"): await svc.get_redirect_url(AsyncMock(), code="nonexistent") + + +# --------------------------------------------------------------------------- +# Sync lifecycle proxies: run, get_jobs, cancel_job, count_by_organization +# --------------------------------------------------------------------------- + +SYNC_ID = uuid4() +SC_ID = uuid4() +JOB_ID = uuid4() +CONN_ID = uuid4() +COL_READABLE_ID = "test-col-abc" + + +def _make_source_conn( + *, + id: UUID = SC_ID, + sync_id: UUID = SYNC_ID, + connection_id: UUID = CONN_ID, + readable_collection_id: str = COL_READABLE_ID, +): + sc = MagicMock() + sc.id = id + sc.sync_id = sync_id + sc.connection_id = connection_id + sc.readable_collection_id = readable_collection_id + return sc + + +def _make_collection(readable_id: str = COL_READABLE_ID): + col = MagicMock() + col.id = uuid4() + col.name = "Test Collection" + col.readable_id = readable_id + col.organization_id = ORG_ID + return col + + +def _make_connection(id: UUID = CONN_ID): + conn = MagicMock() + conn.id = id + conn.short_name = "github" + conn.name = "GitHub" + return conn + + +def _make_sync_schema(): + return MagicMock(spec_set=["id", "status"]) + + +def _make_sync_job_schema(*, sync_id: UUID = SYNC_ID, job_id: UUID = JOB_ID): + job = MagicMock() + job.id = job_id + job.sync_id = sync_id + job.status = SyncJobStatus.PENDING + job.started_at = None + job.completed_at = None + job.error = None + job.error_category = None + job.entities_inserted = 0 + job.entities_updated = 0 + job.entities_deleted = 0 + job.entities_skipped = 0 + return job + + +def _build_run_service( + sc_repo=None, + sync_service=None, + collection_repo=None, + connection_repo=None, + event_bus=None, +): + return SourceConnectionService( + sc_repo=sc_repo or FakeSourceConnectionRepository(), + collection_repo=collection_repo or FakeCollectionRepository(), + connection_repo=connection_repo or FakeConnectionRepository(), + redirect_session_repo=FakeOAuthRedirectSessionRepository(), + source_registry=FakeSourceRegistry(), + auth_provider_registry=FakeAuthProviderRegistry(), + response_builder=FakeResponseBuilder(), + sync_service=sync_service or FakeSyncService(), + event_bus=event_bus or AsyncMock(), + create_service=FakeSourceConnectionCreateService(), + update_service=FakeSourceConnectionUpdateService(), + deletion_service=FakeSourceConnectionDeletionService(), + ) + + +class _RecordingFakeSyncService(FakeSyncService): + """Records keyword arguments passed to trigger_run for assertions.""" + + def __init__(self) -> None: + super().__init__() + self.last_trigger_run: Optional[dict] = None + + async def trigger_run( + self, + db, + *, + sync_id, + collection, + connection, + ctx, + force_full_sync: bool = False, + ): + self.last_trigger_run = { + "sync_id": sync_id, + "collection": collection, + "connection": connection, + "force_full_sync": force_full_sync, + } + return await super().trigger_run( + db, + sync_id=sync_id, + collection=collection, + connection=connection, + ctx=ctx, + force_full_sync=force_full_sync, + ) + + +async def test_run_triggers_workflow_and_returns_job(): + sc = _make_source_conn() + sc_repo = FakeSourceConnectionRepository() + sc_repo.seed(SC_ID, sc) + + sync_svc = _RecordingFakeSyncService() + sync_svc.set_trigger_run_result(_make_sync_schema(), _make_sync_job_schema()) + + event_bus = AsyncMock() + + svc = _build_run_service( + sc_repo=sc_repo, + sync_service=sync_svc, + event_bus=event_bus, + ) + + col_id = uuid4() + col_schema = MagicMock(id=col_id, readable_id="col-x") + col_schema.name = "Col" + conn_schema = MagicMock(short_name="github") + svc._resolve_collection = AsyncMock(return_value=col_schema) + svc._resolve_connection = AsyncMock(return_value=conn_schema) + + result = await svc.run(AsyncMock(), id=SC_ID, ctx=_make_ctx()) + + assert result.id == JOB_ID + assert result.source_connection_id == SC_ID + assert result.status == SyncJobStatus.PENDING + assert sync_svc.last_trigger_run is not None + assert sync_svc.last_trigger_run["sync_id"] == SYNC_ID + assert sync_svc.last_trigger_run["collection"] is col_schema + assert sync_svc.last_trigger_run["connection"] is conn_schema + assert sync_svc.last_trigger_run["force_full_sync"] is False + + +async def test_run_event_failure_propagates(): + sc = _make_source_conn() + sc_repo = FakeSourceConnectionRepository() + sc_repo.seed(SC_ID, sc) + + sync_svc = FakeSyncService() + sync_svc.set_trigger_run_result(_make_sync_schema(), _make_sync_job_schema()) + + event_bus = AsyncMock() + event_bus.publish.side_effect = RuntimeError("event bus down") + + svc = _build_run_service(sc_repo=sc_repo, sync_service=sync_svc, event_bus=event_bus) + col_schema = MagicMock(id=uuid4(), readable_id="col-x") + col_schema.name = "Col" + svc._resolve_collection = AsyncMock(return_value=col_schema) + svc._resolve_connection = AsyncMock(return_value=MagicMock(short_name="github")) + + with pytest.raises(RuntimeError, match="event bus down"): + await svc.run(AsyncMock(), id=SC_ID, ctx=_make_ctx()) + + +async def test_run_not_found_raises(): + svc = _build_run_service() + with pytest.raises(NotFoundException, match="Source connection not found"): + await svc.run(AsyncMock(), id=uuid4(), ctx=_make_ctx()) + + +async def test_get_jobs_returns_mapped_jobs(): + sc = _make_source_conn() + sc_repo = FakeSourceConnectionRepository() + sc_repo.seed(SC_ID, sc) + + sync_svc = FakeSyncService() + j1 = _make_sync_job_schema(job_id=uuid4()) + j1.entities_inserted = 42 + j1.entities_updated = 3 + j2 = _make_sync_job_schema(job_id=uuid4()) + sync_svc.seed_jobs(SYNC_ID, [j1, j2]) + + svc = _build_run_service(sc_repo=sc_repo, sync_service=sync_svc) + jobs = await svc.get_jobs(AsyncMock(), id=SC_ID, ctx=_make_ctx()) + + assert len(jobs) == 2 + assert jobs[0].source_connection_id == SC_ID + assert jobs[0].entities_inserted == 42 + assert jobs[0].entities_updated == 3 + assert jobs[1].source_connection_id == SC_ID + assert jobs[1].entities_inserted == 0 + + +async def test_get_jobs_not_found_raises(): + svc = _build_run_service() + with pytest.raises(NotFoundException, match="Source connection not found"): + await svc.get_jobs(AsyncMock(), id=uuid4(), ctx=_make_ctx()) + + +async def test_cancel_job_delegates_to_sync_service(): + sync_svc = FakeSyncService() + job = _make_sync_job_schema(sync_id=SYNC_ID) + sync_svc.set_cancel_result(job) + + svc = _build_run_service(sync_service=sync_svc) + result = await svc.cancel_job( + AsyncMock(), source_connection_id=SC_ID, job_id=JOB_ID, ctx=_make_ctx() + ) + assert result.id == JOB_ID + + +async def test_run_with_force_full_sync(): + sc = _make_source_conn() + sc_repo = FakeSourceConnectionRepository() + sc_repo.seed(SC_ID, sc) + + sync_svc = _RecordingFakeSyncService() + sync_svc.set_trigger_run_result(_make_sync_schema(), _make_sync_job_schema()) + + svc = _build_run_service(sc_repo=sc_repo, sync_service=sync_svc) + col_schema = MagicMock(id=uuid4(), readable_id="col-x") + col_schema.name = "Col" + svc._resolve_collection = AsyncMock(return_value=col_schema) + svc._resolve_connection = AsyncMock(return_value=MagicMock(short_name="github")) + + result = await svc.run(AsyncMock(), id=SC_ID, ctx=_make_ctx(), force_full_sync=True) + assert result.id == JOB_ID + assert ("validate_force_full_sync", SYNC_ID) in sync_svc._calls + assert sync_svc.last_trigger_run is not None + assert sync_svc.last_trigger_run["force_full_sync"] is True + + +async def test_resolve_collection_not_found(): + sc = _make_source_conn() + sc_repo = FakeSourceConnectionRepository() + sc_repo.seed(SC_ID, sc) + + col_repo = FakeCollectionRepository() + svc = _build_run_service(sc_repo=sc_repo, collection_repo=col_repo) + + with pytest.raises(NotFoundException, match="Collection not found"): + await svc._resolve_collection(AsyncMock(), sc, _make_ctx()) + + +async def test_resolve_collection_no_readable_id(): + sc = _make_source_conn(readable_collection_id=None) + svc = _build_run_service() + + with pytest.raises(NotFoundException, match="has no readable_collection_id"): + await svc._resolve_collection(AsyncMock(), sc, _make_ctx()) + + +async def test_resolve_connection_not_found(): + sc = _make_source_conn() + sc_repo = FakeSourceConnectionRepository() + sc_repo.seed(SC_ID, sc) + + conn_repo = FakeConnectionRepository() + svc = _build_run_service(sc_repo=sc_repo, connection_repo=conn_repo) + + with pytest.raises(NotFoundException, match="not found"): + await svc._resolve_connection(AsyncMock(), sc, _make_ctx()) + + +async def test_resolve_connection_no_connection_id(): + sc = _make_source_conn(connection_id=None) + svc = _build_run_service() + + with pytest.raises(NotFoundException, match="has no connection_id"): + await svc._resolve_connection(AsyncMock(), sc, _make_ctx()) + + +async def test_count_by_organization(): + sc_repo = FakeSourceConnectionRepository() + svc = _build_run_service(sc_repo=sc_repo) + count = await svc.count_by_organization(AsyncMock(), organization_id=ORG_ID) + assert count == 0 diff --git a/backend/airweave/domains/source_connections/tests/test_update.py b/backend/airweave/domains/source_connections/tests/test_update.py index 9063c4bc1..581cec880 100644 --- a/backend/airweave/domains/source_connections/tests/test_update.py +++ b/backend/airweave/domains/source_connections/tests/test_update.py @@ -18,7 +18,7 @@ from airweave.api.context import ApiContext from airweave.core.exceptions import NotFoundException from airweave.core.logging import logger -from airweave.core.shared_models import AuthMethod, SyncStatus +from airweave.core.shared_models import AuthMethod from airweave.domains.syncs.types import InvalidSyncTransitionError from airweave.domains.collections.fakes.repository import FakeCollectionRepository from airweave.domains.connections.fakes.repository import FakeConnectionRepository @@ -30,10 +30,13 @@ ) from airweave.domains.source_connections.fakes.response import FakeResponseBuilder from airweave.domains.source_connections.update import SourceConnectionUpdateService +from airweave.domains.sources.fakes.registry import FakeSourceRegistry from airweave.domains.sources.fakes.service import FakeSourceService from airweave.domains.sources.fakes.validation import FakeSourceValidationService -from airweave.domains.syncs.fakes.record_service import FakeSyncRecordService +from airweave.domains.sources.types import SourceRegistryEntry +from airweave.domains.syncs.fakes.service import FakeSyncService from airweave.domains.syncs.fakes.repository import FakeSyncRepository +from airweave.domains.syncs.types import SyncProvisionResult from airweave.domains.temporal.fakes.schedule_service import FakeTemporalScheduleService from airweave.models.collection import Collection from airweave.models.connection import Connection @@ -105,13 +108,13 @@ def _build_service( connection_repo=None, cred_repo=None, sync_repo=None, - sync_record_service=None, + sync_service=None, source_service=None, + source_registry=None, source_validation=None, credential_encryptor=None, response_builder=None, temporal_schedule_service=None, - sync_state_machine=None, ): return SourceConnectionUpdateService( sc_repo=sc_repo or FakeSourceConnectionRepository(), @@ -119,13 +122,13 @@ def _build_service( connection_repo=connection_repo or FakeConnectionRepository(), cred_repo=cred_repo or FakeIntegrationCredentialRepository(), sync_repo=sync_repo or FakeSyncRepository(), - sync_record_service=sync_record_service or FakeSyncRecordService(), + sync_service=sync_service or FakeSyncService(), source_service=source_service or FakeSourceService(), + source_registry=source_registry or FakeSourceRegistry(), source_validation=source_validation or FakeSourceValidationService(), credential_encryptor=credential_encryptor or FakeCredentialEncryptor(), response_builder=response_builder or FakeResponseBuilder(), temporal_schedule_service=temporal_schedule_service or FakeTemporalScheduleService(), - sync_state_machine=sync_state_machine or AsyncMock(), ) @@ -211,7 +214,7 @@ class ScheduleCase: SCHEDULE_CASES = [ ScheduleCase("update_existing", has_sync=True, new_cron="0 * * * *", expect_temporal_create=True, expect_temporal_delete=False), ScheduleCase("remove_schedule", has_sync=True, new_cron=None, expect_temporal_create=False, expect_temporal_delete=True), - ScheduleCase("add_no_sync", has_sync=False, new_cron="0 * * * *", expect_temporal_create=True, expect_temporal_delete=False, expect_sync_record_create=True), + ScheduleCase("add_no_sync", has_sync=False, new_cron="0 * * * *", expect_temporal_create=False, expect_temporal_delete=False, expect_sync_record_create=True), ScheduleCase("no_connection_id_warning", has_sync=False, new_cron="0 * * * *", has_connection_id=False, expect_temporal_create=False, expect_temporal_delete=False), ] @@ -235,17 +238,38 @@ async def test_schedule_update(case: ScheduleCase): source_svc.seed(_make_source_schema(short_name="github")) temporal = FakeTemporalScheduleService() + sync_svc = FakeSyncService() + + source_registry = FakeSourceRegistry() + source_entry = MagicMock(spec=SourceRegistryEntry) + source_entry.short_name = "github" + source_entry.federated_search = False + source_registry.seed(source_entry) - sync_record_svc = FakeSyncRecordService() if case.expect_sync_record_create: - mock_sync = MagicMock(spec=schemas.Sync) - mock_sync.id = uuid4() - sync_record_svc.set_create_result(mock_sync) + created_sync_id = uuid4() + mock_sync_schema = MagicMock(spec=schemas.Sync) + mock_sync_schema.id = created_sync_id + sync_svc.set_create_result( + SyncProvisionResult( + sync_id=created_sync_id, + sync=mock_sync_schema, + sync_job=None, + cron_schedule=case.new_cron, + ) + ) col = MagicMock(spec=Collection) col.id = uuid4() col.readable_id = "test-col" + col.name = "Test Collection" col.organization_id = ORG_ID + col.vector_db_deployment_metadata_id = uuid4() + col.sync_config = None + col.created_at = NOW + col.modified_at = NOW + col.created_by_email = None + col.modified_by_email = None col_repo = FakeCollectionRepository() col_repo.seed_readable("test-col", col) else: @@ -257,9 +281,10 @@ async def test_schedule_update(case: ScheduleCase): svc = _build_service( sc_repo=sc_repo, sync_repo=sync_repo, + sync_service=sync_svc, source_service=source_svc, + source_registry=source_registry, temporal_schedule_service=temporal, - sync_record_service=sync_record_svc, collection_repo=col_repo, ) @@ -271,7 +296,7 @@ async def test_schedule_update(case: ScheduleCase): if case.expect_temporal_delete: assert any(c[0] == "delete_all_schedules_for_sync" for c in temporal._calls) if case.expect_sync_record_create: - assert any(c[0] == "create_sync" for c in sync_record_svc._calls) + assert any(c[0] == "create" for c in sync_svc._calls) async def test_schedule_add_collection_not_found(): @@ -289,6 +314,45 @@ async def test_schedule_add_collection_not_found(): await svc.update(AsyncMock(), id=sc.id, obj_in=obj_in, ctx=_make_ctx()) +async def test_schedule_add_rejects_federated_source(): + """Adding a schedule to a federated search source is rejected with 400.""" + sc = _make_sc(sync_id=None) + sc_repo = FakeSourceConnectionRepository() + sc_repo.seed(sc.id, sc) + + col = MagicMock(spec=Collection) + col.id = uuid4() + col.readable_id = "test-col" + col.name = "Test Collection" + col.organization_id = ORG_ID + col.vector_db_deployment_metadata_id = uuid4() + col.sync_config = None + col.created_at = NOW + col.modified_at = NOW + col.created_by_email = None + col.modified_by_email = None + col_repo = FakeCollectionRepository() + col_repo.seed_readable("test-col", col) + + federated_entry = MagicMock(spec=SourceRegistryEntry) + federated_entry.short_name = "github" + federated_entry.federated_search = True + source_registry = FakeSourceRegistry() + source_registry.seed(federated_entry) + + svc = _build_service( + sc_repo=sc_repo, + collection_repo=col_repo, + source_registry=source_registry, + ) + obj_in = SourceConnectionUpdate(schedule={"cron": "0 * * * *"}) + + with pytest.raises(HTTPException) as exc_info: + await svc.update(AsyncMock(), id=sc.id, obj_in=obj_in, ctx=_make_ctx()) + assert exc_info.value.status_code == 400 + assert "federated search" in str(exc_info.value.detail) + + # --------------------------------------------------------------------------- # Credential updates -- table-driven # --------------------------------------------------------------------------- @@ -480,7 +544,7 @@ def test_cron_validation(case: CronCase): @pytest.mark.asyncio async def test_credential_update_triggers_unpause(): - """Successful direct auth credential update calls sync_state_machine.transition → ACTIVE.""" + """Successful direct auth credential update calls sync_service.resume.""" conn_id = uuid4() cred_id = uuid4() sync_id = uuid4() @@ -503,7 +567,7 @@ async def test_credential_update_triggers_unpause(): validation = FakeSourceValidationService() validation.seed_auth_result("github", _AuthPayload(token="secret")) - state_machine = AsyncMock() + sync_svc = FakeSyncService() svc = _build_service( sc_repo=sc_repo, @@ -511,24 +575,18 @@ async def test_credential_update_triggers_unpause(): cred_repo=cred_repo, source_validation=validation, credential_encryptor=FakeCredentialEncryptor(), - sync_state_machine=state_machine, + sync_service=sync_svc, ) obj_in = SourceConnectionUpdate(authentication={"credentials": {"token": "new_secret"}}) await svc.update(AsyncMock(), id=sc.id, obj_in=obj_in, ctx=_make_ctx()) - state_machine.transition.assert_called_once() - call_kwargs = state_machine.transition.call_args - from airweave.core.shared_models import SyncStatus - - assert call_kwargs.kwargs.get("target") == SyncStatus.ACTIVE or ( - len(call_kwargs.args) >= 2 and call_kwargs.args[1] == SyncStatus.ACTIVE - ) + assert any(c[0] == "resume" for c in sync_svc._calls) @pytest.mark.asyncio async def test_credential_update_unpause_failure_is_nonfatal(): - """If sync_state_machine.transition raises, the update still succeeds.""" + """If sync_service.resume raises, the update still succeeds.""" conn_id = uuid4() cred_id = uuid4() sync_id = uuid4() @@ -551,10 +609,8 @@ async def test_credential_update_unpause_failure_is_nonfatal(): validation = FakeSourceValidationService() validation.seed_auth_result("github", _AuthPayload(token="secret")) - state_machine = AsyncMock() - state_machine.transition.side_effect = InvalidSyncTransitionError( - SyncStatus.ACTIVE, SyncStatus.ACTIVE - ) + sync_svc = FakeSyncService() + sync_svc.set_error(ValueError("sync not active")) svc = _build_service( sc_repo=sc_repo, @@ -562,7 +618,7 @@ async def test_credential_update_unpause_failure_is_nonfatal(): cred_repo=cred_repo, source_validation=validation, credential_encryptor=FakeCredentialEncryptor(), - sync_state_machine=state_machine, + sync_service=sync_svc, ) obj_in = SourceConnectionUpdate(authentication={"credentials": {"token": "new_secret"}}) diff --git a/backend/airweave/domains/source_connections/update.py b/backend/airweave/domains/source_connections/update.py index 80af6484d..670220f2a 100644 --- a/backend/airweave/domains/source_connections/update.py +++ b/backend/airweave/domains/source_connections/update.py @@ -11,7 +11,6 @@ from airweave.api.context import ApiContext from airweave.core.exceptions import NotFoundException from airweave.core.protocols.encryption import CredentialEncryptor -from airweave.core.shared_models import SyncStatus from airweave.db.unit_of_work import UnitOfWork from airweave.domains.collections.protocols import CollectionRepositoryProtocol from airweave.domains.connections.protocols import ConnectionRepositoryProtocol @@ -23,19 +22,17 @@ ) from airweave.domains.sources.exceptions import SourceNotFoundError from airweave.domains.sources.protocols import ( + SourceRegistryProtocol, SourceServiceProtocol, SourceValidationServiceProtocol, ) -from airweave.domains.syncs.protocols import ( - SyncRecordServiceProtocol, - SyncRepositoryProtocol, - SyncStateMachineProtocol, -) +from airweave.domains.syncs.protocols import SyncRepositoryProtocol, SyncServiceProtocol from airweave.domains.syncs.types import InvalidSyncTransitionError, OptimisticLockError from airweave.domains.temporal.protocols import TemporalScheduleServiceProtocol from airweave.models.source_connection import SourceConnection from airweave.schemas.source_connection import ( AuthenticationMethod, + ScheduleConfig, SourceConnectionUpdate, ) from airweave.schemas.source_connection import ( @@ -59,13 +56,13 @@ def __init__( connection_repo: ConnectionRepositoryProtocol, cred_repo: IntegrationCredentialRepositoryProtocol, sync_repo: SyncRepositoryProtocol, - sync_record_service: SyncRecordServiceProtocol, + sync_service: SyncServiceProtocol, source_service: SourceServiceProtocol, + source_registry: SourceRegistryProtocol, source_validation: SourceValidationServiceProtocol, credential_encryptor: CredentialEncryptor, response_builder: ResponseBuilderProtocol, temporal_schedule_service: TemporalScheduleServiceProtocol, - sync_state_machine: SyncStateMachineProtocol, ) -> None: """Initialize with repositories and collaborator services.""" self._sc_repo = sc_repo @@ -73,13 +70,13 @@ def __init__( self._connection_repo = connection_repo self._cred_repo = cred_repo self._sync_repo = sync_repo - self._sync_record_service = sync_record_service + self._sync_service = sync_service self._source_service = source_service + self._source_registry = source_registry self._source_validation = source_validation self._credential_encryptor = credential_encryptor self._response_builder = response_builder self._temporal_schedule_service = temporal_schedule_service - self._sync_state_machine = sync_state_machine async def update( self, @@ -134,10 +131,9 @@ async def update( if source_conn.sync_id: try: - await self._sync_state_machine.transition( - sync_id=source_conn.sync_id, - target=SyncStatus.ACTIVE, - ctx=ctx, + await self._sync_service.resume( + source_conn.sync_id, + ctx, reason="Credential update completed", ) except (InvalidSyncTransitionError, OptimisticLockError, ValueError): @@ -203,62 +199,55 @@ async def _handle_schedule_update( uow, ) elif new_cron: - # No sync exists but we're adding a schedule - create a new sync - # Get the source to validate schedule - source = await self._get_and_validate_source(source_conn.short_name, ctx) - self._validate_cron_schedule_for_source(new_cron, source, ctx) - - # Check if connection_id exists (might be None for OAuth flows) if not source_conn.connection_id: ctx.logger.warning( f"Cannot create schedule for SC {source_conn.id} without connection_id" ) - # Skip schedule creation for connections without connection_id del update_data["schedule"] return - # Get the collection - collection = await self._collection_repo.get_by_readable_id( + collection_orm = await self._collection_repo.get_by_readable_id( uow.session, readable_id=source_conn.readable_collection_id, ctx=ctx ) - if not collection: + if not collection_orm: raise NotFoundException("Collection not found") + collection = schemas.CollectionRecord.model_validate( + collection_orm, from_attributes=True + ) - # Resolve destination IDs - dest_ids = await self._sync_record_service.resolve_destination_ids(uow.session, ctx) + source_entry = self._source_registry.get(source_conn.short_name) + if source_entry.federated_search: + raise HTTPException( + status_code=400, + detail=f"Source '{source_conn.short_name}' is a federated search source " + "and does not support scheduled syncs.", + ) - # Create a new sync with the schedule - sync, _ = await self._sync_record_service.create_sync( + dest_ids = await self._sync_service.resolve_destination_ids(uow.session, ctx) + + sync_result = await self._sync_service.create( uow.session, - name=f"Sync for {source_conn.name}", + name=source_conn.name, source_connection_id=source_conn.connection_id, destination_connection_ids=dest_ids, - cron_schedule=new_cron, + collection_id=collection.id, + collection_readable_id=collection.readable_id, + source_entry=source_entry, + schedule_config=ScheduleConfig(cron=new_cron), run_immediately=False, ctx=ctx, uow=uow, ) - # Apply the sync_id update to the source connection now - # so that temporal_schedule_service can find it source_conn = await self._sc_repo.update( uow.session, db_obj=source_conn, - obj_in={"sync_id": sync.id}, + obj_in={"sync_id": sync_result.sync_id}, ctx=ctx, uow=uow, ) await uow.session.flush() - # Create the Temporal schedule - await self._temporal_schedule_service.create_or_update_schedule( - sync_id=sync.id, - cron_schedule=new_cron, - db=uow.session, - ctx=ctx, - uow=uow, - ) - if "schedule" in update_data: del update_data["schedule"] diff --git a/backend/airweave/domains/syncs/fakes/lifecycle_service.py b/backend/airweave/domains/syncs/fakes/lifecycle_service.py deleted file mode 100644 index 0f5be25f2..000000000 --- a/backend/airweave/domains/syncs/fakes/lifecycle_service.py +++ /dev/null @@ -1,129 +0,0 @@ -"""Fake sync lifecycle service for testing.""" - -from typing import List, Optional -from uuid import UUID - -from sqlalchemy.ext.asyncio import AsyncSession - -from airweave.api.context import ApiContext -from airweave.db.unit_of_work import UnitOfWork -from airweave.domains.sources.types import SourceRegistryEntry -from airweave.domains.syncs.types import SyncProvisionResult -from airweave.schemas.source_connection import ScheduleConfig, SourceConnectionJob - - -class FakeSyncLifecycleService: - """In-memory fake for SyncLifecycleServiceProtocol.""" - - def __init__(self) -> None: - """Initialize with empty state.""" - self._calls: list[tuple] = [] - self._provision_result: Optional[SyncProvisionResult] = None - self._run_result: Optional[SourceConnectionJob] = None - self._jobs: dict[UUID, List[SourceConnectionJob]] = {} - self._cancel_result: Optional[SourceConnectionJob] = None - self._should_raise: Optional[Exception] = None - - def set_provision_result(self, result: Optional[SyncProvisionResult]) -> None: - """Configure provision_sync() return value.""" - self._provision_result = result - - def set_run_result(self, result: SourceConnectionJob) -> None: - """Configure run() return value.""" - self._run_result = result - - def seed_jobs(self, sc_id: UUID, jobs: List[SourceConnectionJob]) -> None: - """Seed jobs returned by get_jobs.""" - self._jobs[sc_id] = jobs - - def set_cancel_result(self, result: SourceConnectionJob) -> None: - """Configure cancel_job() return value.""" - self._cancel_result = result - - def set_error(self, error: Exception) -> None: - """Make all subsequent calls raise this error.""" - self._should_raise = error - - async def teardown_syncs_for_collection( - self, - db: AsyncSession, - *, - sync_ids: List[UUID], - collection_id: UUID, - organization_id: UUID, - ctx: ApiContext, - cancel_timeout_seconds: int = 15, - ) -> None: - """Record call — no-op fake.""" - self._calls.append( - ("teardown_syncs_for_collection", db, sync_ids, collection_id, organization_id, ctx) - ) - if self._should_raise: - raise self._should_raise - - async def provision_sync( - self, - db: AsyncSession, - *, - name: str, - source_connection_id: UUID, - destination_connection_ids: List[UUID], - collection_id: UUID, - collection_readable_id: str, - source_entry: SourceRegistryEntry, - schedule_config: Optional[ScheduleConfig], - run_immediately: bool, - ctx: ApiContext, - uow: UnitOfWork, - ) -> Optional[SyncProvisionResult]: - """Record call and return canned result.""" - self._calls.append(("provision_sync", name, source_connection_id, collection_id)) - if self._should_raise: - raise self._should_raise - return self._provision_result - - async def run( - self, - db: AsyncSession, - *, - id: UUID, - ctx: ApiContext, - force_full_sync: bool = False, - ) -> SourceConnectionJob: - """Record call and return canned result.""" - self._calls.append(("run", db, id, ctx, force_full_sync)) - if self._should_raise: - raise self._should_raise - if self._run_result is None: - raise RuntimeError("FakeSyncLifecycleService.run_result not configured") - return self._run_result - - async def get_jobs( - self, - db: AsyncSession, - *, - id: UUID, - ctx: ApiContext, - limit: int = 100, - ) -> List[SourceConnectionJob]: - """Record call and return seeded jobs.""" - self._calls.append(("get_jobs", db, id, ctx, limit)) - if self._should_raise: - raise self._should_raise - return self._jobs.get(id, [])[:limit] - - async def cancel_job( - self, - db: AsyncSession, - *, - source_connection_id: UUID, - job_id: UUID, - ctx: ApiContext, - ) -> SourceConnectionJob: - """Record call and return canned result.""" - self._calls.append(("cancel_job", db, source_connection_id, job_id, ctx)) - if self._should_raise: - raise self._should_raise - if self._cancel_result is None: - raise RuntimeError("FakeSyncLifecycleService.cancel_result not configured") - return self._cancel_result diff --git a/backend/airweave/domains/syncs/fakes/record_service.py b/backend/airweave/domains/syncs/fakes/record_service.py deleted file mode 100644 index f0d45fb95..000000000 --- a/backend/airweave/domains/syncs/fakes/record_service.py +++ /dev/null @@ -1,91 +0,0 @@ -"""Fake sync record service for testing.""" - -from typing import List, Optional, Tuple -from uuid import UUID - -from sqlalchemy.ext.asyncio import AsyncSession - -from airweave import schemas -from airweave.api.context import ApiContext -from airweave.db.unit_of_work import UnitOfWork - - -class FakeSyncRecordService: - """In-memory fake for SyncRecordServiceProtocol.""" - - def __init__(self) -> None: - """Initialize with empty state.""" - self._calls: list[tuple] = [] - self._create_result: Optional[Tuple[schemas.Sync, Optional[schemas.SyncJob]]] = None - self._trigger_result: Optional[Tuple[schemas.Sync, schemas.SyncJob]] = None - self._resolve_dest_ids: Optional[List[UUID]] = None - self._should_raise: Optional[Exception] = None - - def set_create_result( - self, sync: schemas.Sync, sync_job: Optional[schemas.SyncJob] = None - ) -> None: - """Configure create_sync return value.""" - self._create_result = (sync, sync_job) - - def set_trigger_result(self, sync: schemas.Sync, sync_job: schemas.SyncJob) -> None: - """Configure trigger_sync_run return value.""" - self._trigger_result = (sync, sync_job) - - def set_resolve_dest_ids(self, ids: List[UUID]) -> None: - """Configure resolve_destination_ids return value.""" - self._resolve_dest_ids = ids - - def set_error(self, error: Exception) -> None: - """Make all subsequent calls raise this error.""" - self._should_raise = error - - async def resolve_destination_ids( - self, - db: AsyncSession, - ctx: ApiContext, - ) -> List[UUID]: - """Record call and return canned result.""" - self._calls.append(("resolve_destination_ids",)) - if self._should_raise: - raise self._should_raise - if self._resolve_dest_ids is None: - from airweave.core.constants.reserved_ids import NATIVE_VESPA_UUID - - return [NATIVE_VESPA_UUID] - return self._resolve_dest_ids - - async def create_sync( - self, - db: AsyncSession, - *, - name: str, - source_connection_id: UUID, - destination_connection_ids: List[UUID], - cron_schedule: Optional[str], - run_immediately: bool, - ctx: ApiContext, - uow: UnitOfWork, - ) -> Tuple[schemas.Sync, Optional[schemas.SyncJob]]: - """Record call and return canned result.""" - self._calls.append( - ("create_sync", name, source_connection_id, cron_schedule, run_immediately) - ) - if self._should_raise: - raise self._should_raise - if self._create_result is None: - raise RuntimeError("FakeSyncRecordService.create_result not configured") - return self._create_result - - async def trigger_sync_run( - self, - db: AsyncSession, - sync_id: UUID, - ctx: ApiContext, - ) -> Tuple[schemas.Sync, schemas.SyncJob]: - """Record call and return canned result.""" - self._calls.append(("trigger_sync_run", db, sync_id, ctx)) - if self._should_raise: - raise self._should_raise - if self._trigger_result is None: - raise RuntimeError("FakeSyncRecordService.trigger_result not configured") - return self._trigger_result diff --git a/backend/airweave/domains/syncs/fakes/service.py b/backend/airweave/domains/syncs/fakes/service.py index d31050fb6..eade30d2e 100644 --- a/backend/airweave/domains/syncs/fakes/service.py +++ b/backend/airweave/domains/syncs/fakes/service.py @@ -1,17 +1,181 @@ -"""Fake sync service for testing.""" +"""Fake sync service for testing — matches unified SyncServiceProtocol.""" -from typing import Optional +from typing import Dict, List, Optional, Tuple +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession from airweave import schemas from airweave.api.context import ApiContext +from airweave.core.context import BaseContext +from airweave.core.shared_models import SyncStatus +from airweave.db.unit_of_work import UnitOfWork +from airweave.domains.sources.types import SourceRegistryEntry from airweave.domains.sync_pipeline.config import SyncConfig +from airweave.domains.syncs.protocols import SyncServiceProtocol +from airweave.domains.syncs.types import SyncProvisionResult, SyncTransitionResult +from airweave.schemas.source_connection import ScheduleConfig -class FakeSyncService: - """In-memory fake for SyncServiceProtocol.""" +class FakeSyncService(SyncServiceProtocol): + """In-memory fake for the unified SyncServiceProtocol.""" def __init__(self) -> None: self._calls: list[tuple] = [] + self._create_result: Optional[SyncProvisionResult] = None + self._get_result: Optional[schemas.Sync] = None + self._trigger_run_result: Optional[Tuple[schemas.Sync, schemas.SyncJob]] = None + self._jobs: Dict[UUID, List[schemas.SyncJob]] = {} + self._cancel_result: Optional[schemas.SyncJob] = None + self._resolve_dest_ids: Optional[List[UUID]] = None + self._run_result: Optional[schemas.Sync] = None + self._should_raise: Optional[Exception] = None + + # -- Configuration helpers -- + + def set_create_result(self, result: SyncProvisionResult) -> None: + self._create_result = result + + def set_get_result(self, result: schemas.Sync) -> None: + self._get_result = result + + def set_trigger_run_result(self, sync: schemas.Sync, job: schemas.SyncJob) -> None: + self._trigger_run_result = (sync, job) + + def seed_jobs(self, sync_id: UUID, jobs: List[schemas.SyncJob]) -> None: + self._jobs[sync_id] = jobs + + def set_cancel_result(self, result: schemas.SyncJob) -> None: + self._cancel_result = result + + def set_resolve_dest_ids(self, ids: List[UUID]) -> None: + self._resolve_dest_ids = ids + + def set_run_result(self, result: schemas.Sync) -> None: + self._run_result = result + + def set_error(self, error: Exception) -> None: + self._should_raise = error + + # -- Lifecycle -- + + async def create( + self, + db: AsyncSession, + *, + name: str, + source_connection_id: UUID, + destination_connection_ids: List[UUID], + collection_id: UUID, + collection_readable_id: str, + source_entry: SourceRegistryEntry, + schedule_config: Optional[ScheduleConfig], + run_immediately: bool, + ctx: ApiContext, + uow: UnitOfWork, + ) -> SyncProvisionResult: + self._calls.append(("create", name, source_connection_id, collection_id)) + if self._should_raise: + raise self._should_raise + if self._create_result is None: + raise RuntimeError("FakeSyncService.create_result not configured") + return self._create_result + + async def get(self, db: AsyncSession, *, sync_id: UUID, ctx: BaseContext) -> schemas.Sync: + self._calls.append(("get", sync_id)) + if self._should_raise: + raise self._should_raise + if self._get_result is None: + raise ValueError(f"Sync {sync_id} not found") + return self._get_result + + async def pause( + self, sync_id: UUID, ctx: BaseContext, *, reason: str = "" + ) -> SyncTransitionResult: + self._calls.append(("pause", sync_id, reason)) + if self._should_raise: + raise self._should_raise + return SyncTransitionResult( + applied=True, previous=SyncStatus.ACTIVE, current=SyncStatus.PAUSED + ) + + async def resume( + self, sync_id: UUID, ctx: BaseContext, *, reason: str = "" + ) -> SyncTransitionResult: + self._calls.append(("resume", sync_id, reason)) + if self._should_raise: + raise self._should_raise + return SyncTransitionResult( + applied=True, previous=SyncStatus.PAUSED, current=SyncStatus.ACTIVE + ) + + async def delete( + self, + db: AsyncSession, + *, + sync_id: UUID, + collection_id: UUID, + organization_id: UUID, + ctx: ApiContext, + cancel_timeout_seconds: int = 15, + ) -> None: + self._calls.append(("delete", sync_id, collection_id, organization_id)) + if self._should_raise: + raise self._should_raise + + # -- Jobs -- + + async def resolve_destination_ids(self, db: AsyncSession, ctx: ApiContext) -> List[UUID]: + self._calls.append(("resolve_destination_ids",)) + if self._should_raise: + raise self._should_raise + if self._resolve_dest_ids is not None: + return self._resolve_dest_ids + from airweave.core.constants.reserved_ids import NATIVE_VESPA_UUID + + return [NATIVE_VESPA_UUID] + + async def trigger_run( + self, + db: AsyncSession, + *, + sync_id: UUID, + collection: schemas.CollectionRecord, + connection: schemas.Connection, + ctx: ApiContext, + force_full_sync: bool = False, + ) -> Tuple[schemas.Sync, schemas.SyncJob]: + self._calls.append(("trigger_run", sync_id)) + if self._should_raise: + raise self._should_raise + if self._trigger_run_result is None: + raise RuntimeError("FakeSyncService.trigger_run_result not configured") + return self._trigger_run_result + + async def get_jobs( + self, db: AsyncSession, *, sync_id: UUID, ctx: ApiContext, limit: int = 100 + ) -> List[schemas.SyncJob]: + self._calls.append(("get_jobs", sync_id, limit)) + if self._should_raise: + raise self._should_raise + return self._jobs.get(sync_id, [])[:limit] + + async def cancel_job( + self, db: AsyncSession, *, job_id: UUID, ctx: ApiContext + ) -> schemas.SyncJob: + self._calls.append(("cancel_job", job_id)) + if self._should_raise: + raise self._should_raise + if self._cancel_result is None: + raise RuntimeError("FakeSyncService.cancel_result not configured") + return self._cancel_result + + async def validate_force_full_sync( + self, db: AsyncSession, sync_id: UUID, ctx: ApiContext + ) -> None: + self._calls.append(("validate_force_full_sync", sync_id)) + + # -- Execution -- async def run( self, @@ -22,6 +186,9 @@ async def run( ctx: ApiContext, force_full_sync: bool = False, execution_config: Optional[SyncConfig] = None, + access_token: Optional[str] = None, ) -> schemas.Sync: self._calls.append(("run", sync, sync_job)) - return sync + if self._should_raise: + raise self._should_raise + return self._run_result or sync diff --git a/backend/airweave/domains/syncs/jobs/fakes/repository.py b/backend/airweave/domains/syncs/jobs/fakes/repository.py index 16867e26f..7419acea6 100644 --- a/backend/airweave/domains/syncs/jobs/fakes/repository.py +++ b/backend/airweave/domains/syncs/jobs/fakes/repository.py @@ -54,11 +54,19 @@ async def get_active_for_sync( return [j for j in jobs if j.status in ("PENDING", "RUNNING", "CANCELLING")] async def get_all_by_sync_id( - self, db: AsyncSession, sync_id: UUID, ctx: ApiContext + self, + db: AsyncSession, + sync_id: UUID, + ctx: ApiContext, + limit: Optional[int] = None, ) -> List[SyncJob]: - """Return all seeded jobs for the sync.""" - self._calls.append(("get_all_by_sync_id", db, sync_id, ctx)) - return self._by_sync.get(sync_id, []) + """Return all seeded jobs for the sync (newest first; optional limit).""" + self._calls.append(("get_all_by_sync_id", db, sync_id, ctx, limit)) + jobs = list(self._by_sync.get(sync_id, [])) + jobs.sort(key=lambda j: j.created_at, reverse=True) + if limit is not None: + jobs = jobs[:limit] + return jobs async def create( self, diff --git a/backend/airweave/domains/syncs/jobs/protocols.py b/backend/airweave/domains/syncs/jobs/protocols.py index 59eb906fe..3143426f7 100644 --- a/backend/airweave/domains/syncs/jobs/protocols.py +++ b/backend/airweave/domains/syncs/jobs/protocols.py @@ -35,7 +35,11 @@ async def get_active_for_sync( ... async def get_all_by_sync_id( - self, db: AsyncSession, sync_id: UUID, ctx: BaseContext + self, + db: AsyncSession, + sync_id: UUID, + ctx: BaseContext, + limit: Optional[int] = None, ) -> List[SyncJob]: """Get all jobs for a specific sync.""" ... diff --git a/backend/airweave/domains/syncs/jobs/repository.py b/backend/airweave/domains/syncs/jobs/repository.py index 9a504b4ee..66d0000aa 100644 --- a/backend/airweave/domains/syncs/jobs/repository.py +++ b/backend/airweave/domains/syncs/jobs/repository.py @@ -8,6 +8,7 @@ from airweave import crud from airweave.api.context import ApiContext +from airweave.core.context import BaseContext from airweave.core.shared_models import SyncJobStatus from airweave.db.unit_of_work import UnitOfWork from airweave.domains.syncs.jobs.protocols import SyncJobRepositoryProtocol @@ -41,10 +42,14 @@ async def get_active_for_sync( ) async def get_all_by_sync_id( - self, db: AsyncSession, sync_id: UUID, ctx: ApiContext + self, + db: AsyncSession, + sync_id: UUID, + ctx: BaseContext, + limit: Optional[int] = None, ) -> List[SyncJob]: """Get all jobs for a specific sync.""" - return await crud.sync_job.get_all_by_sync_id(db, sync_id=sync_id) + return await crud.sync_job.get_all_by_sync_id(db, sync_id=sync_id, limit=limit) async def create( self, diff --git a/backend/airweave/domains/syncs/lifecycle_service.py b/backend/airweave/domains/syncs/lifecycle_service.py deleted file mode 100644 index e2912a462..000000000 --- a/backend/airweave/domains/syncs/lifecycle_service.py +++ /dev/null @@ -1,466 +0,0 @@ -"""Sync lifecycle service: provision, run, get_jobs, cancel_job, teardown.""" - -import asyncio -import re -from datetime import datetime, timezone -from typing import List, Optional -from uuid import UUID - -from fastapi import HTTPException -from sqlalchemy.ext.asyncio import AsyncSession - -from airweave import schemas -from airweave.api.context import ApiContext -from airweave.core.events.sync import SyncLifecycleEvent -from airweave.core.protocols.event_bus import EventBus -from airweave.core.shared_models import SyncJobStatus -from airweave.db.unit_of_work import UnitOfWork -from airweave.domains.collections.protocols import CollectionRepositoryProtocol -from airweave.domains.connections.protocols import ConnectionRepositoryProtocol -from airweave.domains.source_connections.protocols import ( - ResponseBuilderProtocol, - SourceConnectionRepositoryProtocol, -) -from airweave.domains.sources.types import SourceRegistryEntry -from airweave.domains.syncs.jobs.protocols import ( - SyncJobRepositoryProtocol, - SyncJobStateMachineProtocol, -) -from airweave.domains.syncs.protocols import ( - SyncCursorRepositoryProtocol, - SyncLifecycleServiceProtocol, - SyncRecordServiceProtocol, -) -from airweave.domains.syncs.types import ( - CONTINUOUS_SOURCE_DEFAULT_CRON, - DAILY_CRON_TEMPLATE, - SyncProvisionResult, -) -from airweave.domains.temporal.protocols import ( - TemporalScheduleServiceProtocol, - TemporalWorkflowServiceProtocol, -) -from airweave.schemas.source_connection import ScheduleConfig, SourceConnectionJob - -_SUB_HOURLY_PATTERN = re.compile(r"^\*/([1-5]?[0-9]) \* \* \* \*$") - - -class SyncLifecycleService(SyncLifecycleServiceProtocol): - """API-facing facade for sync lifecycle: provision, run, get_jobs, cancel_job.""" - - def __init__( - self, - sc_repo: SourceConnectionRepositoryProtocol, - collection_repo: CollectionRepositoryProtocol, - connection_repo: ConnectionRepositoryProtocol, - sync_cursor_repo: SyncCursorRepositoryProtocol, - sync_service: SyncRecordServiceProtocol, - state_machine: SyncJobStateMachineProtocol, - sync_job_repo: SyncJobRepositoryProtocol, - temporal_workflow_service: TemporalWorkflowServiceProtocol, - temporal_schedule_service: TemporalScheduleServiceProtocol, - response_builder: ResponseBuilderProtocol, - event_bus: EventBus, - ) -> None: - """Initialize with all injected dependencies.""" - self._sc_repo = sc_repo - self._collection_repo = collection_repo - self._connection_repo = connection_repo - self._sync_cursor_repo = sync_cursor_repo - self._sync_service = sync_service - self._state_machine = state_machine - self._sync_job_repo = sync_job_repo - self._temporal_workflow_service = temporal_workflow_service - self._temporal_schedule_service = temporal_schedule_service - self._response_builder = response_builder - self._event_bus = event_bus - - # ------------------------------------------------------------------ - # Public API (protocol surface) - # ------------------------------------------------------------------ - - async def teardown_syncs_for_collection( - self, - db: AsyncSession, - *, - sync_ids: List[UUID], - collection_id: UUID, - organization_id: UUID, - ctx: ApiContext, - cancel_timeout_seconds: int = 15, - ) -> None: - """Cancel running workflows and schedule async cleanup for a collection's syncs. - - 1. Cancels PENDING/RUNNING workflows via Temporal. - 2. Polls until terminal state (up to cancel_timeout_seconds). - 3. Schedules async cleanup workflow for Vespa/ARF/schedules. - """ - syncs_to_wait = await self._cancel_active_syncs(db, sync_ids, ctx) - await self._wait_for_terminal(db, syncs_to_wait, cancel_timeout_seconds, ctx) - await self._schedule_collection_cleanup(sync_ids, collection_id, organization_id, ctx) - - async def provision_sync( - self, - db: AsyncSession, - *, - name: str, - source_connection_id: UUID, - destination_connection_ids: List[UUID], - collection_id: UUID, - collection_readable_id: str, - source_entry: SourceRegistryEntry, - schedule_config: Optional[ScheduleConfig], - run_immediately: bool, - ctx: ApiContext, - uow: UnitOfWork, - ) -> Optional[SyncProvisionResult]: - """Create sync + job + Temporal schedule atomically. - - Returns None for federated search sources (no sync needed) - or when there is neither a schedule nor an immediate run. - """ - if source_entry.federated_search: - ctx.logger.info(f"Skipping sync for federated source '{source_entry.short_name}'") - return None - - cron = self._resolve_cron(schedule_config, source_entry, ctx) - - if not cron and not run_immediately: - ctx.logger.info("No cron schedule and run_immediately=False, skipping sync creation") - return None - - if cron: - self._validate_cron_for_source(cron, source_entry) - - sync_schema, sync_job_schema = await self._sync_service.create_sync( - uow.session, - name=f"Sync for {name}", - source_connection_id=source_connection_id, - destination_connection_ids=destination_connection_ids, - cron_schedule=cron, - run_immediately=run_immediately, - ctx=ctx, - uow=uow, - ) - - if cron: - await self._temporal_schedule_service.create_or_update_schedule( - sync_id=sync_schema.id, - cron_schedule=cron, - db=uow.session, - ctx=ctx, - uow=uow, - collection_readable_id=collection_readable_id, - connection_id=source_connection_id, - ) - - return SyncProvisionResult( - sync_id=sync_schema.id, - sync=sync_schema, - sync_job=sync_job_schema, - cron_schedule=cron, - ) - - async def run( - self, - db: AsyncSession, - *, - id: UUID, - ctx: ApiContext, - force_full_sync: bool = False, - ) -> SourceConnectionJob: - """Trigger a sync run for a source connection. - - Args: - db: Database session. - id: Source connection ID. - ctx: API context. - force_full_sync: Only valid for continuous syncs. - """ - source_conn = await self._sc_repo.get(db, id, ctx) - if not source_conn: - raise HTTPException(status_code=404, detail="Source connection not found") - if not source_conn.sync_id: - raise HTTPException(status_code=400, detail="Source connection has no associated sync") - - sc_id = source_conn.id - sc_sync_id = source_conn.sync_id - - if force_full_sync: - await self._validate_force_full_sync(db, sc_sync_id, ctx) - - collection = await self._collection_repo.get_by_readable_id( - db, source_conn.readable_collection_id, ctx - ) - collection_schema = schemas.CollectionRecord.model_validate( - collection, from_attributes=True - ) - - connection_schema = await self._resolve_connection(db, source_conn, ctx) - - sync, sync_job = await self._sync_service.trigger_sync_run(db, sync_id=sc_sync_id, ctx=ctx) - sync_job_schema = schemas.SyncJob.model_validate(sync_job, from_attributes=True) - - await self._event_bus.publish( - SyncLifecycleEvent.pending( - organization_id=ctx.organization.id, - source_connection_id=sc_id, - sync_job_id=sync_job_schema.id, - sync_id=sc_sync_id, - collection_id=collection_schema.id, - source_type=connection_schema.short_name, - collection_name=collection_schema.name, - collection_readable_id=collection_schema.readable_id, - ) - ) - - await self._temporal_workflow_service.run_source_connection_workflow( - sync=sync, - sync_job=sync_job, - collection=collection_schema, - connection=connection_schema, - ctx=ctx, - force_full_sync=force_full_sync, - ) - - return sync_job_schema.to_source_connection_job(sc_id) - - async def get_jobs( - self, - db: AsyncSession, - *, - id: UUID, - ctx: ApiContext, - limit: int = 100, - ) -> List[SourceConnectionJob]: - """Get sync jobs for a source connection.""" - source_conn = await self._sc_repo.get(db, id, ctx) - if not source_conn: - raise HTTPException(status_code=404, detail="Source connection not found") - if not source_conn.sync_id: - return [] - - jobs = await self._sync_job_repo.get_all_by_sync_id(db, source_conn.sync_id, ctx) - return [self._response_builder.map_sync_job(j, source_conn.id) for j in jobs] - - async def cancel_job( - self, - db: AsyncSession, - *, - source_connection_id: UUID, - job_id: UUID, - ctx: ApiContext, - ) -> SourceConnectionJob: - """Cancel a running sync job. - - Sets CANCELLING, sends cancel to Temporal, and handles - edge cases (workflow not found, Temporal failure). - """ - source_conn = await self._sc_repo.get(db, source_connection_id, ctx) - if not source_conn: - raise HTTPException(status_code=404, detail="Source connection not found") - if not source_conn.sync_id: - raise HTTPException(status_code=400, detail="Source connection has no associated sync") - - sync_job = await self._sync_job_repo.get(db, job_id, ctx) - if not sync_job: - raise HTTPException(status_code=404, detail="Sync job not found") - if sync_job.sync_id != source_conn.sync_id: - raise HTTPException( - status_code=400, - detail="Sync job does not belong to this source connection", - ) - if sync_job.status not in (SyncJobStatus.PENDING, SyncJobStatus.RUNNING): - raise HTTPException( - status_code=400, - detail=f"Cannot cancel job in {sync_job.status} state", - ) - - await self._state_machine.transition( - sync_job_id=job_id, target=SyncJobStatus.CANCELLING, ctx=ctx - ) - - cancel_result = await self._temporal_workflow_service.cancel_sync_job_workflow( - str(job_id), ctx - ) - - if not cancel_result["success"]: - raise HTTPException( - status_code=502, detail="Failed to request cancellation from Temporal" - ) - - if not cancel_result["workflow_found"]: - ctx.logger.info(f"Workflow not found for job {job_id} - marking CANCELLED directly") - await self._state_machine.transition( - sync_job_id=job_id, - target=SyncJobStatus.CANCELLED, - ctx=ctx, - error="Workflow not found in Temporal - may have already completed", - ) - - await db.refresh(sync_job) - sync_job_schema = schemas.SyncJob.model_validate(sync_job, from_attributes=True) - return sync_job_schema.to_source_connection_job(source_connection_id) - - # ------------------------------------------------------------------ - # Private helpers - # ------------------------------------------------------------------ - - async def _cancel_active_syncs( - self, - db: AsyncSession, - sync_ids: List[UUID], - ctx: ApiContext, - ) -> List[UUID]: - """Cancel PENDING/RUNNING jobs and return IDs that need waiting.""" - non_terminal = {SyncJobStatus.PENDING, SyncJobStatus.RUNNING, SyncJobStatus.CANCELLING} - syncs_to_wait: List[UUID] = [] - for sync_id in sync_ids: - latest_job = await self._sync_job_repo.get_latest_by_sync_id(db, sync_id=sync_id) - if not latest_job or latest_job.status not in non_terminal: - continue - if latest_job.status in (SyncJobStatus.PENDING, SyncJobStatus.RUNNING): - try: - await self._temporal_workflow_service.cancel_sync_job_workflow( - str(latest_job.id), ctx - ) - ctx.logger.info(f"Cancelled job {latest_job.id} before deletion") - except Exception as e: - ctx.logger.warning(f"Failed to cancel job {latest_job.id}: {e}") - syncs_to_wait.append(sync_id) - return syncs_to_wait - - async def _wait_for_terminal( - self, - db: AsyncSession, - syncs_to_wait: List[UUID], - timeout_seconds: int, - ctx: ApiContext, - ) -> None: - """Poll until all syncs reach a terminal state or timeout.""" - if not syncs_to_wait: - return - terminal = {SyncJobStatus.COMPLETED, SyncJobStatus.FAILED, SyncJobStatus.CANCELLED} - elapsed = 0.0 - remaining = list(syncs_to_wait) - while elapsed < timeout_seconds and remaining: - await asyncio.sleep(1.0) - elapsed += 1.0 - db.expire_all() - still_waiting = [] - for sid in remaining: - job = await self._sync_job_repo.get_latest_by_sync_id(db, sync_id=sid) - if job and job.status not in terminal: - still_waiting.append(sid) - remaining = still_waiting - if remaining: - ctx.logger.warning( - f"{len(remaining)} sync(s) did not reach terminal state " - f"within {timeout_seconds}s -- proceeding with deletion anyway" - ) - - async def _schedule_collection_cleanup( - self, - sync_ids: List[UUID], - collection_id: UUID, - organization_id: UUID, - ctx: ApiContext, - ) -> None: - """Schedule a Temporal workflow for async Vespa/ARF cleanup.""" - if not sync_ids: - return - try: - await self._temporal_workflow_service.start_cleanup_sync_data_workflow( - sync_ids=[str(sid) for sid in sync_ids], - collection_id=str(collection_id), - organization_id=str(organization_id), - ctx=ctx, - ) - except Exception as e: - ctx.logger.error( - f"Failed to schedule async cleanup for collection {collection_id}: {e}. " - f"Data may be orphaned in Vespa/ARF." - ) - - async def _validate_force_full_sync( - self, db: AsyncSession, sync_id: UUID, ctx: ApiContext - ) -> None: - """Log force_full_sync intent. No-op if no cursor (already a full sync).""" - cursor = await self._sync_cursor_repo.get_by_sync_id(db, sync_id, ctx) - if not cursor or not cursor.cursor_data: - ctx.logger.info( - f"force_full_sync requested but no cursor data exists for sync {sync_id}. " - "This sync will perform a full sync by default." - ) - return - ctx.logger.info( - f"Force full sync requested for continuous sync {sync_id}. " - "Will ignore cursor data and perform full sync with orphaned entity cleanup." - ) - - async def _resolve_connection( - self, db: AsyncSession, source_conn, ctx: ApiContext - ) -> schemas.Connection: - """Resolve the Connection (not SourceConnection!) for a source connection.""" - if not source_conn.connection_id: - raise ValueError(f"Source connection {source_conn.id} has no connection_id") - conn = await self._connection_repo.get(db, source_conn.connection_id, ctx) - if not conn: - raise ValueError(f"Connection {source_conn.connection_id} not found") - return schemas.Connection.model_validate(conn, from_attributes=True) - - def _resolve_cron( - self, - schedule_config: Optional[ScheduleConfig], - source_entry: SourceRegistryEntry, - ctx: ApiContext, - ) -> Optional[str]: - """Resolve cron schedule from config or source defaults. - - When schedule_config is provided: - - cron is a string → use it - - cron is None → caller explicitly wants no schedule - When schedule_config is None → apply source-type defaults. - """ - if schedule_config is not None: - if schedule_config.cron is not None: - return schedule_config.cron - ctx.logger.info("Schedule cron explicitly null, no schedule") - return None - - if source_entry.supports_continuous: - ctx.logger.info("Continuous source, defaulting to 5-minute schedule") - return CONTINUOUS_SOURCE_DEFAULT_CRON - - now_utc = datetime.now(timezone.utc) - cron = DAILY_CRON_TEMPLATE.format(minute=now_utc.minute, hour=now_utc.hour) - ctx.logger.info(f"Defaulting to daily at {now_utc.hour:02d}:{now_utc.minute:02d} UTC") - return cron - - def _validate_cron_for_source( - self, - cron: str, - source_entry: SourceRegistryEntry, - ) -> None: - """Reject sub-hourly schedules for non-continuous sources.""" - if source_entry.supports_continuous: - return - - if cron == "* * * * *": - raise HTTPException( - status_code=400, - detail=( - f"Source '{source_entry.short_name}' does not support " - f"continuous syncs. Minimum interval is 1 hour." - ), - ) - - match = _SUB_HOURLY_PATTERN.match(cron) - if match and int(match.group(1)) < 60: - raise HTTPException( - status_code=400, - detail=( - f"Source '{source_entry.short_name}' does not support " - f"continuous syncs. Minimum interval is 1 hour." - ), - ) diff --git a/backend/airweave/domains/syncs/protocols.py b/backend/airweave/domains/syncs/protocols.py index eeda5e2bc..7d17b55fe 100644 --- a/backend/airweave/domains/syncs/protocols.py +++ b/backend/airweave/domains/syncs/protocols.py @@ -17,7 +17,7 @@ from airweave.domains.syncs.types import SyncProvisionResult, SyncTransitionResult from airweave.models.sync import Sync from airweave.models.sync_cursor import SyncCursor -from airweave.schemas.source_connection import ScheduleConfig, SourceConnectionJob +from airweave.schemas.source_connection import ScheduleConfig from airweave.schemas.sync import SyncCreate, SyncUpdate @@ -79,147 +79,148 @@ async def get_by_sync_id( ... -class SyncRecordServiceProtocol(Protocol): - """Sync record management: create syncs and trigger runs.""" +class SyncStateMachineProtocol(Protocol): + """Validated, idempotent sync status transitions with schedule side effects.""" - async def resolve_destination_ids(self, db: AsyncSession, ctx: ApiContext) -> List[UUID]: - """Resolve destination connection IDs based on feature flags.""" + async def transition( + self, + sync_id: UUID, + target: SyncStatus, + ctx: BaseContext, + *, + reason: str = "", + ) -> SyncTransitionResult: + """Execute a validated, idempotent sync status transition. + + Side effects (schedule pause/unpause) run after the DB commit. + """ ... - async def create_sync( + +class SyncServiceProtocol(Protocol): + """Unified sync service — the public interface for the syncs domain. + + Provides lifecycle (create, get, pause, resume, delete), job management + (trigger_run, get_jobs, cancel_job), and execution (run) operations. + All methods speak the sync domain language; no source_connection types + cross this boundary. + """ + + # -- Lifecycle -- + + async def create( self, db: AsyncSession, *, name: str, source_connection_id: UUID, destination_connection_ids: List[UUID], - cron_schedule: Optional[str], + collection_id: UUID, + collection_readable_id: str, + source_entry: SourceRegistryEntry, + schedule_config: Optional[ScheduleConfig], run_immediately: bool, ctx: ApiContext, uow: UnitOfWork, - ) -> Tuple[schemas.Sync, Optional[schemas.SyncJob]]: - """Create a Sync record and optionally a PENDING SyncJob. - - All writes happen inside the caller's UoW (no commit). - """ + ) -> SyncProvisionResult: + """Create sync + optional job + Temporal schedule atomically.""" ... - async def trigger_sync_run( - self, - db: AsyncSession, - sync_id: UUID, - ctx: ApiContext, - ) -> Tuple[schemas.Sync, schemas.SyncJob]: - """Trigger a manual sync run. - - Returns (sync_schema, sync_job_schema). - Raises HTTPException if a job is already active. - """ + async def get(self, db: AsyncSession, *, sync_id: UUID, ctx: BaseContext) -> schemas.Sync: + """Get a sync by ID.""" ... - -class SyncStateMachineProtocol(Protocol): - """Validated, idempotent sync status transitions with schedule side effects.""" - - async def transition( + async def pause( self, sync_id: UUID, - target: SyncStatus, ctx: BaseContext, *, reason: str = "", ) -> SyncTransitionResult: - """Execute a validated, idempotent sync status transition. - - Side effects (schedule pause/unpause) run after the DB commit. - """ + """Pause a sync.""" ... - -class SyncServiceProtocol(Protocol): - """Sync execution: build orchestrator and run.""" - - async def run( + async def resume( self, - sync: schemas.Sync, - sync_job: schemas.SyncJob, - collection: schemas.CollectionRecord, - source_connection: schemas.Connection, + sync_id: UUID, ctx: BaseContext, - force_full_sync: bool = False, - execution_config: Optional[SyncConfig] = None, - access_token: Optional[str] = None, - ) -> schemas.Sync: - """Run a sync via SyncFactory + SyncOrchestrator.""" + *, + reason: str = "", + ) -> SyncTransitionResult: + """Resume a paused sync.""" ... - -class SyncLifecycleServiceProtocol(Protocol): - """Sync lifecycle: provision, run, get jobs, cancel, teardown.""" - - async def teardown_syncs_for_collection( + async def delete( self, db: AsyncSession, *, - sync_ids: List[UUID], + sync_id: UUID, collection_id: UUID, organization_id: UUID, ctx: ApiContext, cancel_timeout_seconds: int = 15, ) -> None: - """Cancel running workflows and schedule async cleanup for a collection's syncs.""" + """Cancel active workflows and schedule async cleanup.""" ... - async def provision_sync( - self, - db: AsyncSession, - *, - name: str, - source_connection_id: UUID, - destination_connection_ids: List[UUID], - collection_id: UUID, - collection_readable_id: str, - source_entry: SourceRegistryEntry, - schedule_config: Optional[ScheduleConfig], - run_immediately: bool, - ctx: ApiContext, - uow: UnitOfWork, - ) -> Optional[SyncProvisionResult]: - """Create sync + job + Temporal schedule atomically. + # -- Jobs -- - Returns None for federated search sources (no sync needed). - """ + async def resolve_destination_ids(self, db: AsyncSession, ctx: ApiContext) -> List[UUID]: + """Resolve destination connection IDs (interim — will move to a registry).""" ... - async def run( + async def trigger_run( self, db: AsyncSession, *, - id: UUID, + sync_id: UUID, + collection: schemas.CollectionRecord, + connection: schemas.Connection, ctx: ApiContext, force_full_sync: bool = False, - ) -> SourceConnectionJob: - """Trigger a sync run for a source connection.""" + ) -> Tuple[schemas.Sync, schemas.SyncJob]: + """Create a PENDING job and start the Temporal workflow.""" ... async def get_jobs( self, db: AsyncSession, *, - id: UUID, + sync_id: UUID, ctx: ApiContext, limit: int = 100, - ) -> List[SourceConnectionJob]: - """Get sync jobs for a source connection.""" + ) -> List[schemas.SyncJob]: + """List jobs for a sync.""" ... async def cancel_job( self, db: AsyncSession, *, - source_connection_id: UUID, job_id: UUID, ctx: ApiContext, - ) -> SourceConnectionJob: - """Cancel a running sync job for a source connection.""" + ) -> schemas.SyncJob: + """Cancel a running sync job.""" + ... + + async def validate_force_full_sync( + self, db: AsyncSession, sync_id: UUID, ctx: ApiContext + ) -> None: + """Validate and log force_full_sync intent.""" + ... + + # -- Execution -- + + async def run( + self, + sync: schemas.Sync, + sync_job: schemas.SyncJob, + collection: schemas.CollectionRecord, + source_connection: schemas.Connection, + ctx: BaseContext, + force_full_sync: bool = False, + execution_config: Optional[SyncConfig] = None, + access_token: Optional[str] = None, + ) -> schemas.Sync: + """Run a sync via SyncFactory + SyncOrchestrator.""" ... diff --git a/backend/airweave/domains/syncs/record_service.py b/backend/airweave/domains/syncs/record_service.py deleted file mode 100644 index eb223b233..000000000 --- a/backend/airweave/domains/syncs/record_service.py +++ /dev/null @@ -1,139 +0,0 @@ -"""Sync record service: create and trigger operations for Sync/SyncJob records.""" - -from typing import List, Optional, Tuple -from uuid import UUID - -from fastapi import HTTPException -from sqlalchemy.ext.asyncio import AsyncSession - -from airweave import schemas -from airweave.api.context import ApiContext -from airweave.core.constants.reserved_ids import NATIVE_VESPA_UUID -from airweave.core.shared_models import SyncJobStatus, SyncStatus -from airweave.db.unit_of_work import UnitOfWork -from airweave.domains.connections.protocols import ConnectionRepositoryProtocol -from airweave.domains.syncs.jobs.protocols import SyncJobRepositoryProtocol -from airweave.domains.syncs.protocols import ( - SyncRecordServiceProtocol, - SyncRepositoryProtocol, -) -from airweave.schemas.sync import SyncCreate -from airweave.schemas.sync_job import SyncJobCreate - - -class SyncRecordService(SyncRecordServiceProtocol): - """Create syncs, trigger sync runs, and list sync jobs via injected repositories.""" - - def __init__( - self, - sync_repo: SyncRepositoryProtocol, - sync_job_repo: SyncJobRepositoryProtocol, - connection_repo: ConnectionRepositoryProtocol, - ) -> None: - """Initialize with injected repositories.""" - self._sync_repo = sync_repo - self._sync_job_repo = sync_job_repo - self._connection_repo = connection_repo - - async def resolve_destination_ids(self, db: AsyncSession, ctx: ApiContext) -> List[UUID]: - """Resolve destination connection IDs.""" - return [NATIVE_VESPA_UUID] - - async def create_sync( - self, - db: AsyncSession, - *, - name: str, - source_connection_id: UUID, - destination_connection_ids: List[UUID], - cron_schedule: Optional[str], - run_immediately: bool, - ctx: ApiContext, - uow: UnitOfWork, - ) -> Tuple[schemas.Sync, Optional[schemas.SyncJob]]: - """Create a Sync record and optionally a PENDING SyncJob. - - All writes happen inside the caller's UoW (no commit). - """ - sync_in = SyncCreate( - name=name, - source_connection_id=source_connection_id, - destination_connection_ids=destination_connection_ids, - cron_schedule=cron_schedule, - status=SyncStatus.ACTIVE, - run_immediately=run_immediately, - ) - - sync_schema = await self._sync_repo.create( - uow.session, - obj_in=sync_in, - ctx=ctx, - uow=uow, - ) - await uow.session.flush() - - sync_job_schema: Optional[schemas.SyncJob] = None - if run_immediately: - sync_job = await self._sync_job_repo.create( - uow.session, - SyncJobCreate(sync_id=sync_schema.id, status=SyncJobStatus.PENDING), - ctx, - uow=uow, - ) - await uow.session.flush() - await uow.session.refresh(sync_job) - sync_job_schema = schemas.SyncJob.model_validate(sync_job, from_attributes=True) - - return sync_schema, sync_job_schema - - async def trigger_sync_run( - self, - db: AsyncSession, - sync_id: UUID, - ctx: ApiContext, - ) -> Tuple[schemas.Sync, schemas.SyncJob]: - """Trigger a manual sync run. - - Checks for existing active jobs, fetches the sync with - connections, creates a new SyncJob inside a UoW, and returns - both schemas. - - Raises: - HTTPException 400: if a job is already active. - ValueError: if the sync is not found. - """ - sync = await self._sync_repo.get(db, sync_id, ctx) - if not sync: - raise ValueError(f"Sync {sync_id} not found") - - if SyncStatus(sync.status) != SyncStatus.ACTIVE: - raise HTTPException( - status_code=409, - detail=f"Cannot trigger sync: sync is {sync.status}", - ) - - active_jobs = await self._sync_job_repo.get_active_for_sync(db, sync_id, ctx) - if active_jobs: - job_status = active_jobs[0].status.lower() - raise HTTPException( - status_code=400, - detail=f"Cannot start new sync: a sync job is already {job_status}", - ) - - sync_schema = schemas.Sync.model_validate(sync, from_attributes=True) - - async with UnitOfWork(db) as uow: - sync_job = await self._sync_job_repo.create( - uow.session, - schemas.SyncJobCreate( - sync_id=sync_id, - status=SyncJobStatus.PENDING, - ), - ctx, - uow=uow, - ) - await uow.commit() - await uow.session.refresh(sync_job) - sync_job_schema = schemas.SyncJob.model_validate(sync_job, from_attributes=True) - - return sync_schema, sync_job_schema diff --git a/backend/airweave/domains/syncs/service.py b/backend/airweave/domains/syncs/service.py index 2685ad624..024570f2d 100644 --- a/backend/airweave/domains/syncs/service.py +++ b/backend/airweave/domains/syncs/service.py @@ -1,40 +1,363 @@ -"""Sync execution service — runs a sync via SyncFactory + SyncOrchestrator. +"""Unified sync service — single transactional interface for the syncs domain. -Called exclusively from RunSyncActivity (Temporal worker). +Consolidates SyncRecordService, SyncLifecycleService, and the sync runner +into one service with clean, directed semantics. All callers interact through +this interface; internal implementation details (state machines, repos) are hidden. + +Methods speak the sync domain language: create, get, pause, resume, delete, +trigger_run, cancel_job, get_jobs, run. No source_connection types cross +this boundary. """ -from typing import Optional +import asyncio +import re +from datetime import datetime, timezone +from typing import List, Optional, Tuple +from uuid import UUID + +from fastapi import HTTPException +from sqlalchemy.ext.asyncio import AsyncSession from airweave import schemas from airweave.api.context import ApiContext +from airweave.core.constants.reserved_ids import NATIVE_VESPA_UUID +from airweave.core.context import BaseContext from airweave.core.shared_models import SyncJobStatus, SyncStatus from airweave.db.session import get_db_context +from airweave.db.unit_of_work import UnitOfWork from airweave.domains.sources.exceptions.classifier import classify_error +from airweave.domains.sources.types import SourceRegistryEntry from airweave.domains.sync_pipeline.config import SyncConfig from airweave.domains.sync_pipeline.protocols import SyncFactoryProtocol -from airweave.domains.syncs.jobs.protocols import SyncJobStateMachineProtocol +from airweave.domains.syncs.jobs.protocols import ( + SyncJobRepositoryProtocol, + SyncJobStateMachineProtocol, +) from airweave.domains.syncs.protocols import ( + SyncCursorRepositoryProtocol, + SyncRepositoryProtocol, SyncServiceProtocol, SyncStateMachineProtocol, ) +from airweave.domains.syncs.types import ( + CONTINUOUS_SOURCE_DEFAULT_CRON, + DAILY_CRON_TEMPLATE, + SyncProvisionResult, + SyncTransitionResult, +) +from airweave.domains.temporal.protocols import ( + TemporalScheduleServiceProtocol, + TemporalWorkflowServiceProtocol, +) +from airweave.schemas.source_connection import ScheduleConfig +from airweave.schemas.sync import SyncCreate +from airweave.schemas.sync_job import SyncJobCreate + +_SUB_HOURLY_PATTERN = re.compile(r"^\*/([1-5]?[0-9]) \* \* \* \*$") class SyncService(SyncServiceProtocol): - """Runs a sync via SyncFactory + SyncOrchestrator. + """Unified sync service — the public interface for the syncs domain. - Stateless — the only production caller is RunSyncActivity. + Callers use directed methods (create, pause, resume, delete) rather than + raw state transitions. The state machine is an internal implementation detail. """ - def __init__( + def __init__( # noqa: D107 self, - state_machine: SyncJobStateMachineProtocol, + sync_repo: SyncRepositoryProtocol, + sync_job_repo: SyncJobRepositoryProtocol, + sync_cursor_repo: SyncCursorRepositoryProtocol, + state_machine: SyncStateMachineProtocol, + job_state_machine: SyncJobStateMachineProtocol, + temporal_workflow_service: TemporalWorkflowServiceProtocol, + temporal_schedule_service: TemporalScheduleServiceProtocol, sync_factory: SyncFactoryProtocol, - sync_state_machine: SyncStateMachineProtocol, ) -> None: - """Initialize with state machine and factory dependencies.""" + self._sync_repo = sync_repo + self._sync_job_repo = sync_job_repo + self._sync_cursor_repo = sync_cursor_repo self._state_machine = state_machine + self._job_state_machine = job_state_machine + self._temporal_workflow_service = temporal_workflow_service + self._temporal_schedule_service = temporal_schedule_service self._sync_factory = sync_factory - self._sync_state_machine = sync_state_machine + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + async def create( + self, + db: AsyncSession, + *, + name: str, + source_connection_id: UUID, + destination_connection_ids: List[UUID], + collection_id: UUID, + collection_readable_id: str, + source_entry: SourceRegistryEntry, + schedule_config: Optional[ScheduleConfig], + run_immediately: bool, + ctx: ApiContext, + uow: UnitOfWork, + ) -> SyncProvisionResult: + """Create sync + optional job + Temporal schedule atomically. + + Raises ValueError if called for federated sources or when there is + neither a schedule nor an immediate run request — callers must guard + these cases before calling create. + """ + if source_entry.federated_search: + raise ValueError(f"Cannot create sync for federated source '{source_entry.short_name}'") + + cron = self._resolve_cron(schedule_config, source_entry, ctx) + + if not cron and not run_immediately: + raise ValueError("Cannot create sync: no schedule and run_immediately=False") + + if cron: + self._validate_cron_for_source(cron, source_entry) + + sync_schema, sync_job_schema = await self._create_sync_records( + uow.session, + name=f"Sync for {name}", + source_connection_id=source_connection_id, + destination_connection_ids=destination_connection_ids, + cron_schedule=cron, + run_immediately=run_immediately, + ctx=ctx, + uow=uow, + ) + + if cron: + await self._temporal_schedule_service.create_or_update_schedule( + sync_id=sync_schema.id, + cron_schedule=cron, + db=uow.session, + ctx=ctx, + uow=uow, + collection_readable_id=collection_readable_id, + connection_id=source_connection_id, + ) + + return SyncProvisionResult( + sync_id=sync_schema.id, + sync=sync_schema, + sync_job=sync_job_schema, + cron_schedule=cron, + ) + + async def get(self, db: AsyncSession, *, sync_id: UUID, ctx: BaseContext) -> schemas.Sync: + """Get a sync by ID.""" + sync = await self._sync_repo.get(db, sync_id, ctx) + if not sync: + raise HTTPException(status_code=404, detail=f"Sync {sync_id} not found") + return sync + + async def pause( + self, + sync_id: UUID, + ctx: BaseContext, + *, + reason: str = "", + ) -> SyncTransitionResult: + """Pause a sync: update DB status, pause Temporal schedules.""" + return await self._state_machine.transition( + sync_id=sync_id, target=SyncStatus.PAUSED, ctx=ctx, reason=reason + ) + + async def resume( + self, + sync_id: UUID, + ctx: BaseContext, + *, + reason: str = "", + ) -> SyncTransitionResult: + """Resume a paused sync: update DB status, unpause Temporal schedules.""" + return await self._state_machine.transition( + sync_id=sync_id, target=SyncStatus.ACTIVE, ctx=ctx, reason=reason + ) + + async def delete( + self, + db: AsyncSession, + *, + sync_id: UUID, + collection_id: UUID, + organization_id: UUID, + ctx: ApiContext, + cancel_timeout_seconds: int = 15, + ) -> None: + """Cancel active workflows and schedule async cleanup for a single sync. + + 1. Cancels PENDING/RUNNING workflows via Temporal. + 2. Polls until terminal state (up to cancel_timeout_seconds). + 3. Schedules async cleanup workflow for Vespa/ARF/schedules. + + The caller is responsible for the CASCADE delete of DB records. + """ + needs_wait = await self._cancel_active_sync(db, sync_id, ctx) + if needs_wait: + await self._wait_for_terminal(db, sync_id, cancel_timeout_seconds, ctx) + await self._schedule_cleanup(sync_id, collection_id, organization_id, ctx) + + # ------------------------------------------------------------------ + # Jobs + # ------------------------------------------------------------------ + + async def resolve_destination_ids(self, db: AsyncSession, ctx: ApiContext) -> List[UUID]: + """Resolve destination connection IDs.""" + return [NATIVE_VESPA_UUID] + + async def trigger_run( + self, + db: AsyncSession, + *, + sync_id: UUID, + collection: schemas.CollectionRecord, + connection: schemas.Connection, + ctx: ApiContext, + force_full_sync: bool = False, + ) -> Tuple[schemas.Sync, schemas.SyncJob]: + """Create a PENDING job and start the Temporal workflow. + + Validates the sync is ACTIVE and no active jobs exist, creates the + job record, then starts the Temporal workflow. + """ + sync = await self._sync_repo.get(db, sync_id, ctx) + if not sync: + raise HTTPException(status_code=404, detail=f"Sync {sync_id} not found") + + if SyncStatus(sync.status) != SyncStatus.ACTIVE: + raise HTTPException( + status_code=409, + detail=f"Cannot trigger sync: sync is {sync.status}", + ) + + active_jobs = await self._sync_job_repo.get_active_for_sync(db, sync_id, ctx) + if active_jobs: + job_status = active_jobs[0].status.lower() + raise HTTPException( + status_code=400, + detail=f"Cannot start new sync: a sync job is already {job_status}", + ) + + sync_schema = schemas.Sync.model_validate(sync, from_attributes=True) + + sync_job = await self._sync_job_repo.create( + db, + SyncJobCreate(sync_id=sync_id, status=SyncJobStatus.PENDING), + ctx, + ) + await db.flush() + await db.refresh(sync_job) + sync_job_schema = schemas.SyncJob.model_validate(sync_job, from_attributes=True) + + await self._temporal_workflow_service.run_source_connection_workflow( + sync=sync_schema, + sync_job=sync_job_schema, + collection=collection, + connection=connection, + ctx=ctx, + force_full_sync=force_full_sync, + ) + + return sync_schema, sync_job_schema + + async def get_jobs( + self, + db: AsyncSession, + *, + sync_id: UUID, + ctx: ApiContext, + limit: int = 100, + ) -> List[schemas.SyncJob]: + """List jobs for a sync, most recent first.""" + jobs = await self._sync_job_repo.get_all_by_sync_id(db, sync_id, ctx, limit=limit) + return [schemas.SyncJob.model_validate(j, from_attributes=True) for j in jobs] + + async def cancel_job( + self, + db: AsyncSession, + *, + job_id: UUID, + ctx: ApiContext, + ) -> schemas.SyncJob: + """Cancel a running sync job. + + Transitions to CANCELLING, sends cancel to Temporal, and handles + edge cases (workflow not found, Temporal failure with one retry). + """ + sync_job = await self._sync_job_repo.get(db, job_id, ctx) + if not sync_job: + raise HTTPException(status_code=404, detail="Sync job not found") + + if sync_job.status not in (SyncJobStatus.PENDING, SyncJobStatus.RUNNING): + raise HTTPException( + status_code=400, + detail=f"Cannot cancel job in {sync_job.status} state", + ) + + await self._job_state_machine.transition( + sync_job_id=job_id, target=SyncJobStatus.CANCELLING, ctx=ctx + ) + + cancel_result = await self._cancel_temporal_workflow_with_retry(job_id, ctx) + + if not cancel_result["success"]: + raise HTTPException( + status_code=502, detail="Failed to request cancellation from Temporal" + ) + + if not cancel_result["workflow_found"]: + # NOT_FOUND means the workflow already completed, never started, + # or was cleaned up. Check current DB state before marking cancelled. + await db.refresh(sync_job) + terminal = {SyncJobStatus.COMPLETED, SyncJobStatus.FAILED, SyncJobStatus.CANCELLED} + if sync_job.status not in terminal: + ctx.logger.info(f"Workflow not found for job {job_id}, marking CANCELLED") + await self._job_state_machine.transition( + sync_job_id=job_id, target=SyncJobStatus.CANCELLED, ctx=ctx + ) + + await db.refresh(sync_job) + return schemas.SyncJob.model_validate(sync_job, from_attributes=True) + + async def _cancel_temporal_workflow_with_retry( + self, job_id: UUID, ctx: ApiContext, max_retries: int = 1 + ) -> dict[str, bool]: + """Send cancellation to Temporal with a single retry on RPC failure.""" + for attempt in range(1 + max_retries): + result = await self._temporal_workflow_service.cancel_sync_job_workflow( + str(job_id), ctx + ) + if result["success"] or result["workflow_found"]: + return result + if attempt < max_retries: + await asyncio.sleep(0.5) + ctx.logger.info(f"Retrying cancel for job {job_id} (attempt {attempt + 2})") + return result + + async def validate_force_full_sync( + self, db: AsyncSession, sync_id: UUID, ctx: ApiContext + ) -> None: + """Log force_full_sync intent. No-op if no cursor (already a full sync).""" + cursor = await self._sync_cursor_repo.get_by_sync_id(db, sync_id, ctx) + if not cursor or not cursor.cursor_data: + ctx.logger.info( + f"force_full_sync requested but no cursor data exists for sync {sync_id}. " + "This sync will perform a full sync by default." + ) + return + ctx.logger.info( + f"Force full sync requested for continuous sync {sync_id}. " + "Will ignore cursor data and perform full sync with orphaned entity cleanup." + ) + + # ------------------------------------------------------------------ + # Execution (Temporal activity entry point) + # ------------------------------------------------------------------ async def run( self, @@ -47,7 +370,10 @@ async def run( execution_config: Optional[SyncConfig] = None, access_token: Optional[str] = None, ) -> schemas.Sync: - """Run a sync.""" + """Run a sync via SyncFactory + SyncOrchestrator. + + Called exclusively from RunSyncActivity (Temporal worker). + """ try: async with get_db_context() as db: orchestrator = await self._sync_factory.create_orchestrator( @@ -66,7 +392,7 @@ async def run( classification = classify_error(e) - await self._state_machine.transition( + await self._job_state_machine.transition( sync_job_id=sync_job.id, target=SyncJobStatus.FAILED, ctx=ctx, @@ -76,7 +402,7 @@ async def run( if classification.category is not None and sync: try: - await self._sync_state_machine.transition( + await self._state_machine.transition( sync_id=sync.id, target=SyncStatus.PAUSED, ctx=ctx, @@ -88,3 +414,167 @@ async def run( raise e return await orchestrator.run() + + # ------------------------------------------------------------------ + # Private: record creation + # ------------------------------------------------------------------ + + async def _create_sync_records( + self, + db: AsyncSession, + *, + name: str, + source_connection_id: UUID, + destination_connection_ids: List[UUID], + cron_schedule: Optional[str], + run_immediately: bool, + ctx: ApiContext, + uow: UnitOfWork, + ) -> Tuple[schemas.Sync, Optional[schemas.SyncJob]]: + """Create a Sync record and optionally a PENDING SyncJob. + + All writes happen inside the caller's UoW (no commit). + """ + sync_in = SyncCreate( + name=name, + source_connection_id=source_connection_id, + destination_connection_ids=destination_connection_ids, + cron_schedule=cron_schedule, + status=SyncStatus.ACTIVE, + run_immediately=run_immediately, + ) + + sync_schema = await self._sync_repo.create(uow.session, obj_in=sync_in, ctx=ctx, uow=uow) + await uow.session.flush() + + sync_job_schema: Optional[schemas.SyncJob] = None + if run_immediately: + sync_job = await self._sync_job_repo.create( + uow.session, + SyncJobCreate(sync_id=sync_schema.id, status=SyncJobStatus.PENDING), + ctx, + uow=uow, + ) + await uow.session.flush() + await uow.session.refresh(sync_job) + sync_job_schema = schemas.SyncJob.model_validate(sync_job, from_attributes=True) + + return sync_schema, sync_job_schema + + # ------------------------------------------------------------------ + # Private: cron resolution + # ------------------------------------------------------------------ + + def _resolve_cron( + self, + schedule_config: Optional[ScheduleConfig], + source_entry: SourceRegistryEntry, + ctx: ApiContext, + ) -> Optional[str]: + """Resolve cron schedule from config or source defaults.""" + if schedule_config is not None: + if schedule_config.cron is not None: + return schedule_config.cron + ctx.logger.info("Schedule cron explicitly null, no schedule") + return None + + if source_entry.supports_continuous: + ctx.logger.info("Continuous source, defaulting to 5-minute schedule") + return CONTINUOUS_SOURCE_DEFAULT_CRON + + now_utc = datetime.now(timezone.utc) + cron = DAILY_CRON_TEMPLATE.format(minute=now_utc.minute, hour=now_utc.hour) + ctx.logger.info(f"Defaulting to daily at {now_utc.hour:02d}:{now_utc.minute:02d} UTC") + return cron + + def _validate_cron_for_source(self, cron: str, source_entry: SourceRegistryEntry) -> None: + """Reject sub-hourly schedules for non-continuous sources.""" + if source_entry.supports_continuous: + return + + if cron == "* * * * *": + raise HTTPException( + status_code=400, + detail=( + f"Source '{source_entry.short_name}' does not support " + f"continuous syncs. Minimum interval is 1 hour." + ), + ) + + match = _SUB_HOURLY_PATTERN.match(cron) + if match and int(match.group(1)) < 60: + raise HTTPException( + status_code=400, + detail=( + f"Source '{source_entry.short_name}' does not support " + f"continuous syncs. Minimum interval is 1 hour." + ), + ) + + # ------------------------------------------------------------------ + # Private: delete helpers + # ------------------------------------------------------------------ + + async def _cancel_active_sync( + self, + db: AsyncSession, + sync_id: UUID, + ctx: ApiContext, + ) -> bool: + """Cancel PENDING/RUNNING job for a sync. Returns True if it needs waiting.""" + non_terminal = {SyncJobStatus.PENDING, SyncJobStatus.RUNNING, SyncJobStatus.CANCELLING} + latest_job = await self._sync_job_repo.get_latest_by_sync_id(db, sync_id=sync_id) + if not latest_job or latest_job.status not in non_terminal: + return False + if latest_job.status in (SyncJobStatus.PENDING, SyncJobStatus.RUNNING): + try: + await self._temporal_workflow_service.cancel_sync_job_workflow( + str(latest_job.id), ctx + ) + ctx.logger.info(f"Cancelled job {latest_job.id} before deletion") + except Exception as e: + ctx.logger.warning(f"Failed to cancel job {latest_job.id}: {e}") + return True + + async def _wait_for_terminal( + self, + db: AsyncSession, + sync_id: UUID, + timeout_seconds: int, + ctx: ApiContext, + ) -> None: + """Poll until the sync's latest job reaches a terminal state or timeout.""" + terminal = {SyncJobStatus.COMPLETED, SyncJobStatus.FAILED, SyncJobStatus.CANCELLED} + elapsed = 0.0 + while elapsed < timeout_seconds: + await asyncio.sleep(1.0) + elapsed += 1.0 + db.expire_all() + job = await self._sync_job_repo.get_latest_by_sync_id(db, sync_id=sync_id) + if not job or job.status in terminal: + return + ctx.logger.warning( + f"Sync {sync_id} did not reach terminal state " + f"within {timeout_seconds}s -- proceeding with deletion anyway" + ) + + async def _schedule_cleanup( + self, + sync_id: UUID, + collection_id: UUID, + organization_id: UUID, + ctx: ApiContext, + ) -> None: + """Schedule a Temporal workflow for async Vespa/ARF cleanup.""" + try: + await self._temporal_workflow_service.start_cleanup_sync_data_workflow( + sync_ids=[str(sync_id)], + collection_id=str(collection_id), + organization_id=str(organization_id), + ctx=ctx, + ) + except Exception as e: + ctx.logger.error( + f"Failed to schedule async cleanup for sync {sync_id}: {e}. " + f"Data may be orphaned in Vespa/ARF." + ) diff --git a/backend/airweave/domains/syncs/tests/test_lifecycle_service.py b/backend/airweave/domains/syncs/tests/test_lifecycle_service.py deleted file mode 100644 index 7a926292c..000000000 --- a/backend/airweave/domains/syncs/tests/test_lifecycle_service.py +++ /dev/null @@ -1,876 +0,0 @@ -"""Table-driven unit tests for SyncLifecycleService. - -Covers provision_sync(), run(), get_jobs(), and cancel_job() with -happy paths and error edge cases. -""" - -from dataclasses import dataclass, field -from datetime import datetime, timezone -from typing import Optional -from unittest.mock import AsyncMock, MagicMock, patch -from uuid import UUID, uuid4 - -import pytest -from fastapi import HTTPException - -from airweave import schemas -from airweave.api.context import ApiContext -from airweave.core.logging import logger -from airweave.core.shared_models import AuthMethod, SyncJobStatus, SyncStatus -from airweave.domains.collections.fakes.repository import FakeCollectionRepository -from airweave.domains.connections.fakes.repository import FakeConnectionRepository -from airweave.domains.source_connections.fakes.repository import ( - FakeSourceConnectionRepository, -) -from airweave.domains.source_connections.fakes.response import FakeResponseBuilder -from airweave.domains.sources.types import SourceRegistryEntry -from airweave.domains.syncs.fakes.cursor_repository import FakeSyncCursorRepository -from airweave.domains.syncs.jobs.fakes.repository import FakeSyncJobRepository -from airweave.domains.syncs.fakes.record_service import FakeSyncRecordService -from airweave.domains.syncs.lifecycle_service import SyncLifecycleService -from airweave.domains.syncs.types import CONTINUOUS_SOURCE_DEFAULT_CRON, SyncProvisionResult -from airweave.domains.temporal.fakes.schedule_service import FakeTemporalScheduleService -from airweave.domains.temporal.fakes.service import FakeTemporalWorkflowService -from airweave.models.collection import Collection # spec only -from airweave.models.connection import Connection # spec only -from airweave.models.source_connection import SourceConnection # spec only -from airweave.models.sync_cursor import SyncCursor # spec only -from airweave.models.sync_job import SyncJob # spec only -from airweave.platform.configs._base import Fields -from airweave.schemas.organization import Organization -from airweave.schemas.source_connection import ScheduleConfig - -NOW = datetime.now(timezone.utc) -ORG_ID = uuid4() -SC_ID = uuid4() -SYNC_ID = uuid4() -JOB_ID = uuid4() -COLLECTION_ID = uuid4() -CONNECTION_ID = uuid4() -DEST_CONN_ID = uuid4() - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _ctx() -> ApiContext: - org = Organization(id=str(ORG_ID), name="Test Org", created_at=NOW, modified_at=NOW) - return ApiContext( - request_id="test-req", - organization=org, - auth_method=AuthMethod.SYSTEM, - logger=logger.with_context(request_id="test-req"), - ) - - -def _source_connection( - id: UUID = SC_ID, - sync_id: Optional[UUID] = SYNC_ID, - connection_id: UUID = CONNECTION_ID, -) -> MagicMock: - sc = MagicMock(spec=SourceConnection) - sc.id = id - sc.sync_id = sync_id - sc.connection_id = connection_id - sc.readable_collection_id = "test-collection" - sc.organization_id = ORG_ID - return sc - - -def _collection() -> MagicMock: - col = MagicMock(spec=Collection) - col.id = COLLECTION_ID - col.name = "Test Collection" - col.readable_id = "test-collection" - col.organization_id = ORG_ID - col.vector_db_deployment_metadata_id = uuid4() - col.sync_config = None - col.created_by_email = None - col.modified_by_email = None - return col - - -def _connection() -> MagicMock: - conn = MagicMock(spec=Connection) - conn.id = CONNECTION_ID - conn.name = "Test Connection" - conn.readable_id = "test-connection-abc123" - conn.description = None - conn.short_name = "github" - conn.integration_type = "source" - conn.integration_credential_id = None - conn.status = "active" - conn.organization_id = ORG_ID - conn.created_at = NOW - conn.modified_at = NOW - conn.created_by_email = None - conn.modified_by_email = None - return conn - - -def _sync_job( - id: UUID = JOB_ID, - sync_id: UUID = SYNC_ID, - status: SyncJobStatus = SyncJobStatus.PENDING, -) -> MagicMock: - job = MagicMock(spec=SyncJob) - job.id = id - job.sync_id = sync_id - job.status = status - job.organization_id = ORG_ID - job.created_at = NOW - job.modified_at = NOW - job.created_by_email = "test@example.com" - job.modified_by_email = "test@example.com" - return job - - -class _FakeEventBus: - """Minimal fake for EventBus.publish.""" - - def __init__(self) -> None: - self.events: list = [] - - async def publish(self, event) -> None: - self.events.append(event) - - -def _build_service( - sc_repo=None, - collection_repo=None, - connection_repo=None, - sync_cursor_repo=None, - sync_service=None, - state_machine=None, - sync_job_repo=None, - temporal_workflow_service=None, - temporal_schedule_service=None, - response_builder=None, - event_bus=None, -) -> SyncLifecycleService: - return SyncLifecycleService( - sc_repo=sc_repo or FakeSourceConnectionRepository(), - collection_repo=collection_repo or FakeCollectionRepository(), - connection_repo=connection_repo or FakeConnectionRepository(), - sync_cursor_repo=sync_cursor_repo or FakeSyncCursorRepository(), - sync_service=sync_service or FakeSyncRecordService(), - state_machine=state_machine or AsyncMock(), - sync_job_repo=sync_job_repo or FakeSyncJobRepository(), - temporal_workflow_service=temporal_workflow_service or FakeTemporalWorkflowService(), - temporal_schedule_service=temporal_schedule_service or FakeTemporalScheduleService(), - response_builder=response_builder or FakeResponseBuilder(), - event_bus=event_bus or _FakeEventBus(), - ) - - -# --------------------------------------------------------------------------- -# Source entry helper -# --------------------------------------------------------------------------- - - -def _source_entry( - short_name: str = "github", - supports_continuous: bool = False, - federated_search: bool = False, -) -> SourceRegistryEntry: - """Create a minimal SourceRegistryEntry for testing.""" - empty_fields = Fields(fields=[]) - return SourceRegistryEntry( - name="Test Source", - short_name=short_name, - description="Test source for unit tests", - class_name="FakeSource", - source_class_ref=type("FakeSource", (), {}), - config_ref=None, - auth_config_ref=None, - auth_fields=empty_fields, - config_fields=empty_fields, - supported_auth_providers=[], - runtime_auth_all_fields=[], - runtime_auth_optional_fields=set(), - auth_methods=None, - oauth_type=None, - requires_byoc=False, - supports_continuous=supports_continuous, - supports_cursor=False, - federated_search=federated_search, - supports_temporal_relevance=False, - supports_access_control=False, - rate_limit_level=None, - feature_flag=None, - labels=None, - output_entity_definitions=[], - ) - - -def _sync_schema(id: UUID = SYNC_ID) -> schemas.Sync: - """Create a minimal Sync schema for testing.""" - return schemas.Sync( - id=id, - name="Sync for Test", - source_connection_id=SC_ID, - destination_connection_ids=[DEST_CONN_ID], - status=SyncStatus.ACTIVE, - organization_id=ORG_ID, - created_at=NOW, - modified_at=NOW, - ) - - -def _sync_job_schema(id: UUID = JOB_ID, sync_id: UUID = SYNC_ID) -> schemas.SyncJob: - """Create a minimal SyncJob schema for testing.""" - return schemas.SyncJob( - id=id, - sync_id=sync_id, - status=SyncJobStatus.PENDING, - organization_id=ORG_ID, - created_at=NOW, - modified_at=NOW, - ) - - -class _FakeUoW: - """Minimal UoW fake — just exposes .session.""" - - def __init__(self, session=None): - self.session = session or AsyncMock() - - -# --------------------------------------------------------------------------- -# run() tests -# --------------------------------------------------------------------------- - - -@dataclass -class RunCase: - """Table-driven case for run().""" - - name: str - sc: Optional[SourceConnection] = None - collection: Optional[Collection] = None - connection: Optional[Connection] = None - force_full_sync: bool = False - cursor: Optional[SyncCursor] = None - trigger_result: Optional[tuple] = None - expected_error: Optional[str] = None - expected_status: Optional[int] = None - - -_SYNC_SCHEMA = MagicMock() -_SYNC_SCHEMA.model_dump.return_value = {} -_SYNC_JOB_SCHEMA = MagicMock() -_SYNC_JOB_SCHEMA.id = JOB_ID -_SYNC_JOB_SCHEMA.to_source_connection_job.return_value = MagicMock() - - -RUN_CASES = [ - RunCase( - name="missing_source_connection", - expected_error="Source connection not found", - expected_status=404, - ), - RunCase( - name="no_sync_id", - sc=_source_connection(sync_id=None), - expected_error="Source connection has no associated sync", - expected_status=400, - ), - # force_full_sync_no_cursor: removed — service now logs info and proceeds - # (no cursor means first sync, which is inherently a full sync) -] - - -@pytest.mark.asyncio -@pytest.mark.parametrize("case", RUN_CASES, ids=lambda c: c.name) -async def test_run_errors(case: RunCase): - """Test run() error paths.""" - sc_repo = FakeSourceConnectionRepository() - collection_repo = FakeCollectionRepository() - connection_repo = FakeConnectionRepository() - sync_cursor_repo = FakeSyncCursorRepository() - - if case.sc: - sc_repo.seed(case.sc.id, case.sc) - if case.collection: - collection_repo.seed(case.collection.id, case.collection) - collection_repo.seed_readable(case.collection.readable_id, case.collection) - if case.connection: - connection_repo.seed(case.connection.id, case.connection) - if case.cursor: - sync_cursor_repo.seed(SYNC_ID, case.cursor) - - svc = _build_service( - sc_repo=sc_repo, - collection_repo=collection_repo, - connection_repo=connection_repo, - sync_cursor_repo=sync_cursor_repo, - ) - - with pytest.raises(HTTPException) as exc_info: - await svc.run( - AsyncMock(), - id=case.sc.id if case.sc else uuid4(), - ctx=_ctx(), - force_full_sync=case.force_full_sync, - ) - - assert exc_info.value.status_code == case.expected_status - assert case.expected_error in str(exc_info.value.detail) - - -@pytest.mark.asyncio -async def test_run_force_full_sync_happy_path(): - """Test run() with force_full_sync=True and valid cursor data.""" - sc_repo = FakeSourceConnectionRepository() - collection_repo = FakeCollectionRepository() - connection_repo = FakeConnectionRepository() - sync_cursor_repo = FakeSyncCursorRepository() - sync_service = FakeSyncRecordService() - temporal_workflow_service = FakeTemporalWorkflowService() - event_bus = _FakeEventBus() - - sc = _source_connection() - sc_repo.seed(SC_ID, sc) - collection_repo.seed_readable("test-collection", _collection()) - connection_repo.seed(CONNECTION_ID, _connection()) - - cursor = MagicMock(spec=SyncCursor) - cursor.cursor_data = {"last_modified": "2024-01-01"} - sync_cursor_repo.seed(SYNC_ID, cursor) - - mock_sync = MagicMock() - mock_sync_job = MagicMock() - sync_service.set_trigger_result(mock_sync, mock_sync_job) - - mock_collection_schema = MagicMock() - mock_collection_schema.id = COLLECTION_ID - mock_collection_schema.name = "Test Collection" - mock_collection_schema.readable_id = "test-collection" - - mock_connection_schema = MagicMock() - mock_connection_schema.short_name = "github" - - expected_sc_job = MagicMock() - mock_sj_schema = MagicMock() - mock_sj_schema.id = JOB_ID - mock_sj_schema.to_source_connection_job.return_value = expected_sc_job - - svc = _build_service( - sc_repo=sc_repo, - collection_repo=collection_repo, - connection_repo=connection_repo, - sync_cursor_repo=sync_cursor_repo, - sync_service=sync_service, - temporal_workflow_service=temporal_workflow_service, - event_bus=event_bus, - ) - - _mod = "airweave.domains.syncs.lifecycle_service.schemas" - with ( - patch(f"{_mod}.Collection.model_validate", return_value=mock_collection_schema), - patch(f"{_mod}.Connection.model_validate", return_value=mock_connection_schema), - patch(f"{_mod}.SyncJob.model_validate", return_value=mock_sj_schema), - ): - result = await svc.run(AsyncMock(), id=SC_ID, ctx=_ctx(), force_full_sync=True) - - assert result == expected_sc_job - assert len(event_bus.events) == 1 - assert len(temporal_workflow_service._calls) == 1 - wf_call = temporal_workflow_service._calls[0] - assert wf_call[0] == "run_source_connection_workflow" - assert wf_call[7] is True # force_full_sync arg - - -@pytest.mark.asyncio -async def test_run_happy_path(): - """Test run() happy path: triggers workflow and publishes event.""" - sc_repo = FakeSourceConnectionRepository() - collection_repo = FakeCollectionRepository() - connection_repo = FakeConnectionRepository() - sync_service = FakeSyncRecordService() - temporal_workflow_service = FakeTemporalWorkflowService() - event_bus = _FakeEventBus() - - sc = _source_connection() - sc_repo.seed(SC_ID, sc) - collection_repo.seed_readable("test-collection", _collection()) - connection_repo.seed(CONNECTION_ID, _connection()) - - mock_sync = MagicMock() - mock_sync_job = MagicMock() - sync_service.set_trigger_result(mock_sync, mock_sync_job) - - mock_collection_schema = MagicMock() - mock_collection_schema.id = COLLECTION_ID - mock_collection_schema.name = "Test Collection" - mock_collection_schema.readable_id = "test-collection" - - mock_connection_schema = MagicMock() - mock_connection_schema.short_name = "github" - - expected_sc_job = MagicMock() - mock_sj_schema = MagicMock() - mock_sj_schema.id = JOB_ID - mock_sj_schema.to_source_connection_job.return_value = expected_sc_job - - svc = _build_service( - sc_repo=sc_repo, - collection_repo=collection_repo, - connection_repo=connection_repo, - sync_service=sync_service, - temporal_workflow_service=temporal_workflow_service, - event_bus=event_bus, - ) - - _mod = "airweave.domains.syncs.lifecycle_service.schemas" - with ( - patch(f"{_mod}.Collection.model_validate", return_value=mock_collection_schema), - patch(f"{_mod}.Connection.model_validate", return_value=mock_connection_schema), - patch(f"{_mod}.SyncJob.model_validate", return_value=mock_sj_schema), - ): - result = await svc.run(AsyncMock(), id=SC_ID, ctx=_ctx()) - - assert result == expected_sc_job - assert len(event_bus.events) == 1 - assert len(temporal_workflow_service._calls) == 1 - assert temporal_workflow_service._calls[0][0] == "run_source_connection_workflow" - - -# --------------------------------------------------------------------------- -# get_jobs() tests -# --------------------------------------------------------------------------- - - -@dataclass -class GetJobsCase: - """Table-driven case for get_jobs().""" - - name: str - sc: Optional[SourceConnection] = None - jobs: list = field(default_factory=list) - expected_count: int = 0 - expected_error: Optional[str] = None - expected_status: Optional[int] = None - - -GET_JOBS_CASES = [ - GetJobsCase( - name="missing_source_connection", - expected_error="Source connection not found", - expected_status=404, - ), - GetJobsCase( - name="no_sync_id_returns_empty", - sc=_source_connection(sync_id=None), - expected_count=0, - ), - GetJobsCase( - name="empty_jobs", - sc=_source_connection(), - expected_count=0, - ), - GetJobsCase( - name="with_seeded_jobs", - sc=_source_connection(), - jobs=[_sync_job(id=uuid4()), _sync_job(id=uuid4())], - expected_count=2, - ), -] - - -@pytest.mark.asyncio -@pytest.mark.parametrize("case", GET_JOBS_CASES, ids=lambda c: c.name) -async def test_get_jobs(case: GetJobsCase): - """Test get_jobs() with table-driven cases.""" - sc_repo = FakeSourceConnectionRepository() - sync_job_repo = FakeSyncJobRepository() - - if case.sc: - sc_repo.seed(case.sc.id, case.sc) - if case.jobs and case.sc and case.sc.sync_id: - sync_job_repo.seed_jobs_for_sync(case.sc.sync_id, case.jobs) - - svc = _build_service(sc_repo=sc_repo, sync_job_repo=sync_job_repo) - - if case.expected_error: - with pytest.raises(HTTPException) as exc_info: - await svc.get_jobs(AsyncMock(), id=case.sc.id if case.sc else uuid4(), ctx=_ctx()) - assert exc_info.value.status_code == case.expected_status - else: - result = await svc.get_jobs(AsyncMock(), id=case.sc.id, ctx=_ctx()) - assert len(result) == case.expected_count - - -# --------------------------------------------------------------------------- -# cancel_job() tests -# --------------------------------------------------------------------------- - - -@dataclass -class CancelCase: - """Table-driven case for cancel_job().""" - - name: str - sc: Optional[SourceConnection] = None - job: Optional[SyncJob] = None - cancel_success: bool = True - workflow_found: bool = True - expected_error: Optional[str] = None - expected_status: Optional[int] = None - - -CANCEL_CASES = [ - CancelCase( - name="missing_source_connection", - expected_error="Source connection not found", - expected_status=404, - ), - CancelCase( - name="no_sync_id", - sc=_source_connection(sync_id=None), - expected_error="Source connection has no associated sync", - expected_status=400, - ), - CancelCase( - name="job_not_found", - sc=_source_connection(), - expected_error="Sync job not found", - expected_status=404, - ), - CancelCase( - name="wrong_sync", - sc=_source_connection(), - job=_sync_job(sync_id=uuid4()), - expected_error="Sync job does not belong to this source connection", - expected_status=400, - ), - CancelCase( - name="non_cancellable_state", - sc=_source_connection(), - job=_sync_job(status=SyncJobStatus.COMPLETED), - expected_error="Cannot cancel job in", - expected_status=400, - ), - CancelCase( - name="temporal_failure", - sc=_source_connection(), - job=_sync_job(status=SyncJobStatus.RUNNING), - cancel_success=False, - expected_error="Failed to request cancellation from Temporal", - expected_status=502, - ), -] - - -@pytest.mark.asyncio -@pytest.mark.parametrize("case", CANCEL_CASES, ids=lambda c: c.name) -async def test_cancel_job_errors(case: CancelCase): - """Test cancel_job() error paths.""" - sc_repo = FakeSourceConnectionRepository() - sync_job_repo = FakeSyncJobRepository() - state_machine = AsyncMock() - temporal_workflow_service = FakeTemporalWorkflowService() - - if case.sc: - sc_repo.seed(case.sc.id, case.sc) - if case.job: - sync_job_repo.seed(case.job.id, case.job) - - temporal_workflow_service.set_cancel_result( - {"success": case.cancel_success, "workflow_found": case.workflow_found} - ) - - svc = _build_service( - sc_repo=sc_repo, - sync_job_repo=sync_job_repo, - state_machine=state_machine, - temporal_workflow_service=temporal_workflow_service, - ) - - with pytest.raises(HTTPException) as exc_info: - await svc.cancel_job( - AsyncMock(), - source_connection_id=case.sc.id if case.sc else uuid4(), - job_id=case.job.id if case.job else JOB_ID, - ctx=_ctx(), - ) - - assert exc_info.value.status_code == case.expected_status - assert case.expected_error in str(exc_info.value.detail) - - -@pytest.mark.asyncio -async def test_cancel_job_happy_path(): - """Successful cancel: workflow found, job transitions to CANCELLING.""" - sc_repo = FakeSourceConnectionRepository() - sync_job_repo = FakeSyncJobRepository() - state_machine = AsyncMock() - temporal_workflow_service = FakeTemporalWorkflowService() - - sc = _source_connection() - sc_repo.seed(SC_ID, sc) - - job = _sync_job(status=SyncJobStatus.RUNNING) - sync_job_repo.seed(JOB_ID, job) - - temporal_workflow_service.set_cancel_result({"success": True, "workflow_found": True}) - - db_mock = AsyncMock() - - svc = _build_service( - sc_repo=sc_repo, - sync_job_repo=sync_job_repo, - state_machine=state_machine, - temporal_workflow_service=temporal_workflow_service, - ) - - mock_sj_schema = MagicMock() - expected_result = MagicMock() - mock_sj_schema.to_source_connection_job.return_value = expected_result - - _mod = "airweave.domains.syncs.lifecycle_service.schemas" - with patch(f"{_mod}.SyncJob.model_validate", return_value=mock_sj_schema): - result = await svc.cancel_job( - db_mock, source_connection_id=SC_ID, job_id=JOB_ID, ctx=_ctx() - ) - - assert result == expected_result - assert state_machine.transition.await_count == 1 - call_kwargs = state_machine.transition.call_args.kwargs - assert call_kwargs["target"] == SyncJobStatus.CANCELLING - assert len(temporal_workflow_service._calls) == 1 - - -@pytest.mark.asyncio -async def test_cancel_job_workflow_not_found(): - """When workflow is not found, job should be marked CANCELLED directly.""" - from unittest.mock import patch - - sc_repo = FakeSourceConnectionRepository() - sync_job_repo = FakeSyncJobRepository() - state_machine = AsyncMock() - temporal_workflow_service = FakeTemporalWorkflowService() - - sc = _source_connection() - sc_repo.seed(SC_ID, sc) - - job = _sync_job(status=SyncJobStatus.RUNNING) - sync_job_repo.seed(JOB_ID, job) - - temporal_workflow_service.set_cancel_result({"success": True, "workflow_found": False}) - - db_mock = AsyncMock() - - svc = _build_service( - sc_repo=sc_repo, - sync_job_repo=sync_job_repo, - state_machine=state_machine, - temporal_workflow_service=temporal_workflow_service, - ) - - mock_sj_schema = MagicMock() - mock_sj_schema.to_source_connection_job.return_value = MagicMock() - - _mod = "airweave.domains.syncs.lifecycle_service.schemas" - with patch(f"{_mod}.SyncJob.model_validate", return_value=mock_sj_schema): - await svc.cancel_job(db_mock, source_connection_id=SC_ID, job_id=JOB_ID, ctx=_ctx()) - - assert state_machine.transition.await_count == 2 - targets = [c.kwargs["target"] for c in state_machine.transition.call_args_list] - assert targets == [SyncJobStatus.CANCELLING, SyncJobStatus.CANCELLED] - - -# --------------------------------------------------------------------------- -# provision_sync() tests -# --------------------------------------------------------------------------- - - -@dataclass -class ProvisionCase: - """Table-driven case for provision_sync().""" - - name: str - source_entry: Optional[SourceRegistryEntry] = None - schedule_config: Optional[ScheduleConfig] = None - run_immediately: bool = True - expected_none: bool = False - expected_error: Optional[str] = None - expected_status: Optional[int] = None - expected_cron: Optional[str] = None - expect_schedule_call: bool = False - - -PROVISION_CASES = [ - ProvisionCase( - name="federated_search_returns_none", - source_entry=_source_entry(federated_search=True), - expected_none=True, - ), - ProvisionCase( - name="no_schedule_no_immediate_returns_none", - source_entry=_source_entry(), - schedule_config=ScheduleConfig(cron=None), - run_immediately=False, - expected_none=True, - ), - ProvisionCase( - name="default_continuous_schedule", - source_entry=_source_entry(supports_continuous=True), - expected_cron=CONTINUOUS_SOURCE_DEFAULT_CRON, - expect_schedule_call=True, - ), - ProvisionCase( - name="explicit_cron_used", - source_entry=_source_entry(), - schedule_config=ScheduleConfig(cron="0 3 * * *"), - expected_cron="0 3 * * *", - expect_schedule_call=True, - ), - ProvisionCase( - name="sub_hourly_rejected_for_non_continuous", - source_entry=_source_entry(supports_continuous=False), - schedule_config=ScheduleConfig(cron="*/5 * * * *"), - expected_error="does not support continuous syncs", - expected_status=400, - ), - ProvisionCase( - name="every_minute_rejected_for_non_continuous", - source_entry=_source_entry(supports_continuous=False), - schedule_config=ScheduleConfig(cron="* * * * *"), - expected_error="does not support continuous syncs", - expected_status=400, - ), - ProvisionCase( - name="sub_hourly_ok_for_continuous", - source_entry=_source_entry(supports_continuous=True), - schedule_config=ScheduleConfig(cron="*/5 * * * *"), - expected_cron="*/5 * * * *", - expect_schedule_call=True, - ), - ProvisionCase( - name="happy_path_immediate_no_schedule", - source_entry=_source_entry(), - schedule_config=ScheduleConfig(cron=None), - run_immediately=True, - expected_none=False, - expected_cron=None, - expect_schedule_call=False, - ), -] - - -@pytest.mark.asyncio -@pytest.mark.parametrize("case", PROVISION_CASES, ids=lambda c: c.name) -async def test_provision_sync(case: ProvisionCase): - """Test provision_sync() with table-driven cases.""" - sync_service = FakeSyncRecordService() - temporal_schedule_service = FakeTemporalScheduleService() - - mock_sync = _sync_schema() - mock_sync_job = _sync_job_schema() - sync_service.set_create_result(mock_sync, mock_sync_job) - - svc = _build_service( - sync_service=sync_service, - temporal_schedule_service=temporal_schedule_service, - ) - - uow = _FakeUoW() - - if case.expected_error: - with pytest.raises(HTTPException) as exc_info: - await svc.provision_sync( - uow.session, - name="Test", - source_connection_id=SC_ID, - destination_connection_ids=[DEST_CONN_ID], - collection_id=COLLECTION_ID, - collection_readable_id="test-collection", - source_entry=case.source_entry or _source_entry(), - schedule_config=case.schedule_config, - run_immediately=case.run_immediately, - ctx=_ctx(), - uow=uow, - ) - assert exc_info.value.status_code == case.expected_status - assert case.expected_error in str(exc_info.value.detail) - return - - result = await svc.provision_sync( - uow.session, - name="Test", - source_connection_id=SC_ID, - destination_connection_ids=[DEST_CONN_ID], - collection_id=COLLECTION_ID, - collection_readable_id="test-collection", - source_entry=case.source_entry or _source_entry(), - schedule_config=case.schedule_config, - run_immediately=case.run_immediately, - ctx=_ctx(), - uow=uow, - ) - - if case.expected_none: - assert result is None - assert len(sync_service._calls) == 0 - return - - assert result is not None - assert isinstance(result, SyncProvisionResult) - assert result.sync_id == SYNC_ID - assert result.cron_schedule == case.expected_cron - - create_calls = [c for c in sync_service._calls if c[0] == "create_sync"] - assert len(create_calls) == 1 - assert create_calls[0][3] == case.expected_cron # cron_schedule arg - - schedule_calls = [ - c for c in temporal_schedule_service._calls if c[0] == "create_or_update_schedule" - ] - if case.expect_schedule_call: - assert len(schedule_calls) == 1 - assert schedule_calls[0][2] == case.expected_cron - else: - assert len(schedule_calls) == 0 - - -@pytest.mark.asyncio -async def test_provision_sync_default_daily_schedule(): - """Default daily schedule uses current UTC hour:minute.""" - sync_service = FakeSyncRecordService() - temporal_schedule_service = FakeTemporalScheduleService() - - mock_sync = _sync_schema() - sync_service.set_create_result(mock_sync, _sync_job_schema()) - - svc = _build_service( - sync_service=sync_service, - temporal_schedule_service=temporal_schedule_service, - ) - - uow = _FakeUoW() - result = await svc.provision_sync( - uow.session, - name="Test", - source_connection_id=SC_ID, - destination_connection_ids=[DEST_CONN_ID], - collection_id=COLLECTION_ID, - collection_readable_id="test-collection", - source_entry=_source_entry(supports_continuous=False), - schedule_config=None, - run_immediately=True, - ctx=_ctx(), - uow=uow, - ) - - assert result is not None - parts = result.cron_schedule.split() - assert len(parts) == 5 - assert parts[2:] == ["*", "*", "*"] # daily schedule pattern - - schedule_calls = [ - c for c in temporal_schedule_service._calls if c[0] == "create_or_update_schedule" - ] - assert len(schedule_calls) == 1 diff --git a/backend/airweave/domains/syncs/tests/test_record_service.py b/backend/airweave/domains/syncs/tests/test_record_service.py deleted file mode 100644 index 83dfa4fdd..000000000 --- a/backend/airweave/domains/syncs/tests/test_record_service.py +++ /dev/null @@ -1,254 +0,0 @@ -"""Table-driven tests for SyncRecordService. - -Covers trigger_sync_run (happy, active-job, not-found). -""" - -from dataclasses import dataclass -from typing import Optional -from unittest.mock import AsyncMock, MagicMock, patch -from uuid import UUID, uuid4 - -import pytest -from fastapi import HTTPException - -from airweave.core.constants.reserved_ids import NATIVE_VESPA_UUID -from airweave.domains.syncs.record_service import SyncRecordService - -ORG_ID = uuid4() -SYNC_ID = uuid4() - - -def _mock_ctx() -> MagicMock: - ctx = MagicMock() - ctx.organization = MagicMock() - ctx.organization.id = ORG_ID - ctx.logger = MagicMock() - ctx.has_feature = MagicMock(return_value=False) - return ctx - - -def _mock_sync_model(sync_id: UUID = SYNC_ID, status: str = "active") -> MagicMock: - sync = MagicMock() - sync.id = sync_id - sync.name = "test-sync" - sync.status = status - return sync - - -def _mock_sync_job_model(sync_id: UUID = SYNC_ID, status: str = "PENDING") -> MagicMock: - job = MagicMock() - job.id = uuid4() - job.sync_id = sync_id - job.status = status - job.organization_id = ORG_ID - return job - - -# --------------------------------------------------------------------------- -# trigger_sync_run -# --------------------------------------------------------------------------- - - -@dataclass -class TriggerCase: - """Parameters for a single trigger_sync_run scenario.""" - - name: str - active_jobs: list - sync_exists: bool = True - sync_status: str = "active" - expect_error: Optional[type] = None - error_status: Optional[int] = None - - -TRIGGER_CASES = [ - TriggerCase( - name="happy_path", - active_jobs=[], - sync_exists=True, - ), - TriggerCase( - name="active_job_blocks", - active_jobs=[_mock_sync_job_model(status="running")], - expect_error=HTTPException, - error_status=400, - ), - TriggerCase( - name="sync_not_found", - active_jobs=[], - sync_exists=False, - expect_error=ValueError, - ), - TriggerCase( - name="non_active_sync_rejected", - active_jobs=[], - sync_exists=True, - sync_status="paused", - expect_error=HTTPException, - error_status=409, - ), -] - - -@pytest.mark.asyncio -@pytest.mark.parametrize("case", TRIGGER_CASES, ids=lambda c: c.name) -async def test_trigger_sync_run(case: TriggerCase) -> None: - """Verify trigger_sync_run behaviour for each scenario.""" - sync_repo = AsyncMock() - sync_job_repo = AsyncMock() - connection_repo = AsyncMock() - - sync_job_repo.get_active_for_sync = AsyncMock(return_value=case.active_jobs) - sync_repo.get = AsyncMock( - return_value=_mock_sync_model(status=case.sync_status) if case.sync_exists else None - ) - - created_job = _mock_sync_job_model() - sync_job_repo.create = AsyncMock(return_value=created_job) - - svc = SyncRecordService( - sync_repo=sync_repo, - sync_job_repo=sync_job_repo, - connection_repo=connection_repo, - ) - db = AsyncMock() - ctx = _mock_ctx() - - if case.expect_error: - with pytest.raises(case.expect_error) as exc_info: - with patch("airweave.domains.syncs.record_service.UnitOfWork") as mock_uow_cls: - mock_uow = AsyncMock() - mock_uow.session = AsyncMock() - mock_uow.commit = AsyncMock() - mock_uow.session.refresh = AsyncMock() - mock_uow_cls.return_value.__aenter__ = AsyncMock(return_value=mock_uow) - mock_uow_cls.return_value.__aexit__ = AsyncMock(return_value=False) - await svc.trigger_sync_run(db, SYNC_ID, ctx) - if case.error_status and isinstance(exc_info.value, HTTPException): - assert exc_info.value.status_code == case.error_status - else: - with patch("airweave.domains.syncs.record_service.UnitOfWork") as mock_uow_cls: - mock_uow = AsyncMock() - mock_uow.session = AsyncMock() - mock_uow.commit = AsyncMock() - mock_uow.session.refresh = AsyncMock() - mock_uow_cls.return_value.__aenter__ = AsyncMock(return_value=mock_uow) - mock_uow_cls.return_value.__aexit__ = AsyncMock(return_value=False) - - with patch("airweave.domains.syncs.record_service.schemas") as mock_schemas: - mock_sync_schema = MagicMock() - mock_job_schema = MagicMock() - mock_schemas.Sync.model_validate.return_value = mock_sync_schema - mock_schemas.SyncJob.model_validate.return_value = mock_job_schema - mock_schemas.SyncJobCreate = MagicMock() - - result = await svc.trigger_sync_run(db, SYNC_ID, ctx) - assert result == (mock_sync_schema, mock_job_schema) - - sync_job_repo.create.assert_called_once() - mock_uow.commit.assert_called_once() - - -# --------------------------------------------------------------------------- -# resolve_destination_ids -# --------------------------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_create_sync_flushes_and_refreshes_job_before_validation() -> None: - """Verify create_sync flushes twice and refreshes the job before validation.""" - sync_repo = AsyncMock() - sync_job_repo = AsyncMock() - connection_repo = AsyncMock() - svc = SyncRecordService( - sync_repo=sync_repo, - sync_job_repo=sync_job_repo, - connection_repo=connection_repo, - ) - - sync_schema = MagicMock(id=SYNC_ID) - sync_repo.create = AsyncMock(return_value=sync_schema) - created_job = MagicMock() - sync_job_repo.create = AsyncMock(return_value=created_job) - - uow = MagicMock() - uow.session = AsyncMock() - uow.session.flush = AsyncMock() - uow.session.refresh = AsyncMock() - ctx = _mock_ctx() - - with patch("airweave.domains.syncs.record_service.schemas") as mock_schemas: - validated_job_schema = MagicMock() - mock_schemas.SyncJob.model_validate.return_value = validated_job_schema - mock_schemas.SyncJobCreate = MagicMock() - - sync, sync_job = await svc.create_sync( - AsyncMock(), - name="Test Sync", - source_connection_id=uuid4(), - destination_connection_ids=[NATIVE_VESPA_UUID], - cron_schedule=None, - run_immediately=True, - ctx=ctx, - uow=uow, - ) - - assert sync is sync_schema - assert sync_job is validated_job_schema - assert uow.session.flush.await_count == 2 - uow.session.refresh.assert_awaited_once_with(created_job) - - -@pytest.mark.asyncio -async def test_create_sync_flushes_sync_even_without_immediate_job() -> None: - """Verify create_sync flushes once and skips job when run_immediately=False.""" - sync_repo = AsyncMock() - sync_job_repo = AsyncMock() - connection_repo = AsyncMock() - svc = SyncRecordService( - sync_repo=sync_repo, - sync_job_repo=sync_job_repo, - connection_repo=connection_repo, - ) - - sync_schema = MagicMock(id=SYNC_ID) - sync_repo.create = AsyncMock(return_value=sync_schema) - - uow = MagicMock() - uow.session = AsyncMock() - uow.session.flush = AsyncMock() - uow.session.refresh = AsyncMock() - ctx = _mock_ctx() - - sync, sync_job = await svc.create_sync( - AsyncMock(), - name="Test Sync", - source_connection_id=uuid4(), - destination_connection_ids=[NATIVE_VESPA_UUID], - cron_schedule="0 * * * *", - run_immediately=False, - ctx=ctx, - uow=uow, - ) - - assert sync is sync_schema - assert sync_job is None - uow.session.flush.assert_awaited_once() - uow.session.refresh.assert_not_awaited() - sync_job_repo.create.assert_not_called() - - -@pytest.mark.asyncio -async def test_resolve_destination_ids_returns_native_only() -> None: - """Verify resolve_destination_ids returns only native Vespa UUID.""" - svc = SyncRecordService( - sync_repo=AsyncMock(), - sync_job_repo=AsyncMock(), - connection_repo=AsyncMock(), - ) - db = AsyncMock() - ctx = _mock_ctx() - - destination_ids = await svc.resolve_destination_ids(db, ctx) - - assert destination_ids == [NATIVE_VESPA_UUID] diff --git a/backend/airweave/domains/syncs/tests/test_service.py b/backend/airweave/domains/syncs/tests/test_service.py index 601aa62ef..c55f680f9 100644 --- a/backend/airweave/domains/syncs/tests/test_service.py +++ b/backend/airweave/domains/syncs/tests/test_service.py @@ -10,6 +10,7 @@ from uuid import uuid4 import pytest +from fastapi import HTTPException from airweave.core.shared_models import SyncJobStatus from airweave.domains.syncs.service import SyncService @@ -93,9 +94,14 @@ async def test_run(case: RunCase): ) svc = SyncService( - state_machine=fake_state_machine, + sync_repo=MagicMock(), + sync_job_repo=MagicMock(), + sync_cursor_repo=MagicMock(), + state_machine=MagicMock(), + job_state_machine=fake_state_machine, + temporal_workflow_service=MagicMock(), + temporal_schedule_service=MagicMock(), sync_factory=fake_factory, - sync_state_machine=MagicMock(), ) sync = _mock_sync() @@ -160,9 +166,14 @@ async def test_run_forwards_optional_kwargs(): ) svc = SyncService( - state_machine=fake_state_machine, + sync_repo=MagicMock(), + sync_job_repo=MagicMock(), + sync_cursor_repo=MagicMock(), + state_machine=MagicMock(), + job_state_machine=fake_state_machine, + temporal_workflow_service=MagicMock(), + temporal_schedule_service=MagicMock(), sync_factory=fake_factory, - sync_state_machine=MagicMock(), ) mock_db = AsyncMock() @@ -205,9 +216,7 @@ async def test_credential_error_propagates_error_category(): cause = TokenExpiredError( "JWT expired", source_short_name="github", provider_kind=AuthProviderKind.OAUTH ) - wrapper = SourceValidationError( - short_name="github", reason="credential validation failed" - ) + wrapper = SourceValidationError(short_name="github", reason="credential validation failed") wrapper.__cause__ = cause fake_sm = AsyncMock() @@ -215,9 +224,14 @@ async def test_credential_error_propagates_error_category(): fake_factory.create_orchestrator = AsyncMock(side_effect=wrapper) svc = SyncService( - state_machine=fake_sm, + sync_repo=MagicMock(), + sync_job_repo=MagicMock(), + sync_cursor_repo=MagicMock(), + state_machine=AsyncMock(), + job_state_machine=fake_sm, + temporal_workflow_service=MagicMock(), + temporal_schedule_service=MagicMock(), sync_factory=fake_factory, - sync_state_machine=MagicMock(), ) with patch("airweave.domains.syncs.service.get_db_context") as mock_db_ctx: @@ -236,10 +250,7 @@ async def test_credential_error_propagates_error_category(): fake_sm.transition.assert_awaited_once() call_kwargs = fake_sm.transition.call_args.kwargs assert call_kwargs["target"] == SyncJobStatus.FAILED - assert ( - call_kwargs["error_category"] - == SourceConnectionErrorCategory.OAUTH_CREDENTIALS_EXPIRED - ) + assert call_kwargs["error_category"] == SourceConnectionErrorCategory.OAUTH_CREDENTIALS_EXPIRED @pytest.mark.asyncio @@ -247,14 +258,17 @@ async def test_non_credential_error_has_no_error_category(): """Non-auth factory error -> error_category=None on state machine transition.""" fake_sm = AsyncMock() fake_factory = MagicMock() - fake_factory.create_orchestrator = AsyncMock( - side_effect=RuntimeError("bad config") - ) + fake_factory.create_orchestrator = AsyncMock(side_effect=RuntimeError("bad config")) svc = SyncService( - state_machine=fake_sm, + sync_repo=MagicMock(), + sync_job_repo=MagicMock(), + sync_cursor_repo=MagicMock(), + state_machine=AsyncMock(), + job_state_machine=fake_sm, + temporal_workflow_service=MagicMock(), + temporal_schedule_service=MagicMock(), sync_factory=fake_factory, - sync_state_machine=MagicMock(), ) with patch("airweave.domains.syncs.service.get_db_context") as mock_db_ctx: @@ -278,9 +292,565 @@ def test_stores_injected_deps(): fake_sm = MagicMock() fake_factory = MagicMock() svc = SyncService( + sync_repo=MagicMock(), + sync_job_repo=MagicMock(), + sync_cursor_repo=MagicMock(), state_machine=fake_sm, + job_state_machine=MagicMock(), + temporal_workflow_service=MagicMock(), + temporal_schedule_service=MagicMock(), sync_factory=fake_factory, - sync_state_machine=MagicMock(), ) assert svc._state_machine is fake_sm assert svc._sync_factory is fake_factory + + +# --------------------------------------------------------------------------- +# Helper: build a SyncService with configurable mocks +# --------------------------------------------------------------------------- + + +def _build_svc( + sync_repo=None, + sync_job_repo=None, + sync_cursor_repo=None, + state_machine=None, + job_state_machine=None, + temporal_workflow_service=None, + temporal_schedule_service=None, + sync_factory=None, +): + return SyncService( + sync_repo=sync_repo or AsyncMock(), + sync_job_repo=sync_job_repo or AsyncMock(), + sync_cursor_repo=sync_cursor_repo or AsyncMock(), + state_machine=state_machine or AsyncMock(), + job_state_machine=job_state_machine or AsyncMock(), + temporal_workflow_service=temporal_workflow_service or AsyncMock(), + temporal_schedule_service=temporal_schedule_service or AsyncMock(), + sync_factory=sync_factory or MagicMock(), + ) + + +# --------------------------------------------------------------------------- +# get() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_returns_sync(): + repo = AsyncMock() + expected = MagicMock() + repo.get.return_value = expected + + svc = _build_svc(sync_repo=repo) + result = await svc.get(AsyncMock(), sync_id=uuid4(), ctx=_mock_ctx()) + assert result is expected + + +@pytest.mark.asyncio +async def test_get_raises_when_not_found(): + repo = AsyncMock() + repo.get.return_value = None + + svc = _build_svc(sync_repo=repo) + with pytest.raises(HTTPException) as exc_info: + await svc.get(AsyncMock(), sync_id=uuid4(), ctx=_mock_ctx()) + assert exc_info.value.status_code == 404 + + +# --------------------------------------------------------------------------- +# pause() / resume() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_pause_delegates_to_state_machine(): + from airweave.core.shared_models import SyncStatus + + sm = AsyncMock() + expected = MagicMock() + sm.transition.return_value = expected + + svc = _build_svc(state_machine=sm) + sid = uuid4() + result = await svc.pause(sid, _mock_ctx(), reason="maintenance") + + assert result is expected + call_kw = sm.transition.call_args.kwargs + assert call_kw["sync_id"] == sid + assert call_kw["target"] == SyncStatus.PAUSED + assert call_kw["reason"] == "maintenance" + + +@pytest.mark.asyncio +async def test_resume_delegates_to_state_machine(): + from airweave.core.shared_models import SyncStatus + + sm = AsyncMock() + expected = MagicMock() + sm.transition.return_value = expected + + svc = _build_svc(state_machine=sm) + sid = uuid4() + result = await svc.resume(sid, _mock_ctx()) + + assert result is expected + assert sm.transition.call_args.kwargs["target"] == SyncStatus.ACTIVE + + +# --------------------------------------------------------------------------- +# resolve_destination_ids() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_resolve_destination_ids_returns_vespa(): + from airweave.core.constants.reserved_ids import NATIVE_VESPA_UUID + + svc = _build_svc() + result = await svc.resolve_destination_ids(AsyncMock(), _mock_ctx()) + assert result == [NATIVE_VESPA_UUID] + + +# --------------------------------------------------------------------------- +# get_jobs() +# --------------------------------------------------------------------------- + + +def _orm_sync_job(job_id=None, sync_id=None, status=SyncJobStatus.COMPLETED): + """Create a MagicMock that passes schemas.SyncJob.model_validate().""" + m = MagicMock() + m.id = job_id or uuid4() + m.sync_id = sync_id or uuid4() + m.organization_id = uuid4() + m.status = status + m.scheduled = False + m.entities_inserted = 0 + m.entities_updated = 0 + m.entities_deleted = 0 + m.entities_kept = 0 + m.entities_skipped = 0 + m.entities_encountered = {} + m.started_at = None + m.completed_at = None + m.failed_at = None + m.error = None + m.error_category = None + m.access_token = None + m.sync_config = None + m.sync_metadata = None + m.created_by_email = None + m.modified_by_email = None + m.created_at = None + m.modified_at = None + m.sync_name = None + return m + + +@pytest.mark.asyncio +async def test_get_jobs_returns_validated_schemas(): + job_repo = AsyncMock() + mock_job = _orm_sync_job() + job_repo.get_all_by_sync_id.return_value = [mock_job] + + svc = _build_svc(sync_job_repo=job_repo) + jobs = await svc.get_jobs(AsyncMock(), sync_id=uuid4(), ctx=_mock_ctx()) + assert len(jobs) == 1 + assert jobs[0].id == mock_job.id + + +# --------------------------------------------------------------------------- +# validate_force_full_sync() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_validate_force_full_sync_no_cursor(): + cursor_repo = AsyncMock() + cursor_repo.get_by_sync_id.return_value = None + + svc = _build_svc(sync_cursor_repo=cursor_repo) + ctx = _mock_ctx() + await svc.validate_force_full_sync(AsyncMock(), uuid4(), ctx) + ctx.logger.info.assert_called_once() + assert "no cursor data" in ctx.logger.info.call_args[0][0] + + +@pytest.mark.asyncio +async def test_validate_force_full_sync_with_cursor(): + cursor_repo = AsyncMock() + cursor = MagicMock() + cursor.cursor_data = {"some": "data"} + cursor_repo.get_by_sync_id.return_value = cursor + + svc = _build_svc(sync_cursor_repo=cursor_repo) + ctx = _mock_ctx() + await svc.validate_force_full_sync(AsyncMock(), uuid4(), ctx) + ctx.logger.info.assert_called_once() + assert "Force full sync" in ctx.logger.info.call_args[0][0] + + +# --------------------------------------------------------------------------- +# cancel_job() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_cancel_job_not_found(): + job_repo = AsyncMock() + job_repo.get.return_value = None + + svc = _build_svc(sync_job_repo=job_repo) + with pytest.raises(HTTPException) as exc_info: + await svc.cancel_job(AsyncMock(), job_id=uuid4(), ctx=_mock_ctx()) + assert exc_info.value.status_code == 404 + + +@pytest.mark.asyncio +async def test_cancel_job_wrong_status(): + job_repo = AsyncMock() + job = MagicMock() + job.status = SyncJobStatus.COMPLETED + job_repo.get.return_value = job + + svc = _build_svc(sync_job_repo=job_repo) + with pytest.raises(HTTPException) as exc_info: + await svc.cancel_job(AsyncMock(), job_id=uuid4(), ctx=_mock_ctx()) + assert exc_info.value.status_code == 400 + + +@pytest.mark.asyncio +async def test_cancel_job_success(): + job_id = uuid4() + job_repo = AsyncMock() + job = _orm_sync_job(job_id=job_id, status=SyncJobStatus.RUNNING) + job_repo.get.return_value = job + + temporal = AsyncMock() + temporal.cancel_sync_job_workflow.return_value = { + "success": True, + "workflow_found": True, + } + + job_sm = AsyncMock() + db = AsyncMock() + + svc = _build_svc( + sync_job_repo=job_repo, + temporal_workflow_service=temporal, + job_state_machine=job_sm, + ) + result = await svc.cancel_job(db, job_id=job_id, ctx=_mock_ctx()) + + job_sm.transition.assert_awaited_once() + assert job_sm.transition.call_args.kwargs["target"] == SyncJobStatus.CANCELLING + temporal.cancel_sync_job_workflow.assert_awaited_once() + assert result is not None + + +@pytest.mark.asyncio +async def test_cancel_job_workflow_not_found_marks_cancelled(): + job_id = uuid4() + job_repo = AsyncMock() + job = _orm_sync_job(job_id=job_id, status=SyncJobStatus.PENDING) + job_repo.get.return_value = job + + temporal = AsyncMock() + temporal.cancel_sync_job_workflow.return_value = { + "success": True, + "workflow_found": False, + } + + job_sm = AsyncMock() + db = AsyncMock() + + svc = _build_svc( + sync_job_repo=job_repo, + temporal_workflow_service=temporal, + job_state_machine=job_sm, + ) + await svc.cancel_job(db, job_id=job_id, ctx=_mock_ctx()) + + assert job_sm.transition.await_count == 2 + second_call = job_sm.transition.call_args_list[1].kwargs + assert second_call["target"] == SyncJobStatus.CANCELLED + + +@pytest.mark.asyncio +async def test_cancel_job_temporal_failure(): + job_id = uuid4() + job_repo = AsyncMock() + job = MagicMock() + job.id = job_id + job.status = SyncJobStatus.RUNNING + job_repo.get.return_value = job + + temporal = AsyncMock() + temporal.cancel_sync_job_workflow.return_value = { + "success": False, + "workflow_found": True, + } + + svc = _build_svc(sync_job_repo=job_repo, temporal_workflow_service=temporal) + with pytest.raises(HTTPException) as exc_info: + await svc.cancel_job(AsyncMock(), job_id=job_id, ctx=_mock_ctx()) + assert exc_info.value.status_code == 502 + + +# --------------------------------------------------------------------------- +# _resolve_cron() / _validate_cron_for_source() +# --------------------------------------------------------------------------- + + +def _mock_source_entry(*, short_name="github", continuous=False, federated=False): + entry = MagicMock() + entry.short_name = short_name + entry.supports_continuous = continuous + entry.federated_search = federated + return entry + + +def test_resolve_cron_explicit(): + from airweave.schemas.source_connection import ScheduleConfig + + svc = _build_svc() + result = svc._resolve_cron( + ScheduleConfig(cron="0 6 * * *"), + _mock_source_entry(), + _mock_ctx(), + ) + assert result == "0 6 * * *" + + +def test_resolve_cron_explicit_null(): + from airweave.schemas.source_connection import ScheduleConfig + + svc = _build_svc() + result = svc._resolve_cron( + ScheduleConfig(cron=None), + _mock_source_entry(), + _mock_ctx(), + ) + assert result is None + + +def test_resolve_cron_continuous_default(): + from airweave.domains.syncs.types import CONTINUOUS_SOURCE_DEFAULT_CRON + + svc = _build_svc() + result = svc._resolve_cron(None, _mock_source_entry(continuous=True), _mock_ctx()) + assert result == CONTINUOUS_SOURCE_DEFAULT_CRON + + +def test_resolve_cron_daily_default(): + svc = _build_svc() + result = svc._resolve_cron(None, _mock_source_entry(), _mock_ctx()) + assert result is not None + parts = result.split() + assert len(parts) == 5 + assert parts[2:] == ["*", "*", "*"] + + +def test_validate_cron_allows_continuous(): + svc = _build_svc() + svc._validate_cron_for_source("* * * * *", _mock_source_entry(continuous=True)) + + +def test_validate_cron_rejects_every_minute(): + svc = _build_svc() + with pytest.raises(HTTPException) as exc_info: + svc._validate_cron_for_source("* * * * *", _mock_source_entry()) + assert exc_info.value.status_code == 400 + + +def test_validate_cron_rejects_sub_hourly(): + svc = _build_svc() + with pytest.raises(HTTPException) as exc_info: + svc._validate_cron_for_source("*/5 * * * *", _mock_source_entry()) + assert exc_info.value.status_code == 400 + + +# --------------------------------------------------------------------------- +# create() — federated / no-schedule / happy path +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_create_skips_federated(): + svc = _build_svc() + with pytest.raises(ValueError, match="federated"): + await svc.create( + AsyncMock(), + name="test", + source_connection_id=uuid4(), + destination_connection_ids=[uuid4()], + collection_id=uuid4(), + collection_readable_id="col-x", + source_entry=_mock_source_entry(federated=True), + schedule_config=None, + run_immediately=False, + ctx=_mock_ctx(), + uow=MagicMock(), + ) + + +@pytest.mark.asyncio +async def test_create_no_cron_no_run_immediately(): + from airweave.schemas.source_connection import ScheduleConfig + + svc = _build_svc() + with pytest.raises(ValueError, match="no schedule"): + await svc.create( + AsyncMock(), + name="test", + source_connection_id=uuid4(), + destination_connection_ids=[uuid4()], + collection_id=uuid4(), + collection_readable_id="col-x", + source_entry=_mock_source_entry(), + schedule_config=ScheduleConfig(cron=None), + run_immediately=False, + ctx=_mock_ctx(), + uow=MagicMock(), + ) + + +@pytest.mark.asyncio +async def test_create_with_cron_calls_temporal_schedule(): + from airweave.schemas.source_connection import ScheduleConfig + + sync_repo = AsyncMock() + mock_sync = MagicMock() + mock_sync.id = uuid4() + sync_repo.create.return_value = mock_sync + + job_repo = AsyncMock() + temporal_sched = AsyncMock() + + uow = MagicMock() + uow.session = AsyncMock() + uow.commit = AsyncMock() + + svc = _build_svc( + sync_repo=sync_repo, + sync_job_repo=job_repo, + temporal_schedule_service=temporal_sched, + ) + + result = await svc.create( + AsyncMock(), + name="test", + source_connection_id=uuid4(), + destination_connection_ids=[uuid4()], + collection_id=uuid4(), + collection_readable_id="col-x", + source_entry=_mock_source_entry(), + schedule_config=ScheduleConfig(cron="0 6 * * *"), + run_immediately=False, + ctx=_mock_ctx(), + uow=uow, + ) + assert result is not None + assert result.sync_id == mock_sync.id + temporal_sched.create_or_update_schedule.assert_awaited_once() + + +# --------------------------------------------------------------------------- +# delete() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_delete_delegates(): + svc = _build_svc() + svc._cancel_active_sync = AsyncMock(return_value=True) + svc._wait_for_terminal = AsyncMock() + svc._schedule_cleanup = AsyncMock() + + await svc.delete( + AsyncMock(), + sync_id=uuid4(), + collection_id=uuid4(), + organization_id=uuid4(), + ctx=_mock_ctx(), + ) + svc._cancel_active_sync.assert_awaited_once() + svc._wait_for_terminal.assert_awaited_once() + svc._schedule_cleanup.assert_awaited_once() + + +# --------------------------------------------------------------------------- +# _cancel_active_sync() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_cancel_active_sync_cancels_running_job(): + sync_id = uuid4() + job = MagicMock() + job.id = uuid4() + job.status = SyncJobStatus.RUNNING + + job_repo = AsyncMock() + job_repo.get_latest_by_sync_id.return_value = job + + temporal = AsyncMock() + svc = _build_svc(sync_job_repo=job_repo, temporal_workflow_service=temporal) + + result = await svc._cancel_active_sync(AsyncMock(), sync_id, _mock_ctx()) + assert result is True + temporal.cancel_sync_job_workflow.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_cancel_active_sync_skips_terminal(): + job = MagicMock() + job.status = SyncJobStatus.COMPLETED + + job_repo = AsyncMock() + job_repo.get_latest_by_sync_id.return_value = job + + svc = _build_svc(sync_job_repo=job_repo) + result = await svc._cancel_active_sync(AsyncMock(), uuid4(), _mock_ctx()) + assert result is False + + +# --------------------------------------------------------------------------- +# _wait_for_terminal() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_wait_for_terminal_returns_when_no_job(): + job_repo = AsyncMock() + job_repo.get_latest_by_sync_id.return_value = None + svc = _build_svc(sync_job_repo=job_repo) + db = MagicMock() + with patch("airweave.domains.syncs.service.asyncio.sleep", new_callable=AsyncMock): + await svc._wait_for_terminal(db, uuid4(), 5, _mock_ctx()) + job_repo.get_latest_by_sync_id.assert_awaited() + + +# --------------------------------------------------------------------------- +# _schedule_cleanup() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_schedule_cleanup_calls_temporal(): + temporal = AsyncMock() + svc = _build_svc(temporal_workflow_service=temporal) + await svc._schedule_cleanup(uuid4(), uuid4(), uuid4(), _mock_ctx()) + temporal.start_cleanup_sync_data_workflow.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_schedule_cleanup_handles_error(): + temporal = AsyncMock() + temporal.start_cleanup_sync_data_workflow.side_effect = RuntimeError("boom") + + svc = _build_svc(temporal_workflow_service=temporal) + ctx = _mock_ctx() + await svc._schedule_cleanup(uuid4(), uuid4(), uuid4(), ctx) + ctx.logger.error.assert_called_once() diff --git a/backend/conftest.py b/backend/conftest.py index 84ad6698d..532161998 100644 --- a/backend/conftest.py +++ b/backend/conftest.py @@ -210,11 +210,11 @@ def fake_health_service() -> FakeHealthService: @pytest.fixture -def fake_source_connection_service(fake_sync_lifecycle): +def fake_source_connection_service(fake_sync_service): """Fake SourceConnectionService.""" from airweave.domains.source_connections.fakes.service import FakeSourceConnectionService - return FakeSourceConnectionService(sync_lifecycle=fake_sync_lifecycle) + return FakeSourceConnectionService(sync_service=fake_sync_service) @pytest.fixture @@ -375,11 +375,9 @@ def fake_billing_service(): @pytest.fixture -def fake_sync_record_service(): - """Fake SyncRecordService.""" - from airweave.domains.syncs.fakes.record_service import FakeSyncRecordService - - return FakeSyncRecordService() +def fake_sync_record_service(fake_sync_service): + """Legacy fixture — returns the unified FakeSyncService for backward compatibility.""" + return fake_sync_service @pytest.fixture @@ -407,11 +405,9 @@ def fake_sync_service(): @pytest.fixture -def fake_sync_lifecycle(): - """Fake SyncLifecycleService.""" - from airweave.domains.syncs.fakes.lifecycle_service import FakeSyncLifecycleService - - return FakeSyncLifecycleService() +def fake_sync_lifecycle(fake_sync_service): + """Legacy fixture — returns the unified FakeSyncService for backward compatibility.""" + return fake_sync_service @pytest.fixture @@ -717,12 +713,9 @@ def test_container( fake_sync_cursor_repo, fake_sync_cursor_service, fake_sync_job_repo, - fake_sync_record_service, fake_sync_job_service, fake_sync_job_state_machine, - fake_sync_state_machine, fake_sync_service, - fake_sync_lifecycle, fake_billing_service, fake_billing_webhook, fake_payment_gateway, @@ -809,12 +802,9 @@ def test_container( sync_cursor_repo=fake_sync_cursor_repo, sync_cursor_service=fake_sync_cursor_service, sync_job_repo=fake_sync_job_repo, - sync_record_service=fake_sync_record_service, sync_job_service=fake_sync_job_service, sync_job_state_machine=fake_sync_job_state_machine, - sync_state_machine=fake_sync_state_machine, sync_service=fake_sync_service, - sync_lifecycle=fake_sync_lifecycle, billing_service=fake_billing_service, billing_webhook=fake_billing_webhook, payment_gateway=fake_payment_gateway, From e0628be0da9b2c0f8a1d9824431944467f3aa81d Mon Sep 17 00:00:00 2001 From: Felix Date: Wed, 15 Apr 2026 02:32:11 -0700 Subject: [PATCH 05/25] feat: custom auth provider (#1668) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add Custom auth provider for customer-hosted token endpoints Adds a new "custom" auth provider that calls a user-hosted HTTP endpoint to fetch fresh access tokens before each sync. This enables clients who manage their own token lifecycle (e.g. Mozilla) to expose a single endpoint instead of injecting expiring OAUTH_TOKEN values. Backend: - CustomAuthConfig with endpoint_url (SSRF-validated), auth method, and auth value fields stored encrypted on the provider connection - CustomAuthProvider implementation with POST-based token fetching, configurable auth (bearer/api_key_header/none), and full error mapping - Registered in ALL_AUTH_PROVIDERS with empty BLOCKED_SOURCES - Unit tests for create, headers, get_creds_for_source, and validate Frontend: - Fix AuthProviderTable to support multiple connections per provider type (was using .find(), now uses .filter() with list view) - New AuthProviderConnectionsList component for browsing connections - Connection count badge on AuthProviderButton when count > 1 - Custom provider icon (plug + arrow) for light and dark themes Co-Authored-By: Claude Opus 4.6 (1M context) * fix: add request-time SSRF protection, simplify validate contract, clean up debug logs Add validate_url() check before every HTTP request in CustomAuthProvider to prevent DNS rebinding attacks against the user-supplied endpoint URL. Simplify validate() to just check for 2xx at base URL instead of requiring access_token at /__validate__. Document the full endpoint contract in CustomAuthConfig. Remove console.log debug statements and unused import from AuthProviderTable. * fix: remove remaining debug console.log statements from auth provider UI * fix: route auth provider connections through correct token provider and error classification Two bugs found during Custom auth provider e2e testing: 1. lifecycle._get_auth_configuration required both readable_auth_provider_id AND auth_provider_config to route to the auth provider path. Custom provider has no config, so it fell through to the DB credential path and failed with "no integration credential". 2. lifecycle._resolve_token_provider only used AuthProviderTokenProvider for OAuth sources. Non-OAuth sources (GitHub, Stripe) got a DirectCredentialProvider even when credentials came from an auth provider, causing errors to be classified as API_KEY_INVALID instead of AUTH_PROVIDER_CREDENTIALS_INVALID. 3. AuthProviderTokenProvider._fetch_token hardcoded "access_token" as the expected field. Non-OAuth sources use different field names (e.g. personal_access_token for GitHub). Now extracts the first matching runtime auth field from the response. * fix: improve error handling for auth provider credential failures - Handle 404 from Custom endpoint as AuthProviderMissingFieldsError with a clear message instead of leaking raw httpx error - Catch remaining HTTP status codes as AuthProviderConfigError instead of re-raising raw httpx exceptions - Wrap AuthProviderError from credential fetching (step 2) as SourceValidationError so the error classifier can process it - Add AuthProviderError catch-all in classifier to map all auth provider failures to AUTH_PROVIDER_CREDENTIALS_INVALID * fix: scope Custom auth provider credentials by source connection ID Custom endpoint now calls GET {base_url}/{source_connection_id} instead of GET {base_url}/{source_short_name}, allowing customers to return different credentials per source connection. Added optional source_connection_id parameter through the base class, lifecycle, and token provider chain. Custom provider requires it and errors if missing. * feat: normalize Custom auth provider field names and simplify customer contract Customers return {"access_token": "..."} or {"api_key": "..."} — the Custom provider maps to Airweave-internal field names (e.g. personal_access_token for GitHub, api_token for Document360). Also treat refresh_token, client_id, client_secret as optional for auth provider flows since the provider handles token lifecycle. * fix: treat OAuth lifecycle fields as optional in token provider path too The previous fix only applied _AUTH_PROVIDER_OPTIONAL in the lifecycle path. The token provider's _call_provider_with_retry also calls get_creds_for_source with the raw optional fields, causing Dropbox syncs to fail demanding refresh_token. * fix: harden Custom auth provider security, clean up debug logs, improve UI - Block SSRF via redirect by setting follow_redirects=False on httpx clients - Fix stale docstrings referencing source_short_name instead of source_connection_id - Deduplicate AUTH_PROVIDER_OPTIONAL_FIELDS into shared constant in _base.py - Type AuthProviderConnectionsList props with proper AuthProvider/AuthProviderConnection types - Remove 19 debug console.log statements across frontend auth provider components - Add "View endpoint contract" tooltip in Custom provider configure UI - Add tests verifying follow_redirects=False is enforced * feat: gate Custom auth provider behind custom_auth_provider feature flag Add feature_flag support to the @auth_provider decorator, mirroring the existing pattern for sources. Providers with a feature flag are hidden from list_metadata unless the organization has the flag enabled. * fix: use set() instead of list for runtime_auth_optional_fields in test The | operator requires set-like types on both sides. The test was using [] but AUTH_PROVIDER_OPTIONAL_FIELDS is a frozenset. * refactor: extract HTTP status classification to reduce complexity Moves the HTTP status code → exception mapping out of get_creds_for_source into _raise_for_http_status to satisfy ruff C901 (complexity 13 > 10). * fix: resolve mypy violations in diff - Declare endpoint_url/api_key as instance attrs on CustomAuthProvider - Type credentials param as Optional[Dict] instead of Optional[Any] - Cast BaseContext to ApiContext at _get_auth_configuration call site - Use str() to satisfy return type on credential extraction - Remove dead _replace_fields call in test * fix: block ctti source in Custom provider and update e2e test allowlist Custom provider had empty BLOCKED_SOURCES, so ctti (internal test source) showed it as a supported provider. Also add 'custom' to the valid auth provider names in the smoke test. * fix: add source_connection_id param to Composio and Pipedream providers The base class and callers now pass source_connection_id as a keyword argument. Composio and Pipedream need to accept it in their signatures even though they don't use it. * refactor: rename endpoint_url to base_endpoint_url and type create() Address PR review: rename the Custom auth provider's URL field to be more descriptive, and parse credentials into CustomAuthConfig inside create() for typed attribute access. --------- Co-authored-by: Claude Opus 4.6 (1M context) --- backend/airweave/core/shared_models.py | 1 + .../airweave/domains/auth_provider/_base.py | 18 +- .../auth_provider/providers/__init__.py | 2 + .../auth_provider/providers/composio.py | 2 + .../domains/auth_provider/providers/custom.py | 223 +++++++++++ .../auth_provider/providers/pipedream.py | 2 + .../domains/auth_provider/registry.py | 3 + .../airweave/domains/auth_provider/service.py | 28 +- .../auth_provider/tests/test_custom.py | 355 ++++++++++++++++++ .../auth_provider/tests/test_service.py | 72 +++- .../airweave/domains/auth_provider/types.py | 3 + .../domains/sources/exceptions/classifier.py | 8 + backend/airweave/domains/sources/lifecycle.py | 68 ++-- .../domains/sources/tests/test_lifecycle.py | 2 +- .../sources/token_providers/auth_provider.py | 29 +- backend/airweave/platform/configs/auth.py | 56 +++ backend/airweave/platform/configs/config.py | 6 + backend/airweave/platform/decorators.py | 3 + backend/tests/e2e/smoke/test_sources.py | 2 +- .../platform/sync/test_token_providers.py | 2 +- .../AuthProviderConnectionsList.tsx | 92 +++++ .../auth-providers/AuthProviderDetailView.tsx | 16 - .../auth-providers/AuthProviderDialog.tsx | 26 +- .../auth-providers/AuthProviderTable.tsx | 82 +--- .../ConfigureAuthProviderView.tsx | 79 ++-- .../dashboard/AuthProviderButton.tsx | 26 +- .../icons/auth_providers/custom-dark.svg | 13 + .../icons/auth_providers/custom-light.svg | 13 + frontend/src/lib/constants/feature-flags.ts | 3 + frontend/src/lib/stores/authProviders.ts | 1 - 30 files changed, 1049 insertions(+), 187 deletions(-) create mode 100644 backend/airweave/domains/auth_provider/providers/custom.py create mode 100644 backend/airweave/domains/auth_provider/tests/test_custom.py create mode 100644 frontend/src/components/auth-providers/AuthProviderConnectionsList.tsx create mode 100644 frontend/src/components/icons/auth_providers/custom-dark.svg create mode 100644 frontend/src/components/icons/auth_providers/custom-light.svg diff --git a/backend/airweave/core/shared_models.py b/backend/airweave/core/shared_models.py index 70227466a..7139a8fb9 100644 --- a/backend/airweave/core/shared_models.py +++ b/backend/airweave/core/shared_models.py @@ -109,6 +109,7 @@ class FeatureFlag(str, Enum): # These allow specific admin operations via API key authentication API_KEY_ADMIN_SYNC = "api_key_admin_sync" # Allows resync operations via API key CONNECT = "connect" # Enables the Connect playground and embeddable widget features + CUSTOM_AUTH_PROVIDER = "custom_auth_provider" # Enables the Custom auth provider class AuthMethod(str, Enum): diff --git a/backend/airweave/domains/auth_provider/_base.py b/backend/airweave/domains/auth_provider/_base.py index 24e833b40..537acbdf6 100644 --- a/backend/airweave/domains/auth_provider/_base.py +++ b/backend/airweave/domains/auth_provider/_base.py @@ -2,12 +2,19 @@ from abc import ABC, abstractmethod from typing import Any, ClassVar, Dict, List, Optional, Set +from uuid import UUID from pydantic import BaseModel from airweave.core.logging import logger from airweave.domains.auth_provider.auth_result import AuthResult +# OAuth lifecycle fields that auth providers handle internally — +# always optional when fetching credentials from a provider. +AUTH_PROVIDER_OPTIONAL_FIELDS: frozenset[str] = frozenset( + {"refresh_token", "client_id", "client_secret"} +) + class BaseAuthProvider(ABC): """Base class for all auth providers.""" @@ -19,6 +26,7 @@ class BaseAuthProvider(ABC): auth_config_class: ClassVar[Optional[type[BaseModel]]] = None config_class: ClassVar[Optional[type[BaseModel]]] = None SETTINGS_URL: ClassVar[str] = "" + feature_flag: ClassVar[Optional[str]] = None def __init__(self): """Initialize the base auth provider.""" @@ -58,6 +66,7 @@ async def get_creds_for_source( source_short_name: str, source_auth_config_fields: List[str], optional_fields: Optional[Set[str]] = None, + source_connection_id: Optional[UUID] = None, ) -> Dict[str, Any]: """Get credentials for a source. @@ -65,6 +74,8 @@ async def get_creds_for_source( source_short_name: The short name of the source to get credentials for source_auth_config_fields: The fields required for the source auth config optional_fields: Fields that can be skipped if the provider doesn't have them + source_connection_id: UUID of the source connection (used by Custom provider + to scope credentials per connection) """ pass @@ -104,6 +115,7 @@ async def get_auth_result( source_auth_config_fields: List[str], optional_fields: Optional[Set[str]] = None, source_config_field_mappings: Optional[Dict[str, str]] = None, + source_connection_id: Optional[UUID] = None, ) -> AuthResult: """Get auth result with credentials for a source. @@ -115,12 +127,16 @@ async def get_auth_result( source_auth_config_fields: The fields required for the source auth config optional_fields: Fields that can be skipped if the provider doesn't have them source_config_field_mappings: Mapping of config fields extractable from auth response + source_connection_id: UUID of the source connection Returns: AuthResult with credentials and optional source config """ credentials = await self.get_creds_for_source( - source_short_name, source_auth_config_fields, optional_fields + source_short_name, + source_auth_config_fields, + optional_fields, + source_connection_id=source_connection_id, ) source_config = {} diff --git a/backend/airweave/domains/auth_provider/providers/__init__.py b/backend/airweave/domains/auth_provider/providers/__init__.py index 8d5dd9306..7805a13e0 100644 --- a/backend/airweave/domains/auth_provider/providers/__init__.py +++ b/backend/airweave/domains/auth_provider/providers/__init__.py @@ -1,9 +1,11 @@ """Auth provider implementations.""" from .composio import ComposioAuthProvider +from .custom import CustomAuthProvider from .pipedream import PipedreamAuthProvider ALL_AUTH_PROVIDERS: list[type] = [ ComposioAuthProvider, + CustomAuthProvider, PipedreamAuthProvider, ] diff --git a/backend/airweave/domains/auth_provider/providers/composio.py b/backend/airweave/domains/auth_provider/providers/composio.py index edfb52f48..9d95daf5f 100644 --- a/backend/airweave/domains/auth_provider/providers/composio.py +++ b/backend/airweave/domains/auth_provider/providers/composio.py @@ -1,6 +1,7 @@ """Composio Test Auth Provider - provides authentication services for other integrations.""" from typing import Any, Dict, List, Optional, Set +from uuid import UUID import httpx @@ -211,6 +212,7 @@ async def get_creds_for_source( source_short_name: str, source_auth_config_fields: List[str], optional_fields: Optional[Set[str]] = None, + source_connection_id: Optional[UUID] = None, ) -> Dict[str, Any]: """Get credentials for a specific source integration. diff --git a/backend/airweave/domains/auth_provider/providers/custom.py b/backend/airweave/domains/auth_provider/providers/custom.py new file mode 100644 index 000000000..9440a7baf --- /dev/null +++ b/backend/airweave/domains/auth_provider/providers/custom.py @@ -0,0 +1,223 @@ +"""Custom Auth Provider - fetches tokens from a customer-hosted HTTP endpoint.""" + +from typing import Any, Dict, List, Optional, Set +from uuid import UUID + +import httpx + +from airweave.domains.auth_provider._base import BaseAuthProvider +from airweave.domains.auth_provider.exceptions import ( + AuthProviderAuthError, + AuthProviderConfigError, + AuthProviderMissingFieldsError, + AuthProviderRateLimitError, + AuthProviderTemporaryError, +) +from airweave.platform.configs.auth import CustomAuthConfig +from airweave.platform.configs.config import CustomConfig +from airweave.platform.decorators import auth_provider +from airweave.platform.utils.ssrf import SSRFViolation, validate_url + + +@auth_provider( + name="Custom", + short_name="custom", + auth_config_class=CustomAuthConfig, + config_class=CustomConfig, + feature_flag="custom_auth_provider", +) +class CustomAuthProvider(BaseAuthProvider): + """Custom authentication provider. + + Calls GET {base_url}/{source_connection_id} on a customer-hosted endpoint + to fetch fresh credentials. The customer is responsible for returning + the freshest credentials as JSON. + """ + + BLOCKED_SOURCES: list[str] = ["ctti"] + + # Map Airweave-internal field names to the simple names customers return. + # Customers always return {"access_token": "..."} or {"api_key": "..."}. + FIELD_NAME_MAPPING: Dict[str, str] = { + "personal_access_token": "access_token", # GitHub + "api_token": "access_token", # Document360, Pipedrive + } + + # Instance attributes set in create() + base_endpoint_url: str + api_key: str + + @classmethod + async def create( + cls, + credentials: Optional[Dict[str, Any]] = None, + config: Optional[Dict[str, Any]] = None, + ) -> "CustomAuthProvider": + """Create a new Custom auth provider instance.""" + if credentials is None: + raise ValueError("credentials parameter is required") + auth_config = CustomAuthConfig(**credentials) + instance = cls() + instance.base_endpoint_url = auth_config.base_endpoint_url + instance.api_key = auth_config.api_key + return instance + + def _build_headers(self) -> Dict[str, str]: + """Build request headers with API key authentication.""" + return { + "Accept": "application/json", + "X-API-Key": self.api_key, + } + + def _check_ssrf(self, url: str) -> None: + """Validate URL against SSRF blocklist before making a request.""" + try: + validate_url(url) + except SSRFViolation as exc: + self.logger.warning(f"[Custom] SSRF blocked: {exc}") + raise AuthProviderConfigError( + f"Custom endpoint URL blocked by SSRF policy: {exc}", + provider_name="custom", + ) from exc + + def _raise_for_http_status(self, e: httpx.HTTPStatusError, source_short_name: str) -> None: + """Classify an HTTP error status into the appropriate auth provider exception.""" + status = e.response.status_code + self.logger.error(f"[Custom] HTTP {status} from endpoint for source '{source_short_name}'") + if status in (401, 403): + raise AuthProviderAuthError( + f"Custom endpoint returned {status} for source '{source_short_name}'", + provider_name="custom", + ) from e + if status == 429: + retry_after = float(e.response.headers.get("retry-after", 30)) + raise AuthProviderRateLimitError( + f"Custom endpoint rate-limited for source '{source_short_name}'", + provider_name="custom", + retry_after=retry_after, + ) from e + if status == 404: + raise AuthProviderMissingFieldsError( + f"Custom endpoint has no credentials configured for " + f"source '{source_short_name}' (404)", + provider_name="custom", + missing_fields=[], + available_fields=[], + ) from e + if status >= 500: + raise AuthProviderTemporaryError( + f"Custom endpoint returned {status} for source '{source_short_name}'", + provider_name="custom", + status_code=status, + ) from e + raise AuthProviderConfigError( + f"Custom endpoint returned unexpected {status} for source '{source_short_name}'", + provider_name="custom", + ) from e + + async def get_creds_for_source( + self, + source_short_name: str, + source_auth_config_fields: List[str], + optional_fields: Optional[Set[str]] = None, + source_connection_id: Optional[UUID] = None, + ) -> Dict[str, Any]: + """Get credentials for a source by calling GET {base_url}/{source_connection_id}.""" + if not source_connection_id: + raise AuthProviderConfigError( + "Custom auth provider requires a source_connection_id", + provider_name="custom", + ) + _optional_fields = optional_fields or set() + headers = self._build_headers() + url = f"{self.base_endpoint_url}/{source_connection_id}" + + self._check_ssrf(url) + self.logger.info(f"[Custom] Fetching credentials for source '{source_short_name}'") + + async with httpx.AsyncClient(timeout=30.0, follow_redirects=False) as client: + try: + response = await client.get(url, headers=headers) + response.raise_for_status() + data = response.json() + except httpx.HTTPStatusError as e: + self._raise_for_http_status(e, source_short_name) + except (httpx.ConnectError, httpx.TimeoutException) as e: + self.logger.error(f"[Custom] Network error reaching endpoint: {e}") + raise AuthProviderTemporaryError( + f"Custom endpoint unreachable: {e}", + provider_name="custom", + ) from e + + missing_fields = [] + found_credentials: Dict[str, Any] = {} + + for field in source_auth_config_fields: + # Check the response using the mapped name (e.g. access_token for + # personal_access_token), then store under the Airweave-internal name. + mapped = self.FIELD_NAME_MAPPING.get(field, field) + if mapped in data: + found_credentials[field] = data[mapped] + elif field in data: + found_credentials[field] = data[field] + elif field not in _optional_fields: + missing_fields.append(mapped) + + if missing_fields: + available = list(data.keys()) + self.logger.error( + f"[Custom] Missing required fields for source '{source_short_name}': " + f"{missing_fields}. Available: {available}" + ) + raise AuthProviderMissingFieldsError( + f"Custom endpoint response missing required fields for " + f"source '{source_short_name}': {missing_fields}", + provider_name="custom", + missing_fields=missing_fields, + available_fields=available, + ) + + self.logger.info( + f"[Custom] Successfully retrieved {len(found_credentials)} credential fields " + f"for source '{source_short_name}'" + ) + return found_credentials + + async def validate(self) -> bool: + """Validate the custom endpoint by calling GET {base_url}.""" + headers = self._build_headers() + url = self.base_endpoint_url + + self._check_ssrf(url) + self.logger.info("[Custom] Validating endpoint") + + try: + async with httpx.AsyncClient(timeout=30.0, follow_redirects=False) as client: + response = await client.get(url, headers=headers) + response.raise_for_status() + + self.logger.info("[Custom] Endpoint validated successfully") + return True + + except httpx.HTTPStatusError as e: + status = e.response.status_code + if status in (401, 403): + raise AuthProviderAuthError( + f"Custom endpoint validation failed: {status}", + provider_name="custom", + ) from e + if status >= 500: + raise AuthProviderTemporaryError( + f"Custom endpoint validation failed: {status}", + provider_name="custom", + status_code=status, + ) from e + raise AuthProviderConfigError( + f"Custom endpoint validation failed: HTTP {status}", + provider_name="custom", + ) from e + except (httpx.ConnectError, httpx.TimeoutException) as e: + raise AuthProviderTemporaryError( + f"Custom endpoint unreachable during validation: {e}", + provider_name="custom", + ) from e diff --git a/backend/airweave/domains/auth_provider/providers/pipedream.py b/backend/airweave/domains/auth_provider/providers/pipedream.py index edccb5de8..94f1bedb7 100644 --- a/backend/airweave/domains/auth_provider/providers/pipedream.py +++ b/backend/airweave/domains/auth_provider/providers/pipedream.py @@ -2,6 +2,7 @@ import time from typing import Any, Dict, List, Optional, Set +from uuid import UUID import httpx @@ -278,6 +279,7 @@ async def get_creds_for_source( source_short_name: str, source_auth_config_fields: List[str], optional_fields: Optional[Set[str]] = None, + source_connection_id: Optional[UUID] = None, ) -> Dict[str, Any]: """Get credentials for a source from Pipedream. diff --git a/backend/airweave/domains/auth_provider/registry.py b/backend/airweave/domains/auth_provider/registry.py index 4608f3def..800d71509 100644 --- a/backend/airweave/domains/auth_provider/registry.py +++ b/backend/airweave/domains/auth_provider/registry.py @@ -93,6 +93,7 @@ def _build_entry(provider_cls: type) -> AuthProviderRegistryEntry: field_name_mapping: dict[str, str] = getattr(provider_cls, "FIELD_NAME_MAPPING", {}) slug_name_mapping: dict[str, str] = getattr(provider_cls, "SLUG_NAME_MAPPING", {}) settings_url: str = getattr(provider_cls, "SETTINGS_URL", "") + feature_flag: str | None = getattr(provider_cls, "feature_flag", None) # ------------------------------------------------------------------ # Precompute fields @@ -120,4 +121,6 @@ def _build_entry(provider_cls: type) -> AuthProviderRegistryEntry: slug_name_mapping=slug_name_mapping, # Settings URL settings_url=settings_url, + # Feature flag + feature_flag=feature_flag, ) diff --git a/backend/airweave/domains/auth_provider/service.py b/backend/airweave/domains/auth_provider/service.py index b44cb9994..f8ff7d1cd 100644 --- a/backend/airweave/domains/auth_provider/service.py +++ b/backend/airweave/domains/auth_provider/service.py @@ -10,7 +10,7 @@ from airweave.core import credentials from airweave.core.datetime_utils import utc_now_naive from airweave.core.exceptions import InvalidInputError, InvalidStateError, NotFoundException -from airweave.core.shared_models import ConnectionStatus, IntegrationType +from airweave.core.shared_models import ConnectionStatus, FeatureFlag, IntegrationType from airweave.db.unit_of_work import UnitOfWork from airweave.domains.auth_provider.protocols import ( AuthProviderRegistryProtocol, @@ -52,8 +52,17 @@ async def list_connections( return result async def list_metadata(self, *, ctx: ApiContext) -> list[AuthProviderMetadata]: - """List auth provider metadata from registry.""" - return [self._entry_to_metadata(entry) for entry in self._registry.list_all()] + """List auth provider metadata from registry. + + Entries gated by a feature flag are excluded unless the organization + has that flag enabled. + """ + enabled_features = ctx.organization.enabled_features or [] + return [ + self._entry_to_metadata(entry) + for entry in self._registry.list_all() + if not self._is_hidden_by_feature_flag(entry, enabled_features) + ] async def get_metadata(self, *, short_name: str, ctx: ApiContext) -> AuthProviderMetadata: """Get auth provider metadata by short name from registry.""" @@ -388,6 +397,19 @@ async def _to_schema( masked_client_id=masked_client_id, ) + @staticmethod + def _is_hidden_by_feature_flag( + entry: AuthProviderRegistryEntry, enabled_features: list[FeatureFlag] + ) -> bool: + """Return True if the entry requires a feature flag the org doesn't have.""" + if not entry.feature_flag: + return False + try: + required = FeatureFlag(entry.feature_flag) + return required not in enabled_features + except ValueError: + return False + @staticmethod def _entry_to_metadata(entry: AuthProviderRegistryEntry) -> AuthProviderMetadata: """Map a registry entry to public metadata.""" diff --git a/backend/airweave/domains/auth_provider/tests/test_custom.py b/backend/airweave/domains/auth_provider/tests/test_custom.py new file mode 100644 index 000000000..84e978115 --- /dev/null +++ b/backend/airweave/domains/auth_provider/tests/test_custom.py @@ -0,0 +1,355 @@ +"""Tests for CustomAuthProvider.""" + +from unittest.mock import AsyncMock, patch +from uuid import UUID + +import httpx +import pytest + +from airweave.domains.auth_provider.exceptions import ( + AuthProviderAuthError, + AuthProviderConfigError, + AuthProviderMissingFieldsError, + AuthProviderRateLimitError, + AuthProviderTemporaryError, +) +from airweave.domains.auth_provider.providers.custom import CustomAuthProvider + +TEST_SC_ID = UUID("d035439c-dc7d-4813-a207-c68e548cfe51") + + +@pytest.fixture +async def provider(): + """Create a Custom provider.""" + return await CustomAuthProvider.create( + credentials={ + "base_endpoint_url": "https://api.example.com/tokens", + "api_key": "my-secret-key", + } + ) + + +class TestCreate: + """Tests for CustomAuthProvider.create().""" + + @pytest.mark.unit + async def test_create(self, provider): + assert provider.base_endpoint_url == "https://api.example.com/tokens" + assert provider.api_key == "my-secret-key" + + @pytest.mark.unit + async def test_create_strips_trailing_slash(self): + p = await CustomAuthProvider.create( + credentials={ + "base_endpoint_url": "https://api.example.com/tokens/", + "api_key": "key", + } + ) + assert p.base_endpoint_url == "https://api.example.com/tokens" + + +class TestBuildHeaders: + """Tests for _build_headers().""" + + @pytest.mark.unit + async def test_headers(self, provider): + headers = provider._build_headers() + assert headers["Accept"] == "application/json" + assert headers["X-API-Key"] == "my-secret-key" + + +class TestGetCredsForSource: + """Tests for get_creds_for_source().""" + + @pytest.mark.unit + async def test_requires_source_connection_id(self, provider): + with pytest.raises(AuthProviderConfigError, match="source_connection_id"): + await provider.get_creds_for_source("slack", ["access_token"]) + + @pytest.mark.unit + async def test_success(self, provider): + mock_response = httpx.Response( + 200, + json={"access_token": "eyJ-gdrive-token", "refresh_token": "rt-123"}, + request=httpx.Request("GET", f"https://api.example.com/tokens/{TEST_SC_ID}"), + ) + + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=mock_response): + creds = await provider.get_creds_for_source( + "google_drive", + ["access_token"], + source_connection_id=TEST_SC_ID, + ) + + assert creds == {"access_token": "eyJ-gdrive-token"} + + @pytest.mark.unit + async def test_maps_access_token_to_personal_access_token(self, provider): + """Customer returns access_token, provider maps to personal_access_token for GitHub.""" + mock_response = httpx.Response( + 200, + json={"access_token": "ghp_test123"}, + request=httpx.Request("GET", f"https://api.example.com/tokens/{TEST_SC_ID}"), + ) + + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=mock_response): + creds = await provider.get_creds_for_source( + "github", + ["personal_access_token"], + source_connection_id=TEST_SC_ID, + ) + + assert creds == {"personal_access_token": "ghp_test123"} + + @pytest.mark.unit + async def test_maps_access_token_to_api_token(self, provider): + """Customer returns access_token, provider maps to api_token for Document360.""" + mock_response = httpx.Response( + 200, + json={"access_token": "doc360_token"}, + request=httpx.Request("GET", f"https://api.example.com/tokens/{TEST_SC_ID}"), + ) + + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=mock_response): + creds = await provider.get_creds_for_source( + "document360", + ["api_token"], + source_connection_id=TEST_SC_ID, + ) + + assert creds == {"api_token": "doc360_token"} + + @pytest.mark.unit + async def test_calls_correct_url(self, provider): + mock_response = httpx.Response( + 200, + json={"access_token": "token"}, + request=httpx.Request("GET", f"https://api.example.com/tokens/{TEST_SC_ID}"), + ) + + with patch( + "httpx.AsyncClient.get", + new_callable=AsyncMock, + return_value=mock_response, + ) as mock_get: + await provider.get_creds_for_source( + "slack", ["access_token"], source_connection_id=TEST_SC_ID, + ) + + mock_get.assert_called_once() + call_args = mock_get.call_args + assert call_args.args[0] == f"https://api.example.com/tokens/{TEST_SC_ID}" + + @pytest.mark.unit + async def test_optional_fields_not_required(self, provider): + mock_response = httpx.Response( + 200, + json={"access_token": "token"}, + request=httpx.Request("GET", f"https://api.example.com/tokens/{TEST_SC_ID}"), + ) + + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=mock_response): + creds = await provider.get_creds_for_source( + "google_drive", + ["access_token", "refresh_token"], + optional_fields={"refresh_token"}, + source_connection_id=TEST_SC_ID, + ) + + assert creds == {"access_token": "token"} + + @pytest.mark.unit + async def test_error_401(self, provider): + mock_response = httpx.Response( + 401, + json={"error": "unauthorized"}, + request=httpx.Request("GET", f"https://api.example.com/tokens/{TEST_SC_ID}"), + ) + + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=mock_response): + with pytest.raises(AuthProviderAuthError, match="401"): + await provider.get_creds_for_source( + "slack", ["access_token"], source_connection_id=TEST_SC_ID, + ) + + @pytest.mark.unit + async def test_error_429(self, provider): + mock_response = httpx.Response( + 429, + json={"error": "rate limited"}, + headers={"retry-after": "60"}, + request=httpx.Request("GET", f"https://api.example.com/tokens/{TEST_SC_ID}"), + ) + + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=mock_response): + with pytest.raises(AuthProviderRateLimitError) as exc_info: + await provider.get_creds_for_source( + "slack", ["access_token"], source_connection_id=TEST_SC_ID, + ) + assert exc_info.value.retry_after == 60.0 + + @pytest.mark.unit + async def test_error_500(self, provider): + mock_response = httpx.Response( + 500, + json={"error": "internal"}, + request=httpx.Request("GET", f"https://api.example.com/tokens/{TEST_SC_ID}"), + ) + + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=mock_response): + with pytest.raises(AuthProviderTemporaryError, match="500"): + await provider.get_creds_for_source( + "slack", ["access_token"], source_connection_id=TEST_SC_ID, + ) + + @pytest.mark.unit + async def test_error_timeout(self, provider): + with patch( + "httpx.AsyncClient.get", + new_callable=AsyncMock, + side_effect=httpx.TimeoutException("timed out"), + ): + with pytest.raises(AuthProviderTemporaryError, match="unreachable"): + await provider.get_creds_for_source( + "slack", ["access_token"], source_connection_id=TEST_SC_ID, + ) + + @pytest.mark.unit + async def test_error_missing_fields(self, provider): + mock_response = httpx.Response( + 200, + json={"some_other_field": "value"}, + request=httpx.Request("GET", f"https://api.example.com/tokens/{TEST_SC_ID}"), + ) + + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=mock_response): + with pytest.raises(AuthProviderMissingFieldsError) as exc_info: + await provider.get_creds_for_source( + "slack", ["access_token"], source_connection_id=TEST_SC_ID, + ) + assert "access_token" in exc_info.value.missing_fields + + @pytest.mark.unit + async def test_error_404(self, provider): + mock_response = httpx.Response( + 404, + json={"error": "not found"}, + request=httpx.Request("GET", f"https://api.example.com/tokens/{TEST_SC_ID}"), + ) + + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=mock_response): + with pytest.raises(AuthProviderMissingFieldsError, match="404"): + await provider.get_creds_for_source( + "slack", ["access_token"], source_connection_id=TEST_SC_ID, + ) + + @pytest.mark.unit + async def test_ssrf_blocked(self, provider): + provider.base_endpoint_url = "http://169.254.169.254/latest/meta-data" + + with pytest.raises(AuthProviderConfigError, match="SSRF"): + await provider.get_creds_for_source( + "slack", ["access_token"], source_connection_id=TEST_SC_ID, + ) + + +class TestValidate: + """Tests for validate().""" + + @pytest.mark.unit + async def test_validate_success(self, provider): + mock_response = httpx.Response( + 200, + json={"status": "ok"}, + request=httpx.Request("GET", "https://api.example.com/tokens"), + ) + + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=mock_response): + result = await provider.validate() + + assert result is True + + @pytest.mark.unit + async def test_validate_auth_error(self, provider): + mock_response = httpx.Response( + 401, + json={"error": "unauthorized"}, + request=httpx.Request("GET", "https://api.example.com/tokens"), + ) + + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=mock_response): + with pytest.raises(AuthProviderAuthError): + await provider.validate() + + @pytest.mark.unit + async def test_validate_server_error(self, provider): + mock_response = httpx.Response( + 503, + json={"error": "unavailable"}, + request=httpx.Request("GET", "https://api.example.com/tokens"), + ) + + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=mock_response): + with pytest.raises(AuthProviderTemporaryError): + await provider.validate() + + @pytest.mark.unit + async def test_validate_timeout(self, provider): + with patch( + "httpx.AsyncClient.get", + new_callable=AsyncMock, + side_effect=httpx.TimeoutException("timed out"), + ): + with pytest.raises(AuthProviderTemporaryError, match="unreachable"): + await provider.validate() + + @pytest.mark.unit + async def test_validate_ssrf_blocked(self, provider): + provider.base_endpoint_url = "http://169.254.169.254/latest/meta-data" + + with pytest.raises(AuthProviderConfigError, match="SSRF"): + await provider.validate() + + +class TestFollowRedirectsDisabled: + """Verify httpx.AsyncClient is created with follow_redirects=False.""" + + @pytest.mark.unit + async def test_get_creds_no_follow_redirects(self, provider): + with patch( + "airweave.domains.auth_provider.providers.custom.httpx.AsyncClient" + ) as mock_client_cls: + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.get.return_value = httpx.Response( + 200, + json={"access_token": "tok"}, + request=httpx.Request("GET", f"https://api.example.com/tokens/{TEST_SC_ID}"), + ) + mock_client_cls.return_value = mock_client + + await provider.get_creds_for_source( + "slack", ["access_token"], source_connection_id=TEST_SC_ID, + ) + + mock_client_cls.assert_called_once_with(timeout=30.0, follow_redirects=False) + + @pytest.mark.unit + async def test_validate_no_follow_redirects(self, provider): + with patch( + "airweave.domains.auth_provider.providers.custom.httpx.AsyncClient" + ) as mock_client_cls: + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.get.return_value = httpx.Response( + 200, + json={"status": "ok"}, + request=httpx.Request("GET", "https://api.example.com/tokens"), + ) + mock_client_cls.return_value = mock_client + + await provider.validate() + + mock_client_cls.assert_called_once_with(timeout=30.0, follow_redirects=False) diff --git a/backend/airweave/domains/auth_provider/tests/test_service.py b/backend/airweave/domains/auth_provider/tests/test_service.py index afac3cfa3..33c6be419 100644 --- a/backend/airweave/domains/auth_provider/tests/test_service.py +++ b/backend/airweave/domains/auth_provider/tests/test_service.py @@ -9,7 +9,7 @@ from pydantic import BaseModel from airweave.core.exceptions import InvalidInputError, InvalidStateError, NotFoundException -from airweave.core.shared_models import IntegrationType +from airweave.core.shared_models import FeatureFlag, IntegrationType from airweave.domains.auth_provider.service import AuthProviderService from airweave.domains.auth_provider.types import AuthProviderRegistryEntry from airweave.platform.configs._base import Fields @@ -403,3 +403,73 @@ async def test_to_schema_include_and_exclude_masked(): no_mask = await service._to_schema("db", conn, _ctx(), include_masked_client_id=False) assert no_mask.masked_client_id is None + + +# --------------------------------------------------------------------------- +# Feature flag filtering +# --------------------------------------------------------------------------- + +_UNFLAGGED = _entry(short_name="composio") +_FLAGGED = AuthProviderRegistryEntry( + **{**_entry(short_name="custom").model_dump(), "name": "Custom", "feature_flag": "custom_auth_provider"} +) + + +def _ctx_with_features(features: list[FeatureFlag] | None = None): + """Build a context with organization.enabled_features.""" + org = SimpleNamespace(enabled_features=features or []) + logger = SimpleNamespace(info=lambda *a, **k: None, error=lambda *a, **k: None) + return SimpleNamespace( + logger=logger, + organization=org, + has_user_context=True, + tracking_email="owner@airweave.ai", + ) + + +class TestIsHiddenByFeatureFlag: + """Tests for _is_hidden_by_feature_flag.""" + + def test_no_flag_always_visible(self): + assert AuthProviderService._is_hidden_by_feature_flag(_UNFLAGGED, []) is False + + def test_flag_missing_from_org(self): + assert AuthProviderService._is_hidden_by_feature_flag(_FLAGGED, []) is True + + def test_flag_present_in_org(self): + assert ( + AuthProviderService._is_hidden_by_feature_flag( + _FLAGGED, [FeatureFlag.CUSTOM_AUTH_PROVIDER] + ) + is False + ) + + def test_unknown_flag_fails_open(self): + entry = AuthProviderRegistryEntry( + **{**_UNFLAGGED.model_dump(), "feature_flag": "nonexistent_flag"} + ) + assert AuthProviderService._is_hidden_by_feature_flag(entry, []) is False + + +@pytest.mark.asyncio +async def test_list_metadata_hides_flagged_provider(): + registry = SimpleNamespace(list_all=lambda: [_UNFLAGGED, _FLAGGED]) + service = AuthProviderService(registry, connection_repo=None, credential_repo=None) + + ctx = _ctx_with_features() + result = await service.list_metadata(ctx=ctx) + names = [m.short_name for m in result] + assert "composio" in names + assert "custom" not in names + + +@pytest.mark.asyncio +async def test_list_metadata_shows_flagged_provider_with_flag(): + registry = SimpleNamespace(list_all=lambda: [_UNFLAGGED, _FLAGGED]) + service = AuthProviderService(registry, connection_repo=None, credential_repo=None) + + ctx = _ctx_with_features([FeatureFlag.CUSTOM_AUTH_PROVIDER]) + result = await service.list_metadata(ctx=ctx) + names = [m.short_name for m in result] + assert "composio" in names + assert "custom" in names diff --git a/backend/airweave/domains/auth_provider/types.py b/backend/airweave/domains/auth_provider/types.py index 4704756ed..19892ffbd 100644 --- a/backend/airweave/domains/auth_provider/types.py +++ b/backend/airweave/domains/auth_provider/types.py @@ -24,6 +24,9 @@ class AuthProviderRegistryEntry(BaseRegistryEntry): # Settings dashboard URL settings_url: str = "" + # Feature flag gating + feature_flag: str | None = None + class AuthProviderMetadata(BaseRegistryEntry): """Public auth provider metadata returned by API endpoints.""" diff --git a/backend/airweave/domains/sources/exceptions/classifier.py b/backend/airweave/domains/sources/exceptions/classifier.py index 5641f512a..5b555219c 100644 --- a/backend/airweave/domains/sources/exceptions/classifier.py +++ b/backend/airweave/domains/sources/exceptions/classifier.py @@ -66,6 +66,14 @@ def classify_error(exc: Exception) -> ErrorClassification: message=str(exc), ) + # Catch-all for remaining AuthProviderError subtypes (e.g. MissingFieldsError, + # ConfigError) — these are auth provider issues the user needs to address. + if isinstance(exc, AuthProviderError): + return ErrorClassification( + category=SourceConnectionErrorCategory.AUTH_PROVIDER_CREDENTIALS_INVALID, + message=str(exc), + ) + # --- Legacy SourceTokenRefreshError --- if isinstance(exc, SourceTokenRefreshError): return ErrorClassification( diff --git a/backend/airweave/domains/sources/lifecycle.py b/backend/airweave/domains/sources/lifecycle.py index 8c147f3de..cc568f47d 100644 --- a/backend/airweave/domains/sources/lifecycle.py +++ b/backend/airweave/domains/sources/lifecycle.py @@ -7,7 +7,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast from uuid import UUID if TYPE_CHECKING: @@ -22,10 +22,11 @@ from airweave.core.exceptions import NotFoundException from airweave.core.logging import ContextualLogger, LoggerConfigurator from airweave.core.shared_models import FeatureFlag -from airweave.domains.auth_provider._base import BaseAuthProvider +from airweave.domains.auth_provider._base import AUTH_PROVIDER_OPTIONAL_FIELDS, BaseAuthProvider from airweave.domains.auth_provider.exceptions import ( AuthProviderAccountNotFoundError, AuthProviderAuthError, + AuthProviderError, ) from airweave.domains.auth_provider.protocols import AuthProviderRegistryProtocol from airweave.domains.connections.protocols import ConnectionRepositoryProtocol @@ -40,11 +41,6 @@ SourceNotFoundError, SourceValidationError, ) -from airweave.domains.sources.token_providers.exceptions import ( - TokenCredentialsInvalidError, - TokenExpiredError, - TokenProviderAccountGoneError, -) from airweave.domains.sources.protocols import ( SourceLifecycleServiceProtocol, SourceRegistryProtocol, @@ -52,6 +48,11 @@ from airweave.domains.sources.rate_limiting.service import SourceRateLimiter from airweave.domains.sources.token_providers.auth_provider import AuthProviderTokenProvider from airweave.domains.sources.token_providers.credential import DirectCredentialProvider +from airweave.domains.sources.token_providers.exceptions import ( + TokenCredentialsInvalidError, + TokenExpiredError, + TokenProviderAccountGoneError, +) from airweave.domains.sources.token_providers.oauth import OAuthTokenProvider from airweave.domains.sources.token_providers.static import StaticTokenProvider from airweave.domains.sources.types import AuthConfig, SourceConnectionData, SourceRegistryEntry @@ -128,13 +129,19 @@ async def create( ) # 2. Get auth configuration (credentials + proxy setup) - auth_config = await self._get_auth_configuration( - db=db, - source_connection_data=source_connection_data, - ctx=ctx, - logger=logger, - access_token=access_token, - ) + try: + auth_config = await self._get_auth_configuration( + db=db, + source_connection_data=source_connection_data, + ctx=cast(ApiContext, ctx), + logger=logger, + access_token=access_token, + ) + except AuthProviderError as exc: + raise SourceValidationError( + short_name=source_connection_data.short_name, + reason=f"auth provider error: {exc}", + ) from exc # 3. Resolve auth provider token_provider = await self._resolve_token_provider( @@ -328,15 +335,12 @@ async def _get_auth_configuration( ) # Case 2: Auth provider connection - if ( - source_connection_data.readable_auth_provider_id - and source_connection_data.auth_provider_config - ): + if source_connection_data.readable_auth_provider_id: return await self._get_auth_provider_configuration( db=db, source_connection_data=source_connection_data, readable_auth_provider_id=source_connection_data.readable_auth_provider_id, - auth_provider_config=source_connection_data.auth_provider_config, + auth_provider_config=source_connection_data.auth_provider_config or {}, ctx=ctx, logger=logger, ) @@ -369,11 +373,13 @@ async def _get_auth_provider_configuration( logger=logger, ) - # Get runtime auth fields from the source registry (precomputed at startup) + # Get runtime auth fields from the source registry (precomputed at startup). + # Auth providers handle token refresh, so OAuth lifecycle fields + # (refresh_token, client_id, client_secret) are always optional. short_name = source_connection_data.short_name entry = self._source_registry.get(short_name) auth_fields_all = entry.runtime_auth_all_fields - auth_fields_optional = entry.runtime_auth_optional_fields + auth_fields_optional = entry.runtime_auth_optional_fields | AUTH_PROVIDER_OPTIONAL_FIELDS source_config_field_mappings = self._build_source_config_field_mappings( source_connection_data @@ -384,6 +390,7 @@ async def _get_auth_provider_configuration( source_auth_config_fields=auth_fields_all, optional_fields=auth_fields_optional, source_config_field_mappings=source_config_field_mappings or None, + source_connection_id=source_connection_data.source_connection_id, ) if auth_result.source_config: @@ -630,6 +637,17 @@ async def _resolve_token_provider( if access_token is not None: return StaticTokenProvider(access_token, source_short_name=short_name) + # Auth provider takes priority — ensures errors are classified correctly + # regardless of whether the source uses OAuth or direct auth. + if auth_provider_instance: + return AuthProviderTokenProvider( + auth_provider_instance=auth_provider_instance, + source_short_name=short_name, + source_registry=self._source_registry, + logger=logger, + source_connection_id=source_connection_data.source_connection_id, + ) + entry = self._source_registry.get(short_name) source_credentials = self._normalize_credentials(source_credentials, entry, logger) @@ -641,14 +659,6 @@ async def _resolve_token_provider( return DirectCredentialProvider(source_credentials, source_short_name=short_name) try: - if auth_provider_instance: - return AuthProviderTokenProvider( - auth_provider_instance=auth_provider_instance, - source_short_name=short_name, - source_registry=self._source_registry, - logger=logger, - ) - # Sources that support both OAuth and API key auth (e.g. calcom, coda) # may have structured credentials without access_token when using # API key mode — route those to DirectCredentialProvider. diff --git a/backend/airweave/domains/sources/tests/test_lifecycle.py b/backend/airweave/domains/sources/tests/test_lifecycle.py index 135ecb99d..c698e1b7b 100644 --- a/backend/airweave/domains/sources/tests/test_lifecycle.py +++ b/backend/airweave/domains/sources/tests/test_lifecycle.py @@ -402,7 +402,7 @@ class AuthConfigRoutingCase: AuthConfigRoutingCase(id="database-fallthrough", expected_route="database"), AuthConfigRoutingCase(id="auth-provider-id-but-no-config", readable_auth_provider_id="pd-1", - expected_route="database"), + expected_route="auth_provider"), ] diff --git a/backend/airweave/domains/sources/token_providers/auth_provider.py b/backend/airweave/domains/sources/token_providers/auth_provider.py index d16e7d298..9b833bc48 100644 --- a/backend/airweave/domains/sources/token_providers/auth_provider.py +++ b/backend/airweave/domains/sources/token_providers/auth_provider.py @@ -4,11 +4,12 @@ import time from typing import TYPE_CHECKING, Optional +from uuid import UUID from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential from airweave.core.logging import ContextualLogger -from airweave.domains.auth_provider._base import BaseAuthProvider +from airweave.domains.auth_provider._base import AUTH_PROVIDER_OPTIONAL_FIELDS, BaseAuthProvider from airweave.domains.auth_provider.exceptions import ( AuthProviderAccountNotFoundError, AuthProviderAuthError, @@ -50,12 +51,14 @@ def __init__( source_registry: SourceRegistryProtocol, *, logger: ContextualLogger, + source_connection_id: Optional[UUID] = None, ): """Initialize with an auth provider instance and source registry.""" self._provider = auth_provider_instance self._source_short_name = source_short_name self._source_registry = source_registry self._logger = logger + self._source_connection_id = source_connection_id self._cached_token: Optional[str] = None self._cached_at: float = 0.0 @@ -134,15 +137,28 @@ async def _fetch_token(self) -> str: provider_kind=self.provider_kind, ) from e - if not isinstance(creds, dict) or "access_token" not in creds: + if not isinstance(creds, dict) or not creds: raise TokenProviderMissingCredsError( - f"No access_token in auth provider response for {self._source_short_name}", + f"Empty auth provider response for {self._source_short_name}", source_short_name=self._source_short_name, provider_kind=self.provider_kind, - missing_fields=["access_token"], + missing_fields=entry.runtime_auth_all_fields, ) - return creds["access_token"] + # Extract the primary credential value. Prefer access_token if present, + # otherwise use the first runtime auth field (e.g. personal_access_token, + # api_key). This supports both OAuth and non-OAuth sources. + for field in ["access_token"] + entry.runtime_auth_all_fields: + if field in creds: + return str(creds[field]) + + raise TokenProviderMissingCredsError( + f"No usable credential in auth provider response for {self._source_short_name}. " + f"Expected one of: {entry.runtime_auth_all_fields}", + source_short_name=self._source_short_name, + provider_kind=self.provider_kind, + missing_fields=entry.runtime_auth_all_fields, + ) @retry( retry=retry_if_exception_type((AuthProviderRateLimitError, AuthProviderServerError)), @@ -154,7 +170,8 @@ async def _call_provider_with_retry(self, entry) -> dict: return await self._provider.get_creds_for_source( source_short_name=self._source_short_name, source_auth_config_fields=entry.runtime_auth_all_fields, - optional_fields=entry.runtime_auth_optional_fields, + optional_fields=entry.runtime_auth_optional_fields | AUTH_PROVIDER_OPTIONAL_FIELDS, + source_connection_id=self._source_connection_id, ) async def get_token(self) -> str: diff --git a/backend/airweave/platform/configs/auth.py b/backend/airweave/platform/configs/auth.py index c6b62cd8b..6669ea1ee 100644 --- a/backend/airweave/platform/configs/auth.py +++ b/backend/airweave/platform/configs/auth.py @@ -1020,6 +1020,62 @@ class PipedreamAuthConfig(AuthConfig): ) +class CustomAuthConfig(AuthConfig): + """Custom Auth Provider authentication credentials schema. + + Stores a base URL and API key for a customer-hosted token endpoint. + Airweave calls GET {base_url}/{source_connection_id} to fetch credentials + for each source connection. The customer is responsible for returning fresh tokens. + + Contract: + - GET {base_url} must return 2xx (used for validation) + - GET {base_url}/{source_connection_id} must return JSON with credentials + - source_connection_id is the UUID of the source connection in Airweave + - All requests include an X-API-Key header for authentication + - No refresh_token needed — Airweave re-fetches automatically + + Response format (only two shapes): + - Token-based sources (OAuth, PATs): {"access_token": "..."} + - API key sources (Stripe, etc.): {"api_key": "..."} + """ + + base_endpoint_url: str = Field( + title="Endpoint Base URL", + description=( + "Base URL of your token endpoint. " + "Airweave calls GET {base_url}/{source_connection_id} for each " + "source connection and expects a JSON response with the credential " + 'the source needs (e.g. {"access_token": "..."} for OAuth sources). ' + "No refresh_token needed — Airweave re-fetches automatically. " + "GET {base_url} must return 2xx for validation." + ), + ) + api_key: str = Field( + title="API Key", + description=( + "API key sent as X-API-Key header to authenticate all requests to your endpoint." + ), + ) + + @field_validator("base_endpoint_url") + @classmethod + def validate_base_endpoint_url(cls, v: str) -> str: + """Validate the endpoint URL for SSRF safety.""" + if not v or not v.strip(): + raise ValueError("base_endpoint_url is required") + v = v.strip().rstrip("/") + validate_url(v) + return v + + @field_validator("api_key") + @classmethod + def validate_api_key(cls, v: str) -> str: + """Validate that api_key is non-empty.""" + if not v or not v.strip(): + raise ValueError("api_key is required") + return v.strip() + + class ZohoCRMAuthConfig(OAuth2WithRefreshAuthConfig): """Zoho CRM authentication credentials schema.""" diff --git a/backend/airweave/platform/configs/config.py b/backend/airweave/platform/configs/config.py index c522e4d5e..db091cf05 100644 --- a/backend/airweave/platform/configs/config.py +++ b/backend/airweave/platform/configs/config.py @@ -1130,6 +1130,12 @@ class ComposioConfig(AuthProviderConfig): ) +class CustomConfig(AuthProviderConfig): + """Custom Auth Provider configuration schema.""" + + pass + + class PipedreamConfig(AuthProviderConfig): """Pipedream Auth Provider configuration schema.""" diff --git a/backend/airweave/platform/decorators.py b/backend/airweave/platform/decorators.py index 7c2b4b9dd..25fbdec1b 100644 --- a/backend/airweave/platform/decorators.py +++ b/backend/airweave/platform/decorators.py @@ -198,6 +198,7 @@ def auth_provider( short_name: str, auth_config_class: Type[BaseModel], config_class: Type[BaseModel], + feature_flag: Optional[str] = None, ) -> Callable[[type[_AuthProviderT]], type[_AuthProviderT]]: """Class decorator to mark a class as representing an Airweave auth provider. @@ -207,6 +208,7 @@ def auth_provider( short_name (str): The short name of the auth provider. auth_config_class (Type[BaseModel]): The authentication config class. config_class (Type[BaseModel]): The configuration class. + feature_flag (Optional[str]): Optional feature flag required to access this provider. Returns: ------- @@ -220,6 +222,7 @@ def decorator(cls: type[_AuthProviderT]) -> type[_AuthProviderT]: cls.short_name = short_name cls.auth_config_class = auth_config_class cls.config_class = config_class + cls.feature_flag = feature_flag return cls return decorator diff --git a/backend/tests/e2e/smoke/test_sources.py b/backend/tests/e2e/smoke/test_sources.py index da92dd856..62ca82e9d 100644 --- a/backend/tests/e2e/smoke/test_sources.py +++ b/backend/tests/e2e/smoke/test_sources.py @@ -154,7 +154,7 @@ async def test_supported_auth_providers_structure(self, api_client: httpx.AsyncC sources = response.json() # Known valid auth provider short names - valid_providers = ["pipedream", "composio"] + valid_providers = ["pipedream", "composio", "custom"] for source in sources: providers = source.get("supported_auth_providers", []) diff --git a/backend/tests/unit/platform/sync/test_token_providers.py b/backend/tests/unit/platform/sync/test_token_providers.py index 85803dee4..fc194117a 100644 --- a/backend/tests/unit/platform/sync/test_token_providers.py +++ b/backend/tests/unit/platform/sync/test_token_providers.py @@ -221,7 +221,7 @@ def _make_provider(self, access_token: str = "fresh_token"): mock_registry = MagicMock() entry = MagicMock() entry.runtime_auth_all_fields = ["access_token"] - entry.runtime_auth_optional_fields = [] + entry.runtime_auth_optional_fields = set() mock_registry.get.return_value = entry return AuthProviderTokenProvider( diff --git a/frontend/src/components/auth-providers/AuthProviderConnectionsList.tsx b/frontend/src/components/auth-providers/AuthProviderConnectionsList.tsx new file mode 100644 index 000000000..f71a7a1c9 --- /dev/null +++ b/frontend/src/components/auth-providers/AuthProviderConnectionsList.tsx @@ -0,0 +1,92 @@ +import React from "react"; +import { useTheme } from "@/lib/theme-provider"; +import { cn } from "@/lib/utils"; +import { Plus } from "lucide-react"; +import { Button } from "@/components/ui/button"; +import { getAuthProviderIconUrl } from "@/lib/utils/icons"; +import { AuthProvider, AuthProviderConnection, useAuthProvidersStore } from "@/lib/stores/authProviders"; +import { format } from "date-fns"; + +interface AuthProviderConnectionsListProps { + authProvider: AuthProvider; + onSelectConnection: (connection: AuthProviderConnection) => void; + onAddNew: () => void; + onCancel: () => void; +} + +export const AuthProviderConnectionsList: React.FC = ({ + authProvider, + onSelectConnection, + onAddNew, + onCancel, +}) => { + const { resolvedTheme } = useTheme(); + const isDark = resolvedTheme === "dark"; + const { authProviderConnections } = useAuthProvidersStore(); + + const connections = authProviderConnections.filter( + (conn) => conn.short_name === authProvider?.short_name + ); + + return ( +
+ {/* Header */} +
+
+ {`${authProvider?.name} +
+
+

{authProvider?.name}

+

+ {connections.length} connection{connections.length !== 1 ? "s" : ""} +

+
+
+ + {/* Connections list */} +
+ {connections.map((conn) => ( + + ))} +
+ + {/* Actions */} +
+ + +
+
+ ); +}; diff --git a/frontend/src/components/auth-providers/AuthProviderDetailView.tsx b/frontend/src/components/auth-providers/AuthProviderDetailView.tsx index 3e23df8b2..6090866ae 100644 --- a/frontend/src/components/auth-providers/AuthProviderDetailView.tsx +++ b/frontend/src/components/auth-providers/AuthProviderDetailView.tsx @@ -226,20 +226,8 @@ export const AuthProviderDetailView: React.FC = ({ const [isDeleting, setIsDeleting] = useState(false); const [isClosing, setIsClosing] = useState(false); - console.log('🔍 [AuthProviderDetailView] Component mounted with:', { - authProviderConnectionId, - authProviderName, - authProviderShortName, - viewData - }); - - // Log component lifecycle useEffect(() => { - console.log('🌟 [AuthProviderDetailView] useEffect mount check'); - return () => { - console.log('💥 [AuthProviderDetailView] Component unmounting'); - // Clear any stored errors when component unmounts clearStoredErrorDetails(); }; }, []); @@ -247,16 +235,12 @@ export const AuthProviderDetailView: React.FC = ({ // Fetch connection details useEffect(() => { if (!authProviderConnectionId || isClosing) { - if (!authProviderConnectionId) { - console.warn('⚠️ [AuthProviderDetailView] No authProviderConnectionId provided'); - } return; } let isMounted = true; const fetchConnectionDetails = async () => { - console.log('📡 [AuthProviderDetailView] Fetching connection details for:', authProviderConnectionId); setLoading(true); try { const response = await apiClient.get(`/auth-providers/connections/${authProviderConnectionId}`); diff --git a/frontend/src/components/auth-providers/AuthProviderDialog.tsx b/frontend/src/components/auth-providers/AuthProviderDialog.tsx index a8b4db3df..cf2263340 100644 --- a/frontend/src/components/auth-providers/AuthProviderDialog.tsx +++ b/frontend/src/components/auth-providers/AuthProviderDialog.tsx @@ -2,6 +2,7 @@ import React, { useState, useEffect } from "react"; import { Dialog, DialogContent } from "@/components/ui/dialog"; import { ConfigureAuthProviderView } from "./ConfigureAuthProviderView"; import { AuthProviderDetailView } from "./AuthProviderDetailView"; +import { AuthProviderConnectionsList } from "./AuthProviderConnectionsList"; import { EditAuthProviderView } from "@/components/shared/views/EditAuthProviderView"; import { useTheme } from "@/lib/theme-provider"; import { cn } from "@/lib/utils"; @@ -14,7 +15,7 @@ export type { DialogViewProps }; interface AuthProviderDialogProps { open: boolean; onOpenChange: (open: boolean) => void; - mode: 'auth-provider' | 'auth-provider-detail' | 'auth-provider-edit'; + mode: 'auth-provider' | 'auth-provider-detail' | 'auth-provider-edit' | 'auth-provider-list'; authProvider: any; connection?: any; onComplete?: (result: any) => void; @@ -78,8 +79,6 @@ export const AuthProviderDialog: React.FC = ({ }; const handleNext = (data?: any) => { - console.log("🚀 [AuthProviderDialog] handleNext called with:", data); - // Merge new data with existing viewData const newViewData = { ...viewData, ...data }; setViewData(newViewData); @@ -91,8 +90,6 @@ export const AuthProviderDialog: React.FC = ({ }; const handleComplete = (result?: any) => { - console.log("✅ [AuthProviderDialog] handleComplete called with:", result); - // Handle different completion actions if (result?.action === 'edit') { // Switch to edit mode @@ -164,6 +161,25 @@ export const AuthProviderDialog: React.FC = ({ /> ); + case 'auth-provider-list': + return ( + { + setViewData(prev => ({ + ...prev, + authProviderConnectionId: conn.readable_id, + authProviderConnectionName: conn.name, + })); + setCurrentView('auth-provider-detail'); + }} + onAddNew={() => { + setCurrentView('auth-provider'); + }} + onCancel={handleCancel} + /> + ); + default: return null; } diff --git a/frontend/src/components/auth-providers/AuthProviderTable.tsx b/frontend/src/components/auth-providers/AuthProviderTable.tsx index 611706a4f..d38b9967f 100644 --- a/frontend/src/components/auth-providers/AuthProviderTable.tsx +++ b/frontend/src/components/auth-providers/AuthProviderTable.tsx @@ -24,7 +24,7 @@ export const AuthProviderTable = () => { const [dialogOpen, setDialogOpen] = useState(false); const [selectedAuthProvider, setSelectedAuthProvider] = useState(null); const [selectedConnection, setSelectedConnection] = useState(null); - const [dialogMode, setDialogMode] = useState<'auth-provider' | 'auth-provider-detail' | 'auth-provider-edit'>('auth-provider'); + const [dialogMode, setDialogMode] = useState<'auth-provider' | 'auth-provider-detail' | 'auth-provider-edit' | 'auth-provider-list'>('auth-provider'); const [remountKey, setRemountKey] = useState(0); // Fetch auth providers and connections on component mount @@ -33,79 +33,35 @@ export const AuthProviderTable = () => { Promise.all([ fetchAuthProviders(), fetchAuthProviderConnections() - ]).then(([providers, connections]) => { - console.log(`🔄 [AuthProviderTable] Auth providers loaded: ${providers.length} providers, ${connections.length} connections`); - }); + ]); }, [fetchAuthProviders, fetchAuthProviderConnections]); - // Log state changes - useEffect(() => { - console.log('🎮 [AuthProviderTable] State changed:', { - dialogOpen, - dialogMode, - selectedAuthProvider: selectedAuthProvider?.short_name, - selectedConnection: selectedConnection?.readable_id - }); - }, [dialogOpen, dialogMode, selectedAuthProvider, selectedConnection]); - - // Log when connections change - useEffect(() => { - console.log('📊 [AuthProviderTable] Auth provider connections changed:', { - count: authProviderConnections.length, - isLoadingConnections - }); - }, [authProviderConnections, isLoadingConnections]); - - // Log when dialog open state changes - useEffect(() => { - console.log('🚨 [AuthProviderTable] dialogOpen state changed to:', dialogOpen); - }, [dialogOpen]); + const handleAuthProviderClick = (authProvider: any) => { + const connections = authProviderConnections.filter(conn => conn.short_name === authProvider.short_name); - // Log dialog state for debugging + setSelectedAuthProvider(authProvider); - const handleAuthProviderClick = (authProvider: any) => { - console.log('🖱️ [AuthProviderTable] handleAuthProviderClick called:', { - authProvider: authProvider.short_name, - hasConnection: !!authProviderConnections.find(conn => conn.short_name === authProvider.short_name) - }); - - const connection = authProviderConnections.find(conn => conn.short_name === authProvider.short_name); - - if (connection) { - console.log('🔗 [AuthProviderTable] Found existing connection:', connection.readable_id); - // Auth provider is already connected, show details - setSelectedAuthProvider(authProvider); - setSelectedConnection(connection); - setDialogMode('auth-provider-detail'); + if (connections.length === 0) { + if (!canManage) { + toast.info("Only admins can configure auth providers"); + return; + } + setSelectedConnection(null); + setDialogMode('auth-provider'); setDialogOpen(true); - } else if (!canManage) { - toast.info("Only admins can configure auth providers"); - return; } else { - console.log('➕ [AuthProviderTable] No connection found, opening configure dialog'); - // Auth provider not connected, show configure dialog - setSelectedAuthProvider(authProvider); setSelectedConnection(null); - setDialogMode('auth-provider'); + setDialogMode('auth-provider-list'); setDialogOpen(true); } - - console.log('🎯 [AuthProviderTable] Dialog state after click:', { - dialogOpen: true, - dialogMode: connection ? 'auth-provider-detail' : 'auth-provider', - selectedAuthProvider: authProvider.short_name - }); }; const handleDialogComplete = (result: any) => { - console.log("🏁 [AuthProviderTable] Dialog completed:", result); - // Close the dialog setDialogOpen(false); // If it was an edit action, open edit dialog if (result?.action === 'edit') { - console.log("✏️ [AuthProviderTable] Edit action requested, opening edit dialog"); // Store the auth provider details for edit dialog const tempAuthProvider = selectedAuthProvider; @@ -130,7 +86,6 @@ export const AuthProviderTable = () => { // If it was an updated action, open detail dialog with refreshed data if (result?.action === 'updated') { - console.log("✅ [AuthProviderTable] Auth provider connection was updated"); // Find the updated connection from the refreshed list const updatedConnection = authProviderConnections.find( @@ -153,7 +108,6 @@ export const AuthProviderTable = () => { // If it was a deletion, increment remountKey to force dialog remount if (result?.action === 'deleted') { - console.log("🗑️ [AuthProviderTable] Auth provider connection was deleted"); setRemountKey(prev => prev + 1); } @@ -164,7 +118,6 @@ export const AuthProviderTable = () => { // Refresh connections if a new one was created or deleted if (result?.success) { - console.log("♻️ [AuthProviderTable] Refreshing auth provider connections"); fetchAuthProviderConnections(); } }; @@ -187,9 +140,7 @@ export const AuthProviderTable = () => { // Memoize dialog key to prevent remounts const dialogKey = useMemo(() => { // Only use auth provider short name as key since connection ID isn't available when creating new - const key = dialogOpen ? `auth-${selectedAuthProvider?.short_name || 'none'}-${remountKey}` : 'closed'; - console.log('🔑 [AuthProviderTable] Dialog key:', key); - return key; + return dialogOpen ? `auth-${selectedAuthProvider?.short_name || 'none'}-${remountKey}` : 'closed'; }, [dialogOpen, selectedAuthProvider?.short_name, remountKey]); return ( @@ -205,7 +156,7 @@ export const AuthProviderTable = () => { ) : ( allProviders.map(provider => { - const connection = authProviderConnections.find( + const connections = authProviderConnections.filter( conn => conn.short_name === provider.short_name ); @@ -215,7 +166,8 @@ export const AuthProviderTable = () => { id={provider.short_name} name={provider.name} shortName={provider.short_name} - isConnected={!!connection} + isConnected={connections.length > 0} + connectionCount={connections.length} isComingSoon={'isComingSoon' in provider ? provider.isComingSoon : false} onClick={() => handleAuthProviderClick(provider)} /> diff --git a/frontend/src/components/auth-providers/ConfigureAuthProviderView.tsx b/frontend/src/components/auth-providers/ConfigureAuthProviderView.tsx index be93ebb38..3c77b3541 100644 --- a/frontend/src/components/auth-providers/ConfigureAuthProviderView.tsx +++ b/frontend/src/components/auth-providers/ConfigureAuthProviderView.tsx @@ -81,19 +81,6 @@ export const ConfigureAuthProviderView: React.FC const navigate = useNavigate(); const { fetchAuthProviderConnections } = useAuthProvidersStore(); - // Log component lifecycle - useEffect(() => { - console.log('🌟 [ConfigureAuthProviderView] Component mounted:', { - authProviderName, - authProviderShortName, - viewData - }); - - return () => { - console.log('💥 [ConfigureAuthProviderView] Component unmounting'); - }; - }, []); - const [isSubmitting, setIsSubmitting] = useState(false); const [loading, setLoading] = useState(true); const [authProviderDetails, setAuthProviderDetails] = useState(null); @@ -105,11 +92,6 @@ export const ConfigureAuthProviderView: React.FC const [airweaveImageError, setAirweaveImageError] = useState(false); const [authProviderImageError, setAuthProviderImageError] = useState(false); - // Log loading state changes - useEffect(() => { - console.log('⏳ [ConfigureAuthProviderView] Loading state:', loading); - }, [loading]); - // Default name for the connection const defaultConnectionName = authProviderName ? `My ${authProviderName} Connection` : "My Connection"; @@ -157,30 +139,18 @@ export const ConfigureAuthProviderView: React.FC // Fetch auth provider details useEffect(() => { - console.log('🔍 [ConfigureAuthProviderView] Auth provider details effect triggered:', { - authProviderShortName, - currentLoading: loading - }); - if (!authProviderShortName) { - console.log('⚠️ [ConfigureAuthProviderView] No authProviderShortName, skipping fetch'); setLoading(false); return; } const fetchDetails = async () => { - console.log('🚀 [ConfigureAuthProviderView] Starting to fetch auth provider details'); setLoading(true); try { const response = await apiClient.get(`/auth-providers/detail/${authProviderShortName}`); - console.log('📡 [ConfigureAuthProviderView] Auth provider details response:', response.ok); if (response.ok) { const data = await response.json(); - console.log('✅ [ConfigureAuthProviderView] Auth provider details loaded:', { - hasAuthFields: !!data.auth_fields, - fieldsCount: data.auth_fields?.fields?.length || 0 - }); setAuthProviderDetails(data); // Initialize auth field values @@ -195,16 +165,13 @@ export const ConfigureAuthProviderView: React.FC } } else { const errorText = await response.text(); - console.error('❌ [ConfigureAuthProviderView] Failed to load auth provider details:', errorText); throw new Error(`Failed to load auth provider details: ${errorText}`); } } catch (error) { - console.error("Error fetching auth provider details:", error); if (onError) { onError(error instanceof Error ? error : new Error(String(error)), authProviderName); } } finally { - console.log('🏁 [ConfigureAuthProviderView] Setting loading to false'); setLoading(false); } }; @@ -391,16 +358,7 @@ export const ConfigureAuthProviderView: React.FC duration: 5000, }); - // Navigate to detail view BEFORE refreshing connections - console.log('🎯 [ConfigureAuthProviderView] Connection created successfully:', { - connectionId: connection.id, - readableId: connection.readable_id, - name: connection.name, - shortName: connection.short_name - }); - if (onNext) { - console.log('🚀 [ConfigureAuthProviderView] Calling onNext to navigate to detail view'); onNext({ authProviderConnectionId: connection.readable_id, authProviderName: authProviderName, // Use the original auth provider name, not connection name @@ -408,17 +366,11 @@ export const ConfigureAuthProviderView: React.FC isNewConnection: true // Flag to indicate this is a new connection }); - // Refresh connections after navigation - testing without delay - console.log('📡 [ConfigureAuthProviderView] Refreshing auth provider connections after navigation'); fetchAuthProviderConnections(); } else { - console.warn('⚠️ [ConfigureAuthProviderView] onNext is not defined!'); - // If no onNext, refresh immediately await fetchAuthProviderConnections(); } } catch (error) { - console.error("Error creating auth provider connection:", error); - // Extract error message from the response let errorMessage = "Failed to create connection"; if (error instanceof Error) { @@ -654,6 +606,37 @@ export const ConfigureAuthProviderView: React.FC )} + {authProviderShortName === 'custom' && ( + + + + + + +
+

Your endpoint must implement:

+

GET {'{base_url}'} — return 2xx (used for validation)

+

GET {'{base_url}/{source_connection_id}'} — return JSON credentials

+

Response: {`{"access_token": "..."}`} or {`{"api_key": "..."}`}

+

Auth: X-API-Key header sent with every request

+

No refresh_token needed — Airweave re-fetches automatically

+
+
+
+
+ )}
{authProviderDetails.auth_fields.fields.map((field: any) => ( diff --git a/frontend/src/components/dashboard/AuthProviderButton.tsx b/frontend/src/components/dashboard/AuthProviderButton.tsx index a3434f2ec..5a006c54f 100644 --- a/frontend/src/components/dashboard/AuthProviderButton.tsx +++ b/frontend/src/components/dashboard/AuthProviderButton.tsx @@ -10,7 +10,8 @@ interface AuthProviderButtonProps { name: string; shortName: string; isConnected?: boolean; - isComingSoon?: boolean; // Add this prop + connectionCount?: number; + isComingSoon?: boolean; onClick?: () => void; } @@ -19,7 +20,8 @@ export const AuthProviderButton = ({ name, shortName, isConnected = false, - isComingSoon = false, // Add default value + connectionCount = 0, + isComingSoon = false, onClick }: AuthProviderButtonProps) => { const { resolvedTheme } = useTheme(); @@ -30,12 +32,6 @@ export const AuthProviderButton = ({ // Don't handle clicks for coming soon providers if (isComingSoon) return; - console.log('🔘 [AuthProviderButton] Button clicked:', { - id, - name, - shortName, - isConnected - }); if (onClick) { onClick(); } @@ -78,7 +74,19 @@ export const AuthProviderButton = ({ )}
- {name} +
+ {name} + {connectionCount > 1 && ( + + {connectionCount} + + )} +
{isComingSoon && ( Coming soon diff --git a/frontend/src/components/icons/auth_providers/custom-dark.svg b/frontend/src/components/icons/auth_providers/custom-dark.svg new file mode 100644 index 000000000..ecef112bf --- /dev/null +++ b/frontend/src/components/icons/auth_providers/custom-dark.svg @@ -0,0 +1,13 @@ + + + + + + + + + + + + + diff --git a/frontend/src/components/icons/auth_providers/custom-light.svg b/frontend/src/components/icons/auth_providers/custom-light.svg new file mode 100644 index 000000000..822c72096 --- /dev/null +++ b/frontend/src/components/icons/auth_providers/custom-light.svg @@ -0,0 +1,13 @@ + + + + + + + + + + + + + diff --git a/frontend/src/lib/constants/feature-flags.ts b/frontend/src/lib/constants/feature-flags.ts index 223868021..a1694c019 100644 --- a/frontend/src/lib/constants/feature-flags.ts +++ b/frontend/src/lib/constants/feature-flags.ts @@ -18,6 +18,9 @@ export const FeatureFlags = { // Connect CONNECT: 'connect', + + // Auth Providers + CUSTOM_AUTH_PROVIDER: 'custom_auth_provider', } as const; export type FeatureFlag = typeof FeatureFlags[keyof typeof FeatureFlags]; diff --git a/frontend/src/lib/stores/authProviders.ts b/frontend/src/lib/stores/authProviders.ts index d958e1549..ec8a57677 100644 --- a/frontend/src/lib/stores/authProviders.ts +++ b/frontend/src/lib/stores/authProviders.ts @@ -116,7 +116,6 @@ export const useAuthProvidersStore = create((set, get) => ({ }, clearAuthProviderConnections: () => { - console.log("🧹 [AuthProvidersStore] Clearing auth provider connections"); set({ authProviderConnections: [] }); } })); From d62597079d56907a5448445b6b877b8e1d2538e8 Mon Sep 17 00:00:00 2001 From: Felix Schmetz Date: Mon, 20 Apr 2026 12:07:28 +0200 Subject: [PATCH 06/25] fix(ci): use in-project venvs for mypy and import-linter jobs mypy and import-linter used virtualenvs-create: false, so poetry run couldn't find console scripts like diff-quality. Switch all three jobs to the same pattern: in-project venv with caching. Broken since Mar 16 when diff-quality was added to the mypy job without updating its venv config. --- .github/workflows/code-quality.yml | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/.github/workflows/code-quality.yml b/.github/workflows/code-quality.yml index 3d18d8f3d..fb3f2cb64 100644 --- a/.github/workflows/code-quality.yml +++ b/.github/workflows/code-quality.yml @@ -83,9 +83,17 @@ jobs: uses: snok/install-poetry@76e04a911780d5b312d89783f7b1cd627778900a # v1.4.1 with: version: 2.3.2 - virtualenvs-create: false + virtualenvs-create: true + virtualenvs-in-project: true + - name: Load cached venv + id: cached-mypy-deps + uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3 + with: + path: ./backend/.venv + key: mypy-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }} - name: Install dependencies - run: poetry install + if: steps.cached-mypy-deps.outputs.cache-hit != 'true' + run: poetry install --no-interaction - name: Run mypy if: github.event_name == 'push' run: poetry run mypy --config-file pyproject.toml airweave/ @@ -116,8 +124,16 @@ jobs: uses: snok/install-poetry@76e04a911780d5b312d89783f7b1cd627778900a # v1.4.1 with: version: 2.3.2 - virtualenvs-create: false + virtualenvs-create: true + virtualenvs-in-project: true + - name: Load cached venv + id: cached-importlint-deps + uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3 + with: + path: ./backend/.venv + key: lint-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }} - name: Install lint dependencies + if: steps.cached-importlint-deps.outputs.cache-hit != 'true' run: poetry install --only lint --no-interaction --no-root - name: Run import-linter run: poetry run lint-imports From 87ce505bc986ab66d47039bcb0f1b99e3e4c7794 Mon Sep 17 00:00:00 2001 From: Marc Rutzou Date: Mon, 20 Apr 2026 12:49:12 +0200 Subject: [PATCH 07/25] feat: add Mistral LLM adapter (#1754) * feat: add Mistral LLM adapter with Large and Small model support Adds MistralLLM provider using the native mistralai SDK with json_schema structured output and OpenAI-compatible tool/function calling. * fix: register MistralLLM in container factory provider_classes * feat: add thinking/reasoning support to Mistral adapter - Parse ThinkChunk content blocks from Magistral and Mistral Small 4 - Pass reasoning_effort param for adjustable reasoning models - Add MAGISTRAL_SMALL model, update MISTRAL_SMALL thinking config - Add 4 new tests for thinking extraction and reasoning_effort * fix: suppress thinking output when caller did not request it * refactor: address review feedback on Mistral adapter - Type `_extract_text` / `_extract_text_and_thinking` with `AssistantMessageContent | Unset | None` instead of `Any`. - Move the empty-content guard to run after text extraction so ThinkChunk-only responses (reasoning models) are caught by the same transient-retry path. - Drop the `str(raw_content)` fallback that could hand non-JSON to the parser. Co-Authored-By: Claude Opus 4.7 (1M context) --------- Co-authored-by: Claude Opus 4.7 (1M context) --- backend/airweave/adapters/llm/__init__.py | 2 + backend/airweave/adapters/llm/mistral.py | 260 +++++++++++++ backend/airweave/adapters/llm/override.py | 2 + backend/airweave/adapters/llm/registry.py | 42 ++ .../adapters/llm/tests/test_mistral.py | 361 ++++++++++++++++++ backend/airweave/core/container/factory.py | 2 + 6 files changed, 669 insertions(+) create mode 100644 backend/airweave/adapters/llm/mistral.py create mode 100644 backend/airweave/adapters/llm/tests/test_mistral.py diff --git a/backend/airweave/adapters/llm/__init__.py b/backend/airweave/adapters/llm/__init__.py index a00b5dfd2..9733d29d0 100644 --- a/backend/airweave/adapters/llm/__init__.py +++ b/backend/airweave/adapters/llm/__init__.py @@ -12,6 +12,7 @@ ) from airweave.adapters.llm.fallback import FallbackChainLLM from airweave.adapters.llm.groq import GroqLLM +from airweave.adapters.llm.mistral import MistralLLM from airweave.adapters.llm.override import create_llm_from_override from airweave.adapters.llm.registry import ( MODEL_REGISTRY, @@ -35,6 +36,7 @@ "AnthropicLLM", "CerebrasLLM", "GroqLLM", + "MistralLLM", "TogetherLLM", "FallbackChainLLM", # Exceptions diff --git a/backend/airweave/adapters/llm/mistral.py b/backend/airweave/adapters/llm/mistral.py new file mode 100644 index 000000000..68e6717a3 --- /dev/null +++ b/backend/airweave/adapters/llm/mistral.py @@ -0,0 +1,260 @@ +"""Mistral LLM implementation. + +Uses the native mistralai SDK for chat completions with json_schema +structured output and OpenAI-compatible tool/function calling. + +Supports reasoning/thinking via two mechanisms: +- Magistral models: native thinking (always-on), returned as ThinkChunk content blocks +- Mistral Small 4: adjustable reasoning via reasoning_effort parameter +""" + +import json +import time +from typing import Any, TypeVar + +from mistralai import Mistral +from mistralai.models import AssistantMessageContent +from mistralai.models.jsonschema import JSONSchema +from mistralai.models.responseformat import ResponseFormat +from mistralai.models.textchunk import TextChunk +from mistralai.models.thinkchunk import ThinkChunk +from mistralai.types.basemodel import Unset +from pydantic import BaseModel + +from airweave.adapters.llm.base import BaseLLM +from airweave.adapters.llm.exceptions import LLMTransientError +from airweave.adapters.llm.registry import LLMModelSpec +from airweave.adapters.llm.tool_response import LLMResponse, LLMToolCall +from airweave.core.config import settings + +T = TypeVar("T", bound=BaseModel) + + +class MistralLLM(BaseLLM): + """Mistral LLM provider with json_schema structured output and tool calling.""" + + def __init__( + self, + model_spec: LLMModelSpec, + max_retries: int | None = None, + ) -> None: + """Initialize the Mistral LLM client with API key validation.""" + super().__init__(model_spec, max_retries=max_retries) + + api_key = settings.MISTRAL_API_KEY + if not api_key: + raise ValueError( + "MISTRAL_API_KEY not configured. Set it in your environment or .env file." + ) + + try: + self._client = Mistral(api_key=api_key) + except Exception as e: + raise RuntimeError(f"Failed to initialize Mistral client: {e}") from e + + self._logger.debug( + f"[MistralLLM] Initialized with model={model_spec.api_model_name}, " + f"context_window={model_spec.context_window}, " + f"max_output_tokens={model_spec.max_output_tokens}" + ) + + def _prepare_schema(self, schema_json: dict[str, Any]) -> dict[str, Any]: + return self._normalize_strict_schema(schema_json) + + async def _call_api( + self, + prompt: str, + schema: type[T], + schema_json: dict[str, Any], + system_prompt: str, + thinking: bool = False, + ) -> T: + api_start = time.monotonic() + response = await self._client.chat.complete_async( + model=self._model_spec.api_model_name, + messages=[ # type: ignore[arg-type] + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ], + temperature=0.3, + response_format=ResponseFormat( # type: ignore[arg-type] + type="json_schema", + json_schema=JSONSchema( + name=schema.__name__.lower(), + strict=True, + schema_definition=schema_json, + ), + ), + max_tokens=self._model_spec.max_output_tokens, + ) + api_time = time.monotonic() - api_start + + # Empty body (including ThinkChunk-only content for reasoning models) + # is transient: retry often clears a momentary truncation on the API side. + content = _extract_text(response.choices[0].message.content) + if not content: + raise LLMTransientError( + "Mistral returned empty response content", + provider=self._name, + ) + + if response.usage: + self._logger.debug( + f"[MistralLLM] API call completed in {api_time:.2f}s, " + f"tokens: prompt={response.usage.prompt_tokens}, " + f"completion={response.usage.completion_tokens}, " + f"total={response.usage.total_tokens}" + ) + + return self._parse_json_response(content, schema) + + async def _call_api_chat( + self, + messages: list[dict], + tools: list[dict], + system_prompt: str, + thinking: bool = False, + max_tokens: int | None = None, + ) -> LLMResponse: + """Mistral tool calling with OpenAI-compatible format.""" + converted = self._prepare_messages_for_api(messages) + api_messages = [{"role": "system", "content": system_prompt}, *converted] + + # Mistral uses OpenAI-compatible tool definitions directly + strict_tools = self._prepare_tools_strict(tools) + + # Build reasoning params based on thinking config + reasoning_params: dict[str, Any] = {} + tc = self._model_spec.thinking_config + if tc and tc.param_name == "reasoning_effort": + reasoning_params[tc.param_name] = "high" if thinking else "none" + + api_start = time.monotonic() + response = await self._client.chat.complete_async( + model=self._model_spec.api_model_name, + messages=api_messages, # type: ignore[arg-type] + tools=strict_tools, # type: ignore[arg-type] + tool_choice="any", + temperature=0.3, + max_tokens=max_tokens or self._model_spec.max_output_tokens, + **reasoning_params, + ) + api_time = time.monotonic() - api_start + + choice = response.choices[0] + message = choice.message + + # Parse content — may contain thinking chunks for reasoning models + raw_content = message.content + text, thinking_text = _extract_text_and_thinking(raw_content) + + # Only surface thinking when the caller requested it + if not thinking: + thinking_text = None + + tool_calls: list[LLMToolCall] = [] + if message.tool_calls: + for tc_item in message.tool_calls: + arguments = tc_item.function.arguments + if isinstance(arguments, str): + try: + arguments = json.loads(arguments) + except json.JSONDecodeError: + arguments = {} + tool_calls.append( + LLMToolCall( + id=tc_item.id or "", + name=tc_item.function.name, + arguments=arguments, + ) + ) + + prompt_tokens = 0 + completion_tokens = 0 + if response.usage: + prompt_tokens = response.usage.prompt_tokens or 0 + completion_tokens = response.usage.completion_tokens or 0 + self._logger.debug( + f"[MistralLLM] Tool call completed in {api_time:.2f}s, " + f"tokens: prompt={prompt_tokens}, completion={completion_tokens}" + ) + + return LLMResponse( + text=text, + thinking=thinking_text, + tool_calls=tool_calls, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + + def _prepare_tools_strict(self, tools: list[dict]) -> list[dict]: + """Normalize tool parameter schemas for Mistral's json_schema strict mode.""" + strict_tools = [] + for tool in tools: + func = tool["function"] + params = self._normalize_strict_schema(func["parameters"]) + strict_tools.append( + { + "type": "function", + "function": { + "name": func["name"], + "description": func.get("description", ""), + "parameters": params, + }, + } + ) + return strict_tools + + async def close(self) -> None: + """Close the Mistral client and release resources.""" + if self._client: + # Mistral SDK uses context manager protocol (__aexit__) for cleanup + await self._client.__aexit__(None, None, None) + self._logger.debug("[MistralLLM] Client closed") + + +# ── Module-level helpers ─────────────────────────────────────────────── + + +def _extract_text(raw_content: AssistantMessageContent | Unset | None) -> str: + """Extract text from content, which may be a string or list of typed chunks.""" + if not isinstance(raw_content, (str, list)): + return "" + + if isinstance(raw_content, str): + return raw_content + + text_parts = [chunk.text for chunk in raw_content if isinstance(chunk, TextChunk)] + return "\n".join(text_parts) + + +def _extract_text_and_thinking( + raw_content: AssistantMessageContent | Unset | None, +) -> tuple[str | None, str | None]: + """Extract text and thinking from content chunks. + + Mistral reasoning models (Magistral, Mistral Small 4 with reasoning_effort) + return ThinkChunk blocks alongside TextChunk blocks in the content array. + ThinkChunk.thinking is a list of TextChunk/ReferenceChunk sub-items. + """ + if not isinstance(raw_content, (str, list)): + return None, None + + if isinstance(raw_content, str): + return raw_content or None, None + + text_parts: list[str] = [] + thinking_parts: list[str] = [] + + for chunk in raw_content: + if isinstance(chunk, ThinkChunk): + # ThinkChunk.thinking is List[Union[TextChunk, ReferenceChunk]] + for sub in chunk.thinking: + if isinstance(sub, TextChunk): + thinking_parts.append(sub.text) + elif isinstance(chunk, TextChunk): + text_parts.append(chunk.text) + + text = "\n".join(text_parts) if text_parts else None + thinking_text = "\n".join(thinking_parts) if thinking_parts else None + return text, thinking_text diff --git a/backend/airweave/adapters/llm/override.py b/backend/airweave/adapters/llm/override.py index af8a468b9..40295bae4 100644 --- a/backend/airweave/adapters/llm/override.py +++ b/backend/airweave/adapters/llm/override.py @@ -7,6 +7,7 @@ from airweave.adapters.llm.anthropic import AnthropicLLM from airweave.adapters.llm.cerebras import CerebrasLLM from airweave.adapters.llm.groq import GroqLLM +from airweave.adapters.llm.mistral import MistralLLM from airweave.adapters.llm.registry import ( LLMModel, LLMProvider, @@ -19,6 +20,7 @@ LLMProvider.ANTHROPIC: AnthropicLLM, LLMProvider.CEREBRAS: CerebrasLLM, LLMProvider.GROQ: GroqLLM, + LLMProvider.MISTRAL: MistralLLM, LLMProvider.TOGETHER: TogetherLLM, } diff --git a/backend/airweave/adapters/llm/registry.py b/backend/airweave/adapters/llm/registry.py index 01edfce60..0665b1ee1 100644 --- a/backend/airweave/adapters/llm/registry.py +++ b/backend/airweave/adapters/llm/registry.py @@ -23,6 +23,7 @@ class LLMProvider(str, Enum): GROQ = "groq" ANTHROPIC = "anthropic" TOGETHER = "together" + MISTRAL = "mistral" class LLMModel(str, Enum): @@ -45,6 +46,9 @@ class LLMModel(str, Enum): QWEN_3_5_DEDICATED = "qwen-3.5-dedicated" ZAI_GLM_5_DEDICATED = "zai-glm-5-dedicated" MINIMAX_M2_5 = "minimax-m2.5" + MISTRAL_LARGE = "mistral-large" + MISTRAL_SMALL = "mistral-small" + MAGISTRAL_SMALL = "magistral-small" @dataclass(frozen=True) @@ -205,6 +209,43 @@ class LLMModelSpec: output_price_factor=1.20, ), }, + LLMProvider.MISTRAL: { + LLMModel.MISTRAL_LARGE: LLMModelSpec( + api_model_name="mistral-large-latest", + context_window=256_000, + max_output_tokens=16_384, + required_tokenizer_type=TokenizerType.TIKTOKEN, + required_tokenizer_encoding=TokenizerEncoding.O200K_HARMONY, + thinking_config=ThinkingConfig(param_name="_noop", param_value=True), + input_price_factor=2.0, + output_price_factor=6.0, + ), + # Mistral Small 4 — adjustable reasoning via reasoning_effort + LLMModel.MISTRAL_SMALL: LLMModelSpec( + api_model_name="mistral-small-latest", + context_window=128_000, + max_output_tokens=16_384, + required_tokenizer_type=TokenizerType.TIKTOKEN, + required_tokenizer_encoding=TokenizerEncoding.O200K_HARMONY, + thinking_config=ThinkingConfig( + param_name="reasoning_effort", + param_value="high", + ), + input_price_factor=0.1, + output_price_factor=0.3, + ), + # Magistral Small — native reasoning (always-on thinking) + LLMModel.MAGISTRAL_SMALL: LLMModelSpec( + api_model_name="magistral-small-latest", + context_window=128_000, + max_output_tokens=40_000, + required_tokenizer_type=TokenizerType.TIKTOKEN, + required_tokenizer_encoding=TokenizerEncoding.O200K_HARMONY, + thinking_config=ThinkingConfig(param_name="_noop", param_value=True), + input_price_factor=0.5, + output_price_factor=1.5, + ), + }, } @@ -214,6 +255,7 @@ class LLMModelSpec: LLMProvider.GROQ: "GROQ_API_KEY", LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY", LLMProvider.TOGETHER: "TOGETHER_API_KEY", + LLMProvider.MISTRAL: "MISTRAL_API_KEY", } diff --git a/backend/airweave/adapters/llm/tests/test_mistral.py b/backend/airweave/adapters/llm/tests/test_mistral.py new file mode 100644 index 000000000..6dddc5bb1 --- /dev/null +++ b/backend/airweave/adapters/llm/tests/test_mistral.py @@ -0,0 +1,361 @@ +"""Tests for MistralLLM — mock the SDK client, not the network.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from mistralai.models.textchunk import TextChunk +from mistralai.models.thinkchunk import ThinkChunk +from pydantic import BaseModel + +from airweave.adapters.llm.exceptions import LLMProviderExhaustedError +from airweave.adapters.llm.mistral import MistralLLM +from airweave.adapters.llm.registry import LLMModelSpec, ThinkingConfig +from airweave.adapters.tokenizer.registry import TokenizerEncoding, TokenizerType + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_spec( + thinking_param: str = "_noop", + thinking_value: str | bool = True, +) -> LLMModelSpec: + return LLMModelSpec( + api_model_name="mistral-large-latest", + context_window=256_000, + max_output_tokens=16_384, + required_tokenizer_type=TokenizerType.TIKTOKEN, + required_tokenizer_encoding=TokenizerEncoding.O200K_HARMONY, + thinking_config=ThinkingConfig( + param_name=thinking_param, + param_value=thinking_value, + ), + ) + + +class _DummyOutput(BaseModel): + key: str + + +def _mock_response( + content: str | list | None = '{"key": "value"}', + tool_calls: list | None = None, + prompt_tokens: int = 100, + completion_tokens: int = 50, + total_tokens: int = 150, +) -> MagicMock: + """Build a mock mimicking the Mistral SDK ChatCompletionResponse.""" + mock_choice = MagicMock() + mock_choice.message.content = content + mock_choice.message.tool_calls = tool_calls + + mock_resp = MagicMock() + mock_resp.choices = [mock_choice] + mock_resp.usage = MagicMock( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ) + return mock_resp + + +@pytest.fixture +def mistral_llm(): + """Instantiate MistralLLM with a patched settings object.""" + with patch("airweave.adapters.llm.mistral.settings") as mock_settings: + mock_settings.MISTRAL_API_KEY = "test-key" + llm = MistralLLM(model_spec=_make_spec(), max_retries=0) + yield llm + + +@pytest.fixture +def mistral_llm_reasoning(): + """MistralLLM configured with reasoning_effort thinking config.""" + with patch("airweave.adapters.llm.mistral.settings") as mock_settings: + mock_settings.MISTRAL_API_KEY = "test-key" + llm = MistralLLM( + model_spec=_make_spec( + thinking_param="reasoning_effort", + thinking_value="high", + ), + max_retries=0, + ) + yield llm + + +# ═══════════════════════════════════════════════════════════════════════════ +# structured_output tests +# ═══════════════════════════════════════════════════════════════════════════ + + +@pytest.mark.asyncio +async def test_structured_output_returns_parsed(mistral_llm: MistralLLM) -> None: + """_call_api parses JSON content into the Pydantic model.""" + mistral_llm._client.chat.complete_async = AsyncMock( + return_value=_mock_response(content='{"key": "hello"}') + ) + + result = await mistral_llm.structured_output( + prompt="test prompt", + schema=_DummyOutput, + system_prompt="sys", + ) + + assert isinstance(result, _DummyOutput) + assert result.key == "hello" + + +@pytest.mark.asyncio +async def test_empty_response_raises_transient(mistral_llm: MistralLLM) -> None: + """Empty content from the API raises LLMProviderExhaustedError (wrapping transient).""" + mistral_llm._client.chat.complete_async = AsyncMock( + return_value=_mock_response(content=None) + ) + + with pytest.raises(LLMProviderExhaustedError, match="empty response"): + await mistral_llm.structured_output( + prompt="test prompt", + schema=_DummyOutput, + system_prompt="sys", + ) + + +@pytest.mark.asyncio +async def test_structured_output_uses_json_schema(mistral_llm: MistralLLM) -> None: + """_call_api sends json_schema response_format.""" + mock_create = AsyncMock( + return_value=_mock_response(content='{"key": "v"}') + ) + mistral_llm._client.chat.complete_async = mock_create + + await mistral_llm.structured_output( + prompt="test", + schema=_DummyOutput, + system_prompt="sys", + ) + + call_kwargs = mock_create.call_args.kwargs + rf = call_kwargs["response_format"] + assert rf.type == "json_schema" + assert rf.json_schema.strict is True + + +# ═══════════════════════════════════════════════════════════════════════════ +# chat tests +# ═══════════════════════════════════════════════════════════════════════════ + + +@pytest.mark.asyncio +async def test_chat_returns_tool_calls(mistral_llm: MistralLLM) -> None: + """chat() extracts tool_calls from the response.""" + mock_tc = MagicMock() + mock_tc.id = "tc-1" + mock_tc.function.name = "search" + mock_tc.function.arguments = '{"query": "hello"}' + + mistral_llm._client.chat.complete_async = AsyncMock( + return_value=_mock_response(content="text", tool_calls=[mock_tc]) + ) + + result = await mistral_llm.chat( + messages=[{"role": "user", "content": "hi"}], + tools=[ + { + "type": "function", + "function": { + "name": "search", + "description": "Search", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + system_prompt="sys", + ) + + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].name == "search" + assert result.tool_calls[0].arguments == {"query": "hello"} + assert result.prompt_tokens == 100 + + +@pytest.mark.asyncio +async def test_chat_dict_arguments(mistral_llm: MistralLLM) -> None: + """chat() handles arguments already returned as dict (not stringified).""" + mock_tc = MagicMock() + mock_tc.id = "tc-2" + mock_tc.function.name = "lookup" + mock_tc.function.arguments = {"id": 42} + + mistral_llm._client.chat.complete_async = AsyncMock( + return_value=_mock_response(content=None, tool_calls=[mock_tc]) + ) + + result = await mistral_llm.chat( + messages=[{"role": "user", "content": "hi"}], + tools=[ + { + "type": "function", + "function": { + "name": "lookup", + "description": "Lookup", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + system_prompt="sys", + ) + + assert result.tool_calls[0].arguments == {"id": 42} + + +@pytest.mark.asyncio +async def test_chat_uses_tool_choice_any(mistral_llm: MistralLLM) -> None: + """chat() passes tool_choice='any' to force tool usage.""" + mock_create = AsyncMock( + return_value=_mock_response(content=None, tool_calls=None) + ) + mistral_llm._client.chat.complete_async = mock_create + + await mistral_llm.chat( + messages=[{"role": "user", "content": "hi"}], + tools=[ + { + "type": "function", + "function": { + "name": "noop", + "description": "", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + system_prompt="sys", + ) + + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs["tool_choice"] == "any" + + +# ═══════════════════════════════════════════════════════════════════════════ +# thinking/reasoning tests +# ═══════════════════════════════════════════════════════════════════════════ + + +@pytest.mark.asyncio +async def test_chat_extracts_thinking_from_content_chunks(mistral_llm: MistralLLM) -> None: + """chat() extracts thinking from ThinkChunk content blocks.""" + # Simulate Magistral-style response with thinking + text chunks + think_chunk = ThinkChunk( + thinking=[TextChunk(text="Let me reason about this...")], + type="thinking", + ) + text_chunk = TextChunk(text="The answer is 42", type="text") + + mistral_llm._client.chat.complete_async = AsyncMock( + return_value=_mock_response(content=[think_chunk, text_chunk], tool_calls=None) + ) + + result = await mistral_llm.chat( + messages=[{"role": "user", "content": "hi"}], + tools=[ + { + "type": "function", + "function": { + "name": "noop", + "description": "", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + system_prompt="sys", + thinking=True, + ) + + assert result.thinking == "Let me reason about this..." + assert result.text == "The answer is 42" + + +@pytest.mark.asyncio +async def test_chat_no_thinking_returns_none(mistral_llm: MistralLLM) -> None: + """chat() returns thinking=None when no ThinkChunk is present.""" + mistral_llm._client.chat.complete_async = AsyncMock( + return_value=_mock_response(content="plain text", tool_calls=None) + ) + + result = await mistral_llm.chat( + messages=[{"role": "user", "content": "hi"}], + tools=[ + { + "type": "function", + "function": { + "name": "noop", + "description": "", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + system_prompt="sys", + ) + + assert result.thinking is None + assert result.text == "plain text" + + +@pytest.mark.asyncio +async def test_chat_reasoning_effort_passed(mistral_llm_reasoning: MistralLLM) -> None: + """chat(thinking=True) passes reasoning_effort='high' for Small 4 models.""" + mock_create = AsyncMock( + return_value=_mock_response(content=None, tool_calls=None) + ) + mistral_llm_reasoning._client.chat.complete_async = mock_create + + await mistral_llm_reasoning.chat( + messages=[{"role": "user", "content": "hi"}], + tools=[ + { + "type": "function", + "function": { + "name": "noop", + "description": "", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + system_prompt="sys", + thinking=True, + ) + + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs["reasoning_effort"] == "high" + + +@pytest.mark.asyncio +async def test_chat_reasoning_effort_none_when_not_thinking( + mistral_llm_reasoning: MistralLLM, +) -> None: + """chat(thinking=False) passes reasoning_effort='none' for Small 4 models.""" + mock_create = AsyncMock( + return_value=_mock_response(content=None, tool_calls=None) + ) + mistral_llm_reasoning._client.chat.complete_async = mock_create + + await mistral_llm_reasoning.chat( + messages=[{"role": "user", "content": "hi"}], + tools=[ + { + "type": "function", + "function": { + "name": "noop", + "description": "", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + system_prompt="sys", + thinking=False, + ) + + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs["reasoning_effort"] == "none" diff --git a/backend/airweave/core/container/factory.py b/backend/airweave/core/container/factory.py index 9fe45e4cd..2d6602f0a 100644 --- a/backend/airweave/core/container/factory.py +++ b/backend/airweave/core/container/factory.py @@ -24,6 +24,7 @@ from airweave.adapters.llm.cerebras import CerebrasLLM from airweave.adapters.llm.fallback import FallbackChainLLM from airweave.adapters.llm.groq import GroqLLM +from airweave.adapters.llm.mistral import MistralLLM from airweave.adapters.llm.registry import ( PROVIDER_API_KEY_SETTINGS, LLMProvider, @@ -1195,6 +1196,7 @@ def _build_llm_chain( LLMProvider.ANTHROPIC: AnthropicLLM, LLMProvider.CEREBRAS: CerebrasLLM, LLMProvider.GROQ: GroqLLM, + LLMProvider.MISTRAL: MistralLLM, LLMProvider.TOGETHER: TogetherLLM, } From 6f8c05e3706a1a5190b21e16ca3f5c64ebb98d23 Mon Sep 17 00:00:00 2001 From: Hidde Beydals Date: Wed, 22 Apr 2026 13:08:10 +0200 Subject: [PATCH 08/25] fix: cancel path for pending sync jobs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The state machine forbids PENDING → CANCELLING, but the cancel flow assumed every job passed through CANCELLING on the way to CANCELLED. A PENDING job hitting the API `cancel_job` endpoint would attempt an invalid transition. Split the cancel path by current status: PENDING jobs now go straight to CANCELLED, while RUNNING jobs keep the two-phase CANCELLING → CANCELLED flow. The orchestrator and workflow both attempt the same transitions defensively — the orchestrator acts eagerly inside the activity, the workflow retries via a shielded activity as a safety net. Duplicate arrivals are harmless because the state machine treats same-state writes as no-ops. Shielded (cancel-path) transition failures are now logged at debug rather than warning, since the orchestrator already handled them. The CANCELLED transition failure in the orchestrator keeps warning level — it means the job reached an unexpected terminal state (COMPLETED/FAILED) during cancellation. --- .../domains/sync_pipeline/orchestrator.py | 48 +++++++++++++++---- .../tests/test_orchestrator_coverage.py | 34 +++++++++++-- backend/airweave/domains/syncs/service.py | 25 ++++++++-- .../domains/syncs/tests/test_service.py | 28 ++++++++++- .../tests/test_transition_sync_job.py | 24 +++++++++- .../activities/transition_sync_job.py | 7 +-- .../workflows/run_source_connection.py | 18 ++++--- 7 files changed, 155 insertions(+), 29 deletions(-) diff --git a/backend/airweave/domains/sync_pipeline/orchestrator.py b/backend/airweave/domains/sync_pipeline/orchestrator.py index 4a2f34de5..4796e4cf6 100644 --- a/backend/airweave/domains/sync_pipeline/orchestrator.py +++ b/backend/airweave/domains/sync_pipeline/orchestrator.py @@ -23,7 +23,7 @@ from airweave.domains.sync_pipeline.worker_pool import AsyncWorkerPool from airweave.domains.syncs.cursors.service import SyncCursorService from airweave.domains.syncs.jobs.protocols import SyncJobStateMachineProtocol -from airweave.domains.syncs.jobs.types import LifecycleData +from airweave.domains.syncs.jobs.types import InvalidTransitionError, LifecycleData from airweave.domains.syncs.protocols import SyncStateMachineProtocol from airweave.domains.temporal.metrics import worker_metrics from airweave.domains.usage.exceptions import ( @@ -714,21 +714,49 @@ async def _handle_cancellation(self) -> None: """Centralized cancellation handler - explicit and immediate.""" self.sync_context.logger.info("Handling cancellation...") - # 1. Cancel all pending tasks IMMEDIATELY + # Cancel all pending tasks immediately if self.worker_pool: await self.worker_pool.cancel_all() - # 2. Cancel stream to stop producer + # Cancel stream to stop producer await self.stream.cancel() - await self._state_machine.transition( - sync_job_id=self.sync_context.sync_job.id, - target=SyncJobStatus.CANCELLED, - ctx=self.sync_context, - lifecycle_data=self._lifecycle_data, - ) + # Transition through CANCELLING → CANCELLED. + # RUNNING → CANCELLING is required by the state machine. + # If still PENDING (cancellation before _start_sync completed), + # PENDING → CANCELLING is invalid, so fall through to direct CANCELLED. + # + # The workflow will also attempt these transitions via + # TransitionSyncJobActivity once the CancelledError propagates. + # That redundancy is intentional — it guards against the activity + # being killed before the error reaches the workflow. + try: + await self._state_machine.transition( + sync_job_id=self.sync_context.sync_job.id, + target=SyncJobStatus.CANCELLING, + ctx=self.sync_context, + lifecycle_data=self._lifecycle_data, + ) + except InvalidTransitionError as exc: + self.sync_context.logger.debug( + "Skipped CANCELLING transition", + current_state=exc.current.value, + ) + + try: + await self._state_machine.transition( + sync_job_id=self.sync_context.sync_job.id, + target=SyncJobStatus.CANCELLED, + ctx=self.sync_context, + lifecycle_data=self._lifecycle_data, + ) + except InvalidTransitionError as exc: + self.sync_context.logger.warning( + "Skipped CANCELLED transition — job in unexpected terminal state", + current_state=exc.current.value, + ) - # 4. Track sync cancelled + # Track sync cancelled if not self.sync_context.sync_job.started_at: # This can happen if cancellation occurs during _start_sync before # the job status is updated with started_at diff --git a/backend/airweave/domains/sync_pipeline/tests/test_orchestrator_coverage.py b/backend/airweave/domains/sync_pipeline/tests/test_orchestrator_coverage.py index 2044b5b2d..2bf4a0a24 100644 --- a/backend/airweave/domains/sync_pipeline/tests/test_orchestrator_coverage.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_orchestrator_coverage.py @@ -1,6 +1,5 @@ """Coverage tests for SyncOrchestrator — missing state_machine.transition lines.""" -import asyncio from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch from uuid import UUID @@ -12,6 +11,7 @@ from airweave.domains.sync_pipeline.contexts.sync import SyncContext from airweave.domains.sync_pipeline.orchestrator import SyncOrchestrator from airweave.domains.sync_pipeline.pipeline.entity_tracker import SyncStats +from airweave.domains.syncs.jobs.types import InvalidTransitionError MODULE = "airweave.domains.sync_pipeline.orchestrator" @@ -184,11 +184,37 @@ async def test_handle_sync_failure_calls_state_machine_transition(): @pytest.mark.unit -async def test_handle_cancellation_calls_state_machine_transition(): - """_handle_cancellation calls state_machine.transition with CANCELLED.""" +async def test_handle_cancellation_transitions_through_cancelling(): + """_handle_cancellation transitions RUNNING → CANCELLING → CANCELLED.""" orch, sm = _make_orchestrator() with patch(f"{MODULE}.business_events"): await orch._handle_cancellation() - assert any(c["target"] == SyncJobStatus.CANCELLED for c in sm.calls) + targets = [c["target"] for c in sm.calls] + assert targets == [SyncJobStatus.CANCELLING, SyncJobStatus.CANCELLED] + + +@pytest.mark.unit +async def test_handle_cancellation_pending_falls_through_to_cancelled(): + """When CANCELLING raises InvalidTransitionError (PENDING), go directly to CANCELLED.""" + + class PendingStateMachine: + def __init__(self): + self.calls: list[dict] = [] + + async def transition(self, **kwargs): + self.calls.append(kwargs) + if kwargs["target"] == SyncJobStatus.CANCELLING: + raise InvalidTransitionError( + SyncJobStatus.PENDING, SyncJobStatus.CANCELLING + ) + return MagicMock(applied=True) + + orch, sm = _make_orchestrator(state_machine=PendingStateMachine()) + + with patch(f"{MODULE}.business_events"): + await orch._handle_cancellation() + + targets = [c["target"] for c in sm.calls] + assert targets == [SyncJobStatus.CANCELLING, SyncJobStatus.CANCELLED] diff --git a/backend/airweave/domains/syncs/service.py b/backend/airweave/domains/syncs/service.py index 024570f2d..a1589c07a 100644 --- a/backend/airweave/domains/syncs/service.py +++ b/backend/airweave/domains/syncs/service.py @@ -284,10 +284,12 @@ async def cancel_job( job_id: UUID, ctx: ApiContext, ) -> schemas.SyncJob: - """Cancel a running sync job. + """Cancel a pending or running sync job. - Transitions to CANCELLING, sends cancel to Temporal, and handles - edge cases (workflow not found, Temporal failure with one retry). + PENDING jobs are transitioned directly to CANCELLED (the CANCELLING + intermediate state is only valid from RUNNING). RUNNING jobs go + through CANCELLING and rely on the Temporal workflow for the final + CANCELLED transition. """ sync_job = await self._sync_job_repo.get(db, job_id, ctx) if not sync_job: @@ -299,6 +301,23 @@ async def cancel_job( detail=f"Cannot cancel job in {sync_job.status} state", ) + if sync_job.status == SyncJobStatus.PENDING: + # PENDING → CANCELLED directly (PENDING → CANCELLING is invalid) + await self._job_state_machine.transition( + sync_job_id=job_id, target=SyncJobStatus.CANCELLED, ctx=ctx + ) + # Best-effort workflow cleanup — workflow may not exist yet for + # PENDING jobs, but if it does we need to stop it. + cancel_result = await self._cancel_temporal_workflow_with_retry(job_id, ctx) + if not cancel_result["success"] and cancel_result["workflow_found"]: + ctx.logger.warning( + f"Temporal cancel failed for PENDING job {job_id}, " + "workflow may continue running against cancelled job", + ) + await db.refresh(sync_job) + return schemas.SyncJob.model_validate(sync_job, from_attributes=True) + + # RUNNING → CANCELLING; workflow handles CANCELLING → CANCELLED await self._job_state_machine.transition( sync_job_id=job_id, target=SyncJobStatus.CANCELLING, ctx=ctx ) diff --git a/backend/airweave/domains/syncs/tests/test_service.py b/backend/airweave/domains/syncs/tests/test_service.py index c55f680f9..9632f49cf 100644 --- a/backend/airweave/domains/syncs/tests/test_service.py +++ b/backend/airweave/domains/syncs/tests/test_service.py @@ -550,12 +550,36 @@ async def test_cancel_job_success(): @pytest.mark.asyncio -async def test_cancel_job_workflow_not_found_marks_cancelled(): +async def test_cancel_pending_job_transitions_directly_to_cancelled(): job_id = uuid4() job_repo = AsyncMock() job = _orm_sync_job(job_id=job_id, status=SyncJobStatus.PENDING) job_repo.get.return_value = job + temporal = AsyncMock() + job_sm = AsyncMock() + db = AsyncMock() + + svc = _build_svc( + sync_job_repo=job_repo, + temporal_workflow_service=temporal, + job_state_machine=job_sm, + ) + await svc.cancel_job(db, job_id=job_id, ctx=_mock_ctx()) + + job_sm.transition.assert_awaited_once() + assert job_sm.transition.call_args.kwargs["target"] == SyncJobStatus.CANCELLED + temporal.cancel_sync_job_workflow.assert_awaited_once() + db.refresh.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_cancel_job_workflow_not_found_marks_cancelled(): + job_id = uuid4() + job_repo = AsyncMock() + job = _orm_sync_job(job_id=job_id, status=SyncJobStatus.RUNNING) + job_repo.get.return_value = job + temporal = AsyncMock() temporal.cancel_sync_job_workflow.return_value = { "success": True, @@ -573,6 +597,8 @@ async def test_cancel_job_workflow_not_found_marks_cancelled(): await svc.cancel_job(db, job_id=job_id, ctx=_mock_ctx()) assert job_sm.transition.await_count == 2 + first_call = job_sm.transition.call_args_list[0].kwargs + assert first_call["target"] == SyncJobStatus.CANCELLING second_call = job_sm.transition.call_args_list[1].kwargs assert second_call["target"] == SyncJobStatus.CANCELLED diff --git a/backend/airweave/domains/temporal/activities/tests/test_transition_sync_job.py b/backend/airweave/domains/temporal/activities/tests/test_transition_sync_job.py index 706d43ce1..7d0ee6a9f 100644 --- a/backend/airweave/domains/temporal/activities/tests/test_transition_sync_job.py +++ b/backend/airweave/domains/temporal/activities/tests/test_transition_sync_job.py @@ -1,6 +1,6 @@ """Tests for TransitionSyncJobActivity.""" -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock from uuid import UUID import pytest @@ -8,7 +8,6 @@ from airweave.core.shared_models import SyncJobStatus from airweave.domains.temporal.activities.transition_sync_job import ( TransitionSyncJobActivity, - _STATUS_MAP, ) from .conftest import ORG_ID, SYNC_ID, SYNC_JOB_ID, make_ctx_dict @@ -100,6 +99,27 @@ async def test_failed_transition_with_error(activity, state_machine): assert call["error"] == "Something went wrong" +@pytest.mark.unit +async def test_cancelling_transition(activity, state_machine): + lifecycle = { + "organization_id": ORG_ID, + "sync_id": SYNC_ID, + "sync_job_id": SYNC_JOB_ID, + "collection_id": "00000000-0000-0000-0000-000000000030", + "source_connection_id": "00000000-0000-0000-0000-000000000050", + } + + await activity.run( + transition="cancelling", + sync_job_id=SYNC_JOB_ID, + ctx_dict=make_ctx_dict(), + lifecycle_data=lifecycle, + ) + + call = state_machine.calls[0] + assert call["target"] == SyncJobStatus.CANCELLING + + @pytest.mark.unit async def test_cancelled_transition(activity, state_machine): lifecycle = { diff --git a/backend/airweave/domains/temporal/activities/transition_sync_job.py b/backend/airweave/domains/temporal/activities/transition_sync_job.py index cf8241259..1dad9f77e 100644 --- a/backend/airweave/domains/temporal/activities/transition_sync_job.py +++ b/backend/airweave/domains/temporal/activities/transition_sync_job.py @@ -1,6 +1,6 @@ """Transition sync job activity — thin Temporal wrapper over SyncJobStateMachine. -Called by the workflow for COMPLETED, FAILED, and CANCELLED transitions. +Called by the workflow for CANCELLING, COMPLETED, FAILED, and CANCELLED transitions. Deserializes Temporal payloads and delegates to the state machine. """ @@ -19,6 +19,7 @@ from airweave.domains.temporal.activities.context import build_activity_context _STATUS_MAP: dict[str, SyncJobStatus] = { + "cancelling": SyncJobStatus.CANCELLING, "completed": SyncJobStatus.COMPLETED, "failed": SyncJobStatus.FAILED, "cancelled": SyncJobStatus.CANCELLED, @@ -46,10 +47,10 @@ async def run( stats_dict: Optional[Dict[str, Any]] = None, timestamp_iso: Optional[str] = None, ) -> None: - """Execute a terminal state transition via the state machine. + """Execute a state transition via the state machine. Args: - transition: One of "completed", "failed", "cancelled". + transition: One of "cancelling", "completed", "failed", "cancelled". sync_job_id: The sync job UUID as a string. ctx_dict: Serialized context dict (contains organization). lifecycle_data: Fields for building LifecycleData. diff --git a/backend/airweave/domains/temporal/workflows/run_source_connection.py b/backend/airweave/domains/temporal/workflows/run_source_connection.py index 1d33ae422..d79a50412 100644 --- a/backend/airweave/domains/temporal/workflows/run_source_connection.py +++ b/backend/airweave/domains/temporal/workflows/run_source_connection.py @@ -1,6 +1,6 @@ """Run source connection workflow — the sync state machine. -Owns terminal state transitions (COMPLETED, FAILED, CANCELLED) via +Owns state transitions (CANCELLING, COMPLETED, FAILED, CANCELLED) via TransitionSyncJobActivity. RUNNING is published by the orchestrator because only it knows when sync work actually begins. """ @@ -95,7 +95,9 @@ async def run( ) except BaseException as e: if is_cancelled_exception(e): - await self._transition("cancelled", sync_job_dict, ctx_dict, lifecycle, shield=True) + cancel_args = (sync_job_dict, ctx_dict, lifecycle) + await self._transition("cancelling", *cancel_args, shield=True) + await self._transition("cancelled", *cancel_args, shield=True) raise if self._is_orphaned_sync_error(e): reason = self._extract_orphaned_reason(e) @@ -203,7 +205,12 @@ async def _transition( error: Optional[str] = None, shield: bool = False, ) -> None: - """Call TransitionSyncJobActivity for a terminal state change.""" + """Call TransitionSyncJobActivity for a state change. + + Shielded transitions (cancel path) are best-effort retries — the + orchestrator already performed these transitions, so failures here + are expected and logged at debug level. + """ timestamp = workflow.now().replace(tzinfo=None).isoformat() coro = workflow.execute_activity( transition_sync_job_activity, @@ -227,9 +234,8 @@ async def _transition( try: await (asyncio.shield(coro) if shield else coro) except Exception: - workflow.logger.warning( - f"Failed to transition sync job {sync_job_dict.get('id')} to {transition}" - ) + log = workflow.logger.debug if shield else workflow.logger.warning + log(f"Failed to transition sync job {sync_job_dict.get('id')} to {transition}") # ------------------------------------------------------------------ # Self-destruct orphaned sync From 7c0e66c1d67561d709971596d35b2b3edf58a725 Mon Sep 17 00:00:00 2001 From: Hidde Beydals Date: Wed, 22 Apr 2026 15:25:22 +0200 Subject: [PATCH 09/25] fix(ci): harden scan-and-attest resilience MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The `scan-and-attest` job used implicit `if: success()` on every step after Trivy, so a critical vulnerability finding would skip SBOM generation, Grype scanning, attestations, and artifact uploads entirely. Chain each step off its real dependencies with `!cancelled()` and outcome checks instead, so the pipeline produces as many supply-chain artifacts as possible regardless of individual scan failures. Also fix the Trivy SARIF upload condition — the previous `steps.trivy.outputs.sarif` reference checked an output the action never exposes, silently skipping every upload. Use `steps.trivy.outcome != 'skipped'` to match the file-based contract documented upstream. Other changes: add Grype SARIF upload to Code Scanning, add `only-fixed` to Grype so unfixable criticals do not gate the build, and let `package-vespa` run when scans fail since it is independent of container vulnerability results. --- .github/workflows/build.yml | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b3c69cf17..ec7ff3cd6 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -167,6 +167,7 @@ jobs: images: ${{ env.REGISTRY }}/${{ github.repository }}-${{ matrix.component }} - name: Trivy image scan + id: trivy uses: aquasecurity/trivy-action@97e0b3872f55f89b95b2f65b3dbab56962816478 # v0.34.2 with: image-ref: ${{ env.REGISTRY }}/${{ github.repository }}-${{ matrix.component }}:${{ steps.meta.outputs.version }} @@ -179,13 +180,15 @@ jobs: limit-severities-for-sarif: true - name: Upload Trivy SARIF - if: always() + if: ${{ !cancelled() && steps.trivy.outcome != 'skipped' }} uses: github/codeql-action/upload-sarif@c793b717bc78562f491db7b0e93a3a178b099162 # v4.32.5 with: sarif_file: trivy-${{ matrix.component }}.sarif category: trivy-${{ matrix.component }} - name: Generate SBOM + id: sbom + if: ${{ !cancelled() }} uses: anchore/sbom-action@17ae1740179002c89186b61233e0f892c3118b11 # v0.23.0 with: image: ${{ env.REGISTRY }}/${{ github.repository }}-${{ matrix.component }}:${{ steps.meta.outputs.version }} @@ -193,13 +196,24 @@ jobs: output-file: sbom-${{ matrix.component }}.json - name: Vulnerability scan on SBOM + id: grype + if: ${{ !cancelled() && steps.sbom.outcome == 'success' }} uses: anchore/scan-action@7037fa011853d5a11690026fb85feee79f4c946c # v7.3.2 with: sbom: sbom-${{ matrix.component }}.json fail-build: "true" severity-cutoff: critical + only-fixed: "true" + + - name: Upload Grype SARIF + if: ${{ always() && steps.grype.outputs.sarif }} + uses: github/codeql-action/upload-sarif@c793b717bc78562f491db7b0e93a3a178b099162 # v4.32.5 + with: + sarif_file: ${{ steps.grype.outputs.sarif }} + category: grype-${{ matrix.component }} - name: Get image digest + if: ${{ !cancelled() }} id: digest run: | IMAGE="${{ env.REGISTRY }}/${{ github.repository }}-${{ matrix.component }}" @@ -209,6 +223,7 @@ jobs: echo "Image digest: ${DIGEST}" - name: Attest SBOM + if: ${{ !cancelled() && steps.digest.outcome == 'success' && steps.sbom.outcome == 'success' }} uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0 with: subject-name: ${{ env.REGISTRY }}/${{ github.repository }}-${{ matrix.component }} @@ -217,6 +232,7 @@ jobs: push-to-registry: true - name: Attest build provenance + if: ${{ !cancelled() && steps.digest.outcome == 'success' }} uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0 with: subject-name: ${{ env.REGISTRY }}/${{ github.repository }}-${{ matrix.component }} @@ -224,6 +240,7 @@ jobs: push-to-registry: true - name: Upload SBOM artifact + if: ${{ !cancelled() && steps.sbom.outcome == 'success' }} uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 with: name: sbom-${{ matrix.component }} @@ -233,6 +250,7 @@ jobs: # Package Vespa application and attach to the GitHub Release package-vespa: needs: scan-and-attest + if: ${{ !cancelled() && needs.scan-and-attest.result != 'skipped' }} runs-on: ubuntu-latest permissions: contents: write From 9eccace9755ba40e128db3d8ee37e82be79f6b89 Mon Sep 17 00:00:00 2001 From: Daan Manneke Date: Wed, 22 Apr 2026 09:31:31 +0200 Subject: [PATCH 10/25] fix(sharepoint_online): correct SP site group expansion - Fix regression where _expand_sp_site_groups and _fetch_sp_group_viewers silently early-return when self._site_url is empty, breaking the multi-site sync mode (getAllSites). Track SP groups per site URL. - Drop role principals (PrincipalType=16, rolemanager|spo-grid-all-users/...) instead of emitting them as fake users. - Emit nested Entra groups (PrincipalType=4, federateddirectoryclaimprovider) as group-to-group memberships (entra:{guid}) rather than flattening to the group's email. - Parse LoginName via a strict regex for the i:0#.f|membership| pattern only, removing the naive split("|")[-1] that accepted role claims. Cursor format for tracked_sp_groups migrated from List[str] to Dict[site_url, List[str]]; legacy list values are discarded defensively and re-collected on next full sync. Incremental and targeted sync paths still pass an empty site_url (noted as follow-up work). --- .../sources/sharepoint_online/source.py | 309 ++++++++++++---- .../test_sharepoint_online_group_expansion.py | 340 ++++++++++++++++++ 2 files changed, 581 insertions(+), 68 deletions(-) create mode 100644 backend/tests/unit/platform/sources/test_sharepoint_online_group_expansion.py diff --git a/backend/airweave/platform/sources/sharepoint_online/source.py b/backend/airweave/platform/sources/sharepoint_online/source.py index c5eb96336..c5d7eb4aa 100644 --- a/backend/airweave/platform/sources/sharepoint_online/source.py +++ b/backend/airweave/platform/sources/sharepoint_online/source.py @@ -27,8 +27,9 @@ from __future__ import annotations import asyncio +import re from dataclasses import dataclass -from typing import Any, AsyncGenerator, Callable, Dict, List, Optional +from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Set, Tuple from urllib.parse import urlparse import httpx @@ -105,8 +106,10 @@ class SharePointOnlineBase(BaseSource): _site_url: str _include_personal_sites: bool _include_pages: bool - _item_level_entra_groups: set - _item_level_sp_groups: set + _item_level_entra_groups: Set[str] + # Site-scoped SP group tracking: {site_url: {sp_group_name, ...}} + # Keyed by normalized site URL so multi-site syncs can expand SP groups per site. + _item_level_sp_groups: Dict[str, Set[str]] def _init_common(self, config: SharePointOnlineConfig) -> None: """Initialize fields shared by both OAuth and client-credentials sources.""" @@ -114,7 +117,7 @@ def _init_common(self, config: SharePointOnlineConfig) -> None: self._include_personal_sites = config.include_personal_sites self._include_pages = config.include_pages self._item_level_entra_groups = set() - self._item_level_sp_groups = set() + self._item_level_sp_groups = {} # -- Auth hooks (subclasses override) -- @@ -127,7 +130,19 @@ async def _handle_401(self) -> str: raise NotImplementedError def _make_sp_token_provider(self) -> Optional[Callable]: - """Create an async callable returning a SharePoint REST API token, or None.""" + """Create an async callable returning a SharePoint REST API token, or None. + + Legacy single-site provider that derives hostname from ``self._site_url``. + New code should use ``_make_sp_token_provider_for_site`` instead. + """ + raise NotImplementedError + + def _make_sp_token_provider_for_site(self, site_url: str) -> Optional[Callable]: + """Create an SP REST token provider scoped to a specific site URL. + + Subclasses must override. Returns None if a token cannot be obtained + for the given site (e.g., malformed URL). + """ raise NotImplementedError async def _get_download_auth(self, url: str) -> Any: @@ -191,16 +206,101 @@ def _derive_sp_hostname(self) -> Optional[str]: parsed = urlparse(self._site_url) return parsed.netloc or None - def _track_entity_groups(self, entity: BaseEntity) -> None: - """Track Entra ID and SP site groups found in entity permissions.""" + @staticmethod + def _normalize_site_url(site_url: str) -> str: + """Normalize a site URL for use as a dict key (strip trailing slash).""" + return (site_url or "").rstrip("/") + + def _track_entity_groups(self, entity: BaseEntity, site_url: str = "") -> None: + """Track Entra ID and SP site groups found in entity permissions. + + Args: + entity: The entity whose access viewers to inspect. + site_url: The site URL this entity belongs to. SP groups are keyed + by site URL so multi-site syncs can expand SP groups per-site. + May be empty for paths that lack site context (incremental / + targeted single-file); those SP groups won't expand. + """ if not hasattr(entity, "access") or entity.access is None: return + norm_site = self._normalize_site_url(site_url) for viewer in entity.access.viewers or []: if viewer.startswith("group:entra:"): group_id = viewer[len("group:") :] self._item_level_entra_groups.add(group_id) elif viewer.startswith("group:sp:"): - self._item_level_sp_groups.add(viewer[len("group:") :]) + sp_name = viewer[len("group:") :] + self._item_level_sp_groups.setdefault(norm_site, set()).add(sp_name) + + # -- SP site group membership parsing -- + + # Match regular user logins: "i:0#.f|membership|" + _MEMBERSHIP_LOGIN_RE = re.compile(r"^i:0#\.f\|membership\|(?P[^|]+@[^|]+)$") + # Match Entra federated group logins: "c:0o.c|federateddirectoryclaimprovider|[_o]" + _ENTRA_GROUP_LOGIN_RE = re.compile( + r"^c:0o\.c\|federateddirectoryclaimprovider\|" + r"(?P[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-" + r"[0-9a-fA-F]{4}-[0-9a-fA-F]{12})(_o)?$" + ) + + @classmethod + def _email_from_membership_login(cls, login: str) -> Optional[str]: + """Extract email from SP user LoginName if it follows the membership pattern. + + Only matches "i:0#.f|membership|". Returns None for role principals + (e.g., "c:0-.f|rolemanager|spo-grid-all-users/...") and other shapes so + we don't pollute the membership table with fake email-like strings. + """ + if not login: + return None + m = cls._MEMBERSHIP_LOGIN_RE.match(login) + if m: + return m.group("email").strip().lower() or None + return None + + @classmethod + def _parse_sp_group_member(cls, user: Dict[str, Any]) -> Optional[Tuple[str, str]]: + """Parse one entry from /_api/web/sitegroups({id})/users into (member_id, member_type). + + Returns None for entries that should not become memberships: + - Role principals (PrincipalType=16, e.g. "Everyone except external users") + - Catch-all "All" principals (PrincipalType=15) + - DistList, SPGroup, unknown types (skipped; rare in practice) + - Unparseable entries (no email for users, no GUID for groups) + + PrincipalType reference: + 1 = User + 2 = DistList + 4 = SecurityGroup (Entra group when LoginName uses + federateddirectoryclaimprovider) + 8 = SPGroup + 15 = All + 16 = RoleManager + """ + ptype = user.get("PrincipalType") + login = user.get("LoginName", "") or "" + + if ptype == 1: + email = user.get("Email") or "" + email = email.strip().lower() + if not email: + email = cls._email_from_membership_login(login) or "" + if not email: + return None + # Bare email (no "user:" prefix) matches the broker storage + # convention used by EntraGroupExpander and SP 2019 V2. + return (email, "user") + + if ptype == 4: + m = cls._ENTRA_GROUP_LOGIN_RE.match(login) + if not m: + return None + guid = m.group("guid").lower() + return (f"entra:{guid}", "group") + + # PrincipalType 2 (DistList), 8 (SPGroup), 15 (All), 16 (RoleManager), + # and unknown types are intentionally skipped. + return None # -- Browse Tree -- @@ -440,7 +540,7 @@ def _should_do_full_sync(self, cursor: SyncCursor | None) -> tuple: # -- Entity Generation -- - async def generate_entities( + async def generate_entities( # noqa: C901 self, *, cursor: SyncCursor | None = None, @@ -451,8 +551,19 @@ async def generate_entities( cursor_data = cursor.data if cursor else {} for g in cursor_data.get("tracked_entra_groups", []): self._item_level_entra_groups.add(g) - for g in cursor_data.get("tracked_sp_groups", []): - self._item_level_sp_groups.add(g) + + # tracked_sp_groups format changed from List[str] (flat names) to + # Dict[site_url, List[str]] (site-scoped). Migrate defensively. + tracked_sp = cursor_data.get("tracked_sp_groups") + if isinstance(tracked_sp, dict): + for site_url, names in tracked_sp.items(): + if isinstance(names, list): + self._item_level_sp_groups[site_url] = set(names) + elif isinstance(tracked_sp, list): + self.logger.info( + "Legacy tracked_sp_groups list format detected; discarding — " + "will re-collect on next full sync" + ) if node_selections: self.logger.info(f"Sync strategy: TARGETED ({len(node_selections)} node selections)") @@ -495,10 +606,18 @@ async def _resolve_unresolved_viewers( new_viewers.append(v) entity.access.viewers = new_viewers - async def _fetch_sp_group_viewers(self) -> List[str]: - """Fetch all SP site groups and return their viewer strings.""" - sp_token_provider = self._make_sp_token_provider() - if not sp_token_provider or not self._site_url: + async def _fetch_sp_group_viewers(self, site_url: str) -> List[str]: + """Fetch all SP site groups for a site and return their viewer strings. + + Args: + site_url: Full site URL (e.g. https://tenant.sharepoint.com/sites/X). + Required — without it we can't hit the SP REST endpoint. + """ + norm_site = self._normalize_site_url(site_url) + if not norm_site: + return [] + sp_token_provider = self._make_sp_token_provider_for_site(norm_site) + if not sp_token_provider: return [] try: token = await sp_token_provider() @@ -507,7 +626,7 @@ async def _fetch_sp_group_viewers(self) -> List[str]: "Accept": "application/json;odata=verbose", } resp = await self.http_client.get( - f"{self._site_url}/_api/web/sitegroups", + f"{norm_site}/_api/web/sitegroups", headers=headers, timeout=30.0, ) @@ -515,18 +634,19 @@ async def _fetch_sp_group_viewers(self) -> List[str]: groups = resp.json().get("d", {}).get("results", []) viewers = [] + site_bucket = self._item_level_sp_groups.setdefault(norm_site, set()) for g in groups: title = g.get("Title", "") if title: tag = f"group:sp:{title.lower().replace(' ', '_')}" viewers.append(tag) - self._item_level_sp_groups.add(tag[len("group:") :]) - self.logger.info(f"Fetched {len(viewers)} SP site groups as viewers") + site_bucket.add(tag[len("group:") :]) + self.logger.info(f"Fetched {len(viewers)} SP site groups as viewers for {norm_site}") return viewers except SourceAuthError: raise except Exception as e: - self.logger.warning(f"SP group fetch failed: {e}") + self.logger.warning(f"SP group fetch failed for {norm_site}: {e}") return [] async def _full_sync( # noqa: C901 @@ -541,6 +661,7 @@ async def _full_sync( # noqa: C901 for site_data in sites: site_id = site_data.get("id", "") + site_url = self._normalize_site_url(site_data.get("webUrl", "")) # Collect all drives for this site (single API call) all_drives = [] @@ -560,7 +681,7 @@ async def _full_sync( # noqa: C901 try: site_entity = await build_site_entity(site_data, [], access=site_access) - self._track_entity_groups(site_entity) + self._track_entity_groups(site_entity, site_url) yield site_entity entity_count += 1 @@ -574,7 +695,7 @@ async def _full_sync( # noqa: C901 self.logger.warning(f"Skipping site {site_id}: {e}") continue - sp_group_viewers = await self._fetch_sp_group_viewers() + sp_group_viewers = await self._fetch_sp_group_viewers(site_url) for drive_data in all_drives: drive_id = drive_data.get("id", "") @@ -593,7 +714,7 @@ async def _full_sync( # noqa: C901 drive_entity = await build_drive_entity( drive_data, site_id, site_breadcrumbs, access=drive_access ) - self._track_entity_groups(drive_entity) + self._track_entity_groups(drive_entity, site_url) yield drive_entity entity_count += 1 @@ -631,7 +752,7 @@ async def _full_sync( # noqa: C901 for spv in sp_group_viewers: if spv not in existing: file_entity.access.viewers.append(spv) - self._track_entity_groups(file_entity) + self._track_entity_groups(file_entity, site_url) if files: pending_files.append( @@ -697,7 +818,7 @@ async def _full_sync( # noqa: C901 page_entity = await build_page_entity( page_data, site_id, site_breadcrumbs, access=site_access ) - self._track_entity_groups(page_entity) + self._track_entity_groups(page_entity, site_url) yield page_entity entity_count += 1 except EntityProcessingError as e: @@ -718,7 +839,9 @@ async def _full_sync( # noqa: C901 full_sync_required=False, total_entities_synced=entity_count, tracked_entra_groups=list(self._item_level_entra_groups), - tracked_sp_groups=list(self._item_level_sp_groups), + tracked_sp_groups={ + site: sorted(names) for site, names in self._item_level_sp_groups.items() + }, ) self.logger.info(f"Full sync complete: {entity_count} entities") @@ -850,6 +973,7 @@ async def _targeted_sync( # noqa: C901 try: site_data = await graph_client.get_site(site_id) + targeted_site_url = self._normalize_site_url(site_data.get("webUrl", "")) # Fetch site-level permissions from first drive root targeted_site_access = None @@ -862,7 +986,7 @@ async def _targeted_sync( # noqa: C901 break site_entity = await build_site_entity(site_data, [], access=targeted_site_access) - self._track_entity_groups(site_entity) + self._track_entity_groups(site_entity, targeted_site_url) yield site_entity entity_count += 1 except SourceAuthError: @@ -1069,51 +1193,85 @@ async def _expand_entra_groups( async for membership in group_expander.expand_group(group_id): yield membership - async def _expand_sp_site_groups(self) -> AsyncGenerator[MembershipTuple, None]: - """Expand tracked SP site groups into user memberships.""" - sp_group_names = list(self._item_level_sp_groups) - if not sp_group_names or not self._site_url: - return - sp_token_provider = self._make_sp_token_provider() - if not sp_token_provider: - self.logger.warning("No SP token provider for site group expansion") + async def _expand_sp_site_groups( # noqa: C901 + self, + ) -> AsyncGenerator[MembershipTuple, None]: + """Expand tracked SP site groups into user/group memberships. + + Iterates per-site: for each site URL we've tracked SP group names against, + fetches that site's SP groups via the SharePoint REST API and resolves + their members. + + Member types emitted: + - ``user`` for real users (PrincipalType=1). Role principals like + "Everyone except external users" are skipped. + - ``group`` for Entra security groups nested inside SP groups + (PrincipalType=4 with federateddirectoryclaimprovider). The broker's + recursive group expansion resolves these to individual users at + search time. + """ + if not self._item_level_sp_groups: return - self.logger.info(f"Expanding {len(sp_group_names)} SP site groups") + total_groups = sum(len(v) for v in self._item_level_sp_groups.values()) + self.logger.info( + f"Expanding {total_groups} SP site groups across " + f"{len(self._item_level_sp_groups)} site(s)" + ) + graph_client = self._create_graph_client() - sp_groups = await graph_client.get_site_groups( - self._site_url, - sp_token_provider=sp_token_provider, - ) - sp_name_to_id = { - f"sp:{g['Title'].replace(' ', '_').lower()}": g.get("Id") - for g in sp_groups - if g.get("Title") - } - - for sp_name in sp_group_names: - sp_id = sp_name_to_id.get(sp_name) - if not sp_id: - self.logger.debug(f"SP group '{sp_name}' not found in site") + for site_url, sp_group_names in self._item_level_sp_groups.items(): + if not site_url or not sp_group_names: continue - users = await graph_client.get_site_group_users( - self._site_url, - sp_id, - sp_token_provider=sp_token_provider, - ) - for user in users: - email = user.get("Email", "") - login = user.get("LoginName", "") - if not email and login and "|" in login: - email = login.split("|")[-1] - if email: + sp_token_provider = self._make_sp_token_provider_for_site(site_url) + if not sp_token_provider: + self.logger.warning( + f"No SP token provider for site {site_url}; skipping SP group expansion" + ) + continue + + try: + sp_groups = await graph_client.get_site_groups( + site_url, sp_token_provider=sp_token_provider + ) + except Exception as e: + self.logger.warning(f"Failed to fetch SP groups for {site_url}: {e}") + continue + + sp_name_to_id = { + f"sp:{g['Title'].replace(' ', '_').lower()}": g.get("Id") + for g in sp_groups + if g.get("Title") + } + + for sp_name in sp_group_names: + sp_id = sp_name_to_id.get(sp_name) + if not sp_id: + self.logger.debug(f"SP group '{sp_name}' not found in site {site_url}") + continue + + try: + users = await graph_client.get_site_group_users( + site_url, sp_id, sp_token_provider=sp_token_provider + ) + except Exception as e: + self.logger.warning( + f"Failed to fetch users for SP group {sp_name} in {site_url}: {e}" + ) + continue + + for user in users: + parsed = self._parse_sp_group_member(user) + if parsed is None: + continue + member_id, member_type = parsed yield MembershipTuple( - member_id=email.lower(), - member_type="user", + member_id=member_id, + member_type=member_type, group_id=sp_name, - group_name=user.get("Title", sp_name), + group_name=user.get("Title") or sp_name, ) async def generate_access_control_memberships( @@ -1194,10 +1352,18 @@ async def _handle_401(self) -> str: return await self.auth.get_token() def _make_sp_token_provider(self) -> Optional[Callable]: - """Create SP token provider via OAuth scope exchange.""" - sp_scope = self._derive_sp_resource_scope() - if not sp_scope: + """Create SP token provider via OAuth scope exchange (config-site-bound).""" + return self._make_sp_token_provider_for_site(self._site_url) + + def _make_sp_token_provider_for_site(self, site_url: str) -> Optional[Callable]: + """Create SP token provider for a specific site URL via OAuth scope exchange.""" + if not site_url: return None + parsed = urlparse(site_url) + hostname = parsed.netloc + if not hostname: + return None + sp_scope = f"https://{hostname}/.default" async def _provider() -> str: token = await self.get_token_for_resource(sp_scope) @@ -1431,8 +1597,15 @@ async def _handle_401(self) -> str: return await self._get_access_token() def _make_sp_token_provider(self) -> Optional[Callable]: - """Create SP token provider via certificate exchange.""" - hostname = self._derive_sp_hostname() + """Create SP token provider via certificate exchange (config-site-bound).""" + return self._make_sp_token_provider_for_site(self._site_url) + + def _make_sp_token_provider_for_site(self, site_url: str) -> Optional[Callable]: + """Create SP token provider for a specific site URL via certificate exchange.""" + if not site_url: + return None + parsed = urlparse(site_url) + hostname = parsed.netloc if not hostname: return None diff --git a/backend/tests/unit/platform/sources/test_sharepoint_online_group_expansion.py b/backend/tests/unit/platform/sources/test_sharepoint_online_group_expansion.py new file mode 100644 index 000000000..669912b17 --- /dev/null +++ b/backend/tests/unit/platform/sources/test_sharepoint_online_group_expansion.py @@ -0,0 +1,340 @@ +"""Unit tests for SharePoint Online SP site group expansion helpers. + +Covers _parse_sp_group_member, _email_from_membership_login, and the cursor +migration path for tracked_sp_groups. +""" + +from unittest.mock import MagicMock + +from airweave.platform.sources.sharepoint_online.source import SharePointOnlineBase + +# --------------------------------------------------------------------------- +# _email_from_membership_login +# --------------------------------------------------------------------------- + + +def test_email_from_membership_login_valid(): + assert ( + SharePointOnlineBase._email_from_membership_login("i:0#.f|membership|foo@bar.com") + == "foo@bar.com" + ) + + +def test_email_from_membership_login_uppercase_normalized(): + assert ( + SharePointOnlineBase._email_from_membership_login("i:0#.f|membership|Foo@BAR.com") + == "foo@bar.com" + ) + + +def test_email_from_membership_login_rejects_role_principal(): + # Role principals would otherwise yield "spo-grid-all-users/..." — must reject. + assert ( + SharePointOnlineBase._email_from_membership_login( + "c:0-.f|rolemanager|spo-grid-all-users/26adf163-2699-4d04-a0ad-3d935411bf45" + ) + is None + ) + + +def test_email_from_membership_login_rejects_federated_group(): + assert ( + SharePointOnlineBase._email_from_membership_login( + "c:0o.c|federateddirectoryclaimprovider|58cb1814-203a-44d0-8578-b53f63860579" + ) + is None + ) + + +def test_email_from_membership_login_rejects_empty(): + assert SharePointOnlineBase._email_from_membership_login("") is None + + +def test_email_from_membership_login_rejects_malformed(): + assert SharePointOnlineBase._email_from_membership_login("i:0#.f|membership|noat") is None + + +# --------------------------------------------------------------------------- +# _parse_sp_group_member +# --------------------------------------------------------------------------- + + +def test_parse_real_user_with_email(): + user = { + "PrincipalType": 1, + "LoginName": "i:0#.f|membership|alice@contoso.com", + "Email": "alice@contoso.com", + "Title": "Alice", + } + assert SharePointOnlineBase._parse_sp_group_member(user) == ( + "alice@contoso.com", + "user", + ) + + +def test_parse_real_user_uppercase_email_normalized(): + user = { + "PrincipalType": 1, + "LoginName": "i:0#.f|membership|ALICE@CONTOSO.COM", + "Email": "ALICE@CONTOSO.COM", + "Title": "Alice", + } + assert SharePointOnlineBase._parse_sp_group_member(user) == ( + "alice@contoso.com", + "user", + ) + + +def test_parse_real_user_email_empty_fallback_to_login(): + # If Email is missing but LoginName has the membership pattern, use that. + user = { + "PrincipalType": 1, + "LoginName": "i:0#.f|membership|alice@contoso.com", + "Email": "", + "Title": "Alice", + } + assert SharePointOnlineBase._parse_sp_group_member(user) == ( + "alice@contoso.com", + "user", + ) + + +def test_parse_real_user_no_email_no_parseable_login_returns_none(): + # System Account and similar — no Email, no membership LoginName. + user = { + "PrincipalType": 1, + "LoginName": "SHAREPOINT\\system", + "Email": "", + "Title": "System Account", + } + assert SharePointOnlineBase._parse_sp_group_member(user) is None + + +def test_parse_role_principal_skipped(): + """Bug B regression test — 'Everyone except external users' must not become a fake user.""" + user = { + "PrincipalType": 16, + "LoginName": "c:0-.f|rolemanager|spo-grid-all-users/26adf163-2699-4d04-a0ad-3d935411bf45", + "Email": "", + "Title": "Everyone except external users", + } + assert SharePointOnlineBase._parse_sp_group_member(user) is None + + +def test_parse_entra_group_emits_group_membership(): + """Bug C/D regression test — Entra group must be emitted as group-to-group.""" + user = { + "PrincipalType": 4, + "LoginName": "c:0o.c|federateddirectoryclaimprovider|58cb1814-203a-44d0-8578-b53f63860579", + "Email": "neena@neenacorp.onmicrosoft.com", # group's email, must NOT be used + "Title": "Neena Members", + } + assert SharePointOnlineBase._parse_sp_group_member(user) == ( + "entra:58cb1814-203a-44d0-8578-b53f63860579", + "group", + ) + + +def test_parse_entra_group_owner_suffix_stripped(): + """Owner-style claim has `_o` suffix — must strip it to get the bare GUID.""" + login = "c:0o.c|federateddirectoryclaimprovider|58cb1814-203a-44d0-8578-b53f63860579_o" + user = { + "PrincipalType": 4, + "LoginName": login, + "Email": "neena@neenacorp.onmicrosoft.com", + "Title": "Neena Owners", + } + assert SharePointOnlineBase._parse_sp_group_member(user) == ( + "entra:58cb1814-203a-44d0-8578-b53f63860579", + "group", + ) + + +def test_parse_entra_group_uppercase_guid_normalized(): + user = { + "PrincipalType": 4, + "LoginName": "c:0o.c|federateddirectoryclaimprovider|58CB1814-203A-44D0-8578-B53F63860579", + "Title": "Neena Owners", + } + assert SharePointOnlineBase._parse_sp_group_member(user) == ( + "entra:58cb1814-203a-44d0-8578-b53f63860579", + "group", + ) + + +def test_parse_entra_group_malformed_guid_returns_none(): + user = { + "PrincipalType": 4, + "LoginName": "c:0o.c|federateddirectoryclaimprovider|not-a-guid", + "Title": "Bad Group", + } + assert SharePointOnlineBase._parse_sp_group_member(user) is None + + +def test_parse_security_group_non_federated_returns_none(): + # PrincipalType=4 but not federated — on-prem AD claim, skip. + user = { + "PrincipalType": 4, + "LoginName": "c:0-.f|adclaimprovider|S-1-5-21-...", + "Title": "On-prem Group", + } + assert SharePointOnlineBase._parse_sp_group_member(user) is None + + +def test_parse_distlist_skipped(): + user = {"PrincipalType": 2, "LoginName": "some-dl", "Title": "DL"} + assert SharePointOnlineBase._parse_sp_group_member(user) is None + + +def test_parse_spgroup_skipped(): + user = {"PrincipalType": 8, "LoginName": "some-sp", "Title": "SP"} + assert SharePointOnlineBase._parse_sp_group_member(user) is None + + +def test_parse_all_catchall_skipped(): + user = {"PrincipalType": 15, "LoginName": "everyone", "Title": "All"} + assert SharePointOnlineBase._parse_sp_group_member(user) is None + + +def test_parse_unknown_principal_type_skipped(): + user = {"PrincipalType": 99, "LoginName": "x", "Title": "X"} + assert SharePointOnlineBase._parse_sp_group_member(user) is None + + +def test_parse_missing_principal_type_skipped(): + user = {"LoginName": "x", "Title": "X", "Email": "x@y.z"} + assert SharePointOnlineBase._parse_sp_group_member(user) is None + + +# --------------------------------------------------------------------------- +# _normalize_site_url +# --------------------------------------------------------------------------- + + +def test_normalize_site_url_strips_trailing_slash(): + assert ( + SharePointOnlineBase._normalize_site_url("https://contoso.sharepoint.com/sites/X/") + == "https://contoso.sharepoint.com/sites/X" + ) + + +def test_normalize_site_url_empty(): + assert SharePointOnlineBase._normalize_site_url("") == "" + assert SharePointOnlineBase._normalize_site_url(None) == "" # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# _track_entity_groups with site_url scoping +# --------------------------------------------------------------------------- + + +class _StubEntity: + def __init__(self, viewers): + self.access = MagicMock() + self.access.viewers = viewers + + +def _bare_base() -> SharePointOnlineBase: + """Instantiate the base class just enough to exercise tracking logic. + + We bypass the normal source creation path since we only need the tracking + state and its methods. + """ + instance = SharePointOnlineBase.__new__(SharePointOnlineBase) + instance._site_url = "" + instance._include_personal_sites = False + instance._include_pages = False + instance._item_level_entra_groups = set() + instance._item_level_sp_groups = {} + return instance + + +def test_track_entity_groups_scopes_sp_by_site(): + base = _bare_base() + e = _StubEntity( + [ + "group:sp:neena_members", + "group:sp:neena_owners", + "group:entra:58cb1814-203a-44d0-8578-b53f63860579", + "user:alice@contoso.com", + ] + ) + base._track_entity_groups(e, "https://neenacorp.sharepoint.com/sites/Neena77") + + assert base._item_level_sp_groups == { + "https://neenacorp.sharepoint.com/sites/Neena77": { + "sp:neena_members", + "sp:neena_owners", + } + } + assert base._item_level_entra_groups == {"entra:58cb1814-203a-44d0-8578-b53f63860579"} + + +def test_track_entity_groups_multiple_sites_keep_separate(): + base = _bare_base() + base._track_entity_groups( + _StubEntity(["group:sp:neena_members"]), + "https://neenacorp.sharepoint.com/sites/A", + ) + base._track_entity_groups( + _StubEntity(["group:sp:access_control_tests_owners"]), + "https://neenacorp.sharepoint.com/sites/B", + ) + + assert base._item_level_sp_groups == { + "https://neenacorp.sharepoint.com/sites/A": {"sp:neena_members"}, + "https://neenacorp.sharepoint.com/sites/B": {"sp:access_control_tests_owners"}, + } + + +def test_track_entity_groups_same_name_different_sites_do_not_collide(): + base = _bare_base() + base._track_entity_groups( + _StubEntity(["group:sp:members"]), + "https://neenacorp.sharepoint.com/sites/A", + ) + base._track_entity_groups( + _StubEntity(["group:sp:members"]), + "https://neenacorp.sharepoint.com/sites/B", + ) + + # Same group name but two different sites — must be tracked independently. + assert set(base._item_level_sp_groups.keys()) == { + "https://neenacorp.sharepoint.com/sites/A", + "https://neenacorp.sharepoint.com/sites/B", + } + + +def test_track_entity_groups_normalizes_trailing_slash(): + base = _bare_base() + base._track_entity_groups( + _StubEntity(["group:sp:x"]), + "https://neenacorp.sharepoint.com/sites/A/", + ) + base._track_entity_groups( + _StubEntity(["group:sp:y"]), + "https://neenacorp.sharepoint.com/sites/A", + ) + # Both should land under the same normalized key. + assert base._item_level_sp_groups == { + "https://neenacorp.sharepoint.com/sites/A": {"sp:x", "sp:y"} + } + + +def test_track_entity_groups_no_access_noop(): + base = _bare_base() + entity = MagicMock() + entity.access = None + base._track_entity_groups(entity, "https://neenacorp.sharepoint.com/sites/A") + assert base._item_level_sp_groups == {} + + +def test_track_entity_groups_empty_site_url_still_stores_under_empty_key(): + """Groups are still stored under the empty-string key when no site_url. + + Expansion skips empty-key buckets, so this is effectively a no-op for + broker purposes but keeps the data structure consistent. + """ + base = _bare_base() + base._track_entity_groups(_StubEntity(["group:sp:orphan"]), "") + assert base._item_level_sp_groups == {"": {"sp:orphan"}} From c9fccf86bebe35394da59074660829cfe12a8c02 Mon Sep 17 00:00:00 2001 From: Felix Schmetz Date: Thu, 23 Apr 2026 09:52:39 +0200 Subject: [PATCH 11/25] feat: lazy + configurable LLM fallback chain Let deployments boot without any chain-provider API key: when the chain resolves to zero providers, inject an UnavailableLLM null-object instead of raising. Instant search keeps working; classic and agentic search surface LLMUnavailableError on first use, mapped to HTTP 503 via the existing handler. Also expose LLM_FALLBACK_CHAIN as an env var (format: "provider:model,provider:model") so private-cloud deployers can override the chain without forking. Unset falls back to the current hardcoded [together:zai-glm-5, anthropic:claude-sonnet-4.6] default. --- .../adapters/llm/tests/test_unavailable.py | 56 ++++++++++++++ backend/airweave/adapters/llm/unavailable.py | 58 ++++++++++++++ backend/airweave/api/middleware.py | 2 + backend/airweave/core/config/settings.py | 7 ++ backend/airweave/core/container/factory.py | 23 +++--- .../container/tests/test_llm_chain_wiring.py | 53 +++++++++++++ backend/airweave/core/exceptions.py | 22 ++++++ backend/airweave/domains/search/config.py | 64 ++++++++++++++-- .../domains/search/tests/test_config.py | 75 +++++++++++++++++++ 9 files changed, 343 insertions(+), 17 deletions(-) create mode 100644 backend/airweave/adapters/llm/tests/test_unavailable.py create mode 100644 backend/airweave/adapters/llm/unavailable.py create mode 100644 backend/airweave/core/container/tests/test_llm_chain_wiring.py create mode 100644 backend/airweave/domains/search/tests/test_config.py diff --git a/backend/airweave/adapters/llm/tests/test_unavailable.py b/backend/airweave/adapters/llm/tests/test_unavailable.py new file mode 100644 index 000000000..0b9c76deb --- /dev/null +++ b/backend/airweave/adapters/llm/tests/test_unavailable.py @@ -0,0 +1,56 @@ +"""Tests for UnavailableLLM null-object adapter.""" + +from __future__ import annotations + +import pytest +from pydantic import BaseModel + +from airweave.adapters.llm.unavailable import UnavailableLLM +from airweave.core.exceptions import LLMUnavailableError + + +class _Dummy(BaseModel): + key: str + + +@pytest.mark.asyncio +async def test_structured_output_raises_llm_unavailable_error() -> None: + llm = UnavailableLLM() + with pytest.raises(LLMUnavailableError): + await llm.structured_output(prompt="x", schema=_Dummy, system_prompt="y") + + +@pytest.mark.asyncio +async def test_chat_raises_llm_unavailable_error() -> None: + llm = UnavailableLLM() + with pytest.raises(LLMUnavailableError): + await llm.chat(messages=[], tools=[], system_prompt="y") + + +def test_model_spec_raises_llm_unavailable_error() -> None: + llm = UnavailableLLM() + with pytest.raises(LLMUnavailableError): + _ = llm.model_spec + + +@pytest.mark.asyncio +async def test_close_is_a_safe_noop() -> None: + llm = UnavailableLLM() + assert await llm.close() is None + + +def test_error_message_mentions_accepted_api_key_env_vars() -> None: + llm = UnavailableLLM() + with pytest.raises(LLMUnavailableError) as excinfo: + _ = llm.model_spec + + message = str(excinfo.value) + for env_var in ( + "TOGETHER_API_KEY", + "ANTHROPIC_API_KEY", + "MISTRAL_API_KEY", + "GROQ_API_KEY", + "CEREBRAS_API_KEY", + "LLM_FALLBACK_CHAIN", + ): + assert env_var in message, f"{env_var} missing from error message" diff --git a/backend/airweave/adapters/llm/unavailable.py b/backend/airweave/adapters/llm/unavailable.py new file mode 100644 index 000000000..1065706de --- /dev/null +++ b/backend/airweave/adapters/llm/unavailable.py @@ -0,0 +1,58 @@ +"""Null-object LLM provider for deployments without any configured API key. + +Wired into the container when LLM_FALLBACK_CHAIN has no entries whose API key is +set. Instant search — which does not use an LLM — keeps working. Classic and +agentic search services are unchanged (they still expect a non-null LLMProtocol); +the failure surfaces on first use as LLMUnavailableError, which the FastAPI +exception handler maps to HTTP 503. +""" + +from __future__ import annotations + +from typing import Any, TypeVar + +from pydantic import BaseModel + +from airweave.adapters.llm.tool_response import LLMResponse +from airweave.core.exceptions import LLMUnavailableError + +T = TypeVar("T", bound=BaseModel) + + +class UnavailableLLM: + """LLMProtocol implementation that raises on every call. + + The protocol is structural (``typing.Protocol``), so no inheritance is + required. Every method and the ``model_spec`` property raise + ``LLMUnavailableError`` with an actionable message. + """ + + @property + def model_spec(self) -> Any: + """Raise because no provider is configured.""" + raise LLMUnavailableError() + + async def structured_output( + self, + prompt: str, + schema: type[T], + system_prompt: str, + thinking: bool = False, + ) -> T: + """Raise because no provider is configured.""" + raise LLMUnavailableError() + + async def chat( + self, + messages: list[dict], + tools: list[dict], + system_prompt: str, + thinking: bool = False, + max_tokens: int | None = None, + ) -> LLMResponse: + """Raise because no provider is configured.""" + raise LLMUnavailableError() + + async def close(self) -> None: + """No-op: the null-object holds no resources.""" + return None diff --git a/backend/airweave/api/middleware.py b/backend/airweave/api/middleware.py index 892c51c85..6237aefa0 100644 --- a/backend/airweave/api/middleware.py +++ b/backend/airweave/api/middleware.py @@ -22,6 +22,7 @@ AirweaveException, InvalidInputError, InvalidStateError, + LLMUnavailableError, NotFoundException, PermissionException, RateLimitExceededException, @@ -437,6 +438,7 @@ async def airweave_exception_handler(request: Request, exc: AirweaveException) - # Add new base classes here as they're introduced (BadRequestError, etc.). status_map = { TokenRefreshError: 401, + LLMUnavailableError: 503, } for exc_type, code in status_map.items(): diff --git a/backend/airweave/core/config/settings.py b/backend/airweave/core/config/settings.py index 1ab327bd0..a5bc6c69e 100644 --- a/backend/airweave/core/config/settings.py +++ b/backend/airweave/core/config/settings.py @@ -195,6 +195,13 @@ class Settings(BaseSettings): TOGETHER_API_KEY: Optional[str] = None AZURE_KEYVAULT_NAME: Optional[str] = None + # Overrides SearchConfig.LLM_FALLBACK_CHAIN when set. + # Format: comma-separated "provider:model" pairs using the values from + # airweave.adapters.llm.registry (e.g. "mistral:mistral-large" or + # "together:zai-glm-5,anthropic:claude-sonnet-4.6"). Unset → use the + # in-code default in domains/search/config.py. + LLM_FALLBACK_CHAIN: Optional[str] = None + # Docling OCR fallback service (None = disabled) DOCLING_BASE_URL: Optional[str] = None diff --git a/backend/airweave/core/container/factory.py b/backend/airweave/core/container/factory.py index 2d6602f0a..2799b51a0 100644 --- a/backend/airweave/core/container/factory.py +++ b/backend/airweave/core/container/factory.py @@ -33,6 +33,7 @@ get_model_spec as get_llm_model_spec, ) from airweave.adapters.llm.together import TogetherLLM +from airweave.adapters.llm.unavailable import UnavailableLLM from airweave.adapters.metrics import ( PrometheusAgenticSearchMetrics, PrometheusDbPoolMetrics, @@ -1186,11 +1187,10 @@ def _build_llm_chain( ): """Build LLM fallback chain from SearchConfig, skipping providers without API keys. - Returns: - An LLM instance (single provider or FallbackChainLLM). - - Raises: - ValueError: If no LLM providers are available. + When no provider in the chain has a configured API key (or all fail to initialize), + returns an ``UnavailableLLM`` null-object rather than raising. This keeps the + backend bootable for deployers who only use Instant search. Classic and agentic + search surface ``LLMUnavailableError`` on first invocation, mapped to HTTP 503. """ provider_classes = { LLMProvider.ANTHROPIC: AnthropicLLM, @@ -1218,10 +1218,11 @@ def _build_llm_chain( available.append((provider, model, model_spec, provider_cls)) if not available: - raise ValueError( - "No LLM providers available for search. " - "Configure at least one API key from SearchConfig.LLM_FALLBACK_CHAIN." + logger.info( + "[SearchFactory] No LLM provider API keys configured — classic/agentic " + "search will return HTTP 503 until a key is set. Instant search is unaffected." ) + return UnavailableLLM() # Single provider: use default retries. # Multiple providers: max_retries=0 for all except the last provider, @@ -1243,9 +1244,11 @@ def _build_llm_chain( ) if not llm_providers: - raise ValueError( - "No LLM providers available for search. All configured providers failed to initialize." + logger.warning( + "[SearchFactory] All configured LLM providers failed to initialize — " + "classic/agentic search will return HTTP 503. Instant search is unaffected." ) + return UnavailableLLM() if len(llm_providers) == 1: return llm_providers[0] diff --git a/backend/airweave/core/container/tests/test_llm_chain_wiring.py b/backend/airweave/core/container/tests/test_llm_chain_wiring.py new file mode 100644 index 000000000..3226e1e3b --- /dev/null +++ b/backend/airweave/core/container/tests/test_llm_chain_wiring.py @@ -0,0 +1,53 @@ +"""Tests for _build_llm_chain: null-object fallback when no providers resolve.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +from airweave.adapters.llm.registry import LLMModel, LLMProvider +from airweave.adapters.llm.unavailable import UnavailableLLM +from airweave.core.container.factory import _build_llm_chain + + +def _settings_with_no_keys() -> MagicMock: + s = MagicMock() + for attr in ( + "TOGETHER_API_KEY", + "ANTHROPIC_API_KEY", + "MISTRAL_API_KEY", + "GROQ_API_KEY", + "CEREBRAS_API_KEY", + ): + setattr(s, attr, None) + return s + + +def _config_with_chain(chain: list[tuple[LLMProvider, LLMModel]]) -> MagicMock: + config = MagicMock() + config.LLM_FALLBACK_CHAIN = chain + return config + + +def test_returns_unavailable_llm_when_no_keys_configured() -> None: + settings = _settings_with_no_keys() + config = _config_with_chain( + [ + (LLMProvider.TOGETHER, LLMModel.ZAI_GLM_5), + (LLMProvider.ANTHROPIC, LLMModel.CLAUDE_SONNET_4_6), + ] + ) + circuit_breaker = MagicMock() + + llm = _build_llm_chain(settings, config, circuit_breaker) + + assert isinstance(llm, UnavailableLLM) + + +def test_returns_unavailable_llm_when_chain_is_empty() -> None: + settings = _settings_with_no_keys() + config = _config_with_chain([]) + circuit_breaker = MagicMock() + + llm = _build_llm_chain(settings, config, circuit_breaker) + + assert isinstance(llm, UnavailableLLM) diff --git a/backend/airweave/core/exceptions.py b/backend/airweave/core/exceptions.py index eea773d88..d00418953 100644 --- a/backend/airweave/core/exceptions.py +++ b/backend/airweave/core/exceptions.py @@ -208,6 +208,28 @@ def __init__( super().__init__(self.message) +class LLMUnavailableError(AirweaveException): + """Raised when an LLM-backed feature is requested but no LLM provider is configured. + + The container wires an UnavailableLLM null-object when LLM_FALLBACK_CHAIN has no + entries with a configured API key. Instant search still works; classic/agentic + search surface this on first use and map to HTTP 503. + """ + + def __init__(self, message: Optional[str] = None): + """Create a new LLMUnavailableError with an actionable default message.""" + if message is None: + message = ( + "No LLM provider configured. Set one of: " + "TOGETHER_API_KEY, ANTHROPIC_API_KEY, MISTRAL_API_KEY, " + "GROQ_API_KEY, CEREBRAS_API_KEY — " + "or customize the chain via LLM_FALLBACK_CHAIN " + "(format: 'provider:model,provider:model')." + ) + self.message = message + super().__init__(self.message) + + class SourceRateLimitExceededException(Exception): """Exception raised when source API rate limit is exceeded. diff --git a/backend/airweave/domains/search/config.py b/backend/airweave/domains/search/config.py index d812862b4..a37ab9480 100644 --- a/backend/airweave/domains/search/config.py +++ b/backend/airweave/domains/search/config.py @@ -4,6 +4,58 @@ from airweave.adapters.llm.registry import LLMModel, LLMProvider from airweave.adapters.tokenizer.registry import TokenizerEncoding, TokenizerType +from airweave.core.config import settings + +_DEFAULT_LLM_FALLBACK_CHAIN: list[tuple[LLMProvider, LLMModel]] = [ + (LLMProvider.TOGETHER, LLMModel.ZAI_GLM_5), + (LLMProvider.ANTHROPIC, LLMModel.CLAUDE_SONNET_4_6), +] + + +def parse_llm_fallback_chain(raw: str | None) -> list[tuple[LLMProvider, LLMModel]]: + """Parse the LLM_FALLBACK_CHAIN env var. + + Format: comma-separated ``provider:model`` pairs using the enum ``value`` + strings from ``airweave.adapters.llm.registry``. When ``raw`` is None or + empty, returns the in-code default chain. + + Raises ValueError at import time (startup) on unknown provider or model names, + listing the accepted values so deployers can fix the misconfiguration fast. + """ + if not raw or not raw.strip(): + return list(_DEFAULT_LLM_FALLBACK_CHAIN) + + valid_providers = {p.value: p for p in LLMProvider} + valid_models = {m.value: m for m in LLMModel} + + parsed: list[tuple[LLMProvider, LLMModel]] = [] + for entry in raw.split(","): + entry = entry.strip() + if not entry: + continue + if ":" not in entry: + raise ValueError( + f"Invalid LLM_FALLBACK_CHAIN entry {entry!r}: expected 'provider:model'." + ) + provider_raw, model_raw = entry.split(":", 1) + provider_raw = provider_raw.strip() + model_raw = model_raw.strip() + + if provider_raw not in valid_providers: + raise ValueError( + f"Unknown provider {provider_raw!r} in LLM_FALLBACK_CHAIN. " + f"Accepted: {sorted(valid_providers)}." + ) + if model_raw not in valid_models: + raise ValueError( + f"Unknown model {model_raw!r} in LLM_FALLBACK_CHAIN. " + f"Accepted: {sorted(valid_models)}." + ) + parsed.append((valid_providers[provider_raw], valid_models[model_raw])) + + if not parsed: + return list(_DEFAULT_LLM_FALLBACK_CHAIN) + return parsed class DatabaseImpl(str, Enum): @@ -35,13 +87,11 @@ class SearchConfig: # configured) and responds successfully handles the request. Subsequent # providers are only tried when the previous one fails. # - # To change the primary model, reorder this list or swap the model for a - # provider. For example, to use GPT_OSS_120B on Cerebras instead of GLM: - # (LLMProvider.CEREBRAS, LLMModel.GPT_OSS_120B), - LLM_FALLBACK_CHAIN: list[tuple[LLMProvider, LLMModel]] = [ - (LLMProvider.TOGETHER, LLMModel.ZAI_GLM_5), - (LLMProvider.ANTHROPIC, LLMModel.CLAUDE_SONNET_4_6), - ] + # Deployers can override via the LLM_FALLBACK_CHAIN env var + # (format: "provider:model,provider:model"). Unset → use the default below. + LLM_FALLBACK_CHAIN: list[tuple[LLMProvider, LLMModel]] = parse_llm_fallback_chain( + settings.LLM_FALLBACK_CHAIN + ) # Tokenizer # Note: Must be compatible with the chosen LLM model (validated at startup) diff --git a/backend/airweave/domains/search/tests/test_config.py b/backend/airweave/domains/search/tests/test_config.py new file mode 100644 index 000000000..e7a91e6af --- /dev/null +++ b/backend/airweave/domains/search/tests/test_config.py @@ -0,0 +1,75 @@ +"""Tests for LLM_FALLBACK_CHAIN env-var parser in search config.""" + +from __future__ import annotations + +import pytest + +from airweave.adapters.llm.registry import LLMModel, LLMProvider +from airweave.domains.search.config import ( + _DEFAULT_LLM_FALLBACK_CHAIN, + parse_llm_fallback_chain, +) + + +def test_none_returns_default_chain() -> None: + assert parse_llm_fallback_chain(None) == list(_DEFAULT_LLM_FALLBACK_CHAIN) + + +def test_empty_string_returns_default_chain() -> None: + assert parse_llm_fallback_chain("") == list(_DEFAULT_LLM_FALLBACK_CHAIN) + + +def test_whitespace_only_returns_default_chain() -> None: + assert parse_llm_fallback_chain(" ") == list(_DEFAULT_LLM_FALLBACK_CHAIN) + + +def test_single_entry_parsed_to_tuple() -> None: + parsed = parse_llm_fallback_chain("mistral:mistral-large") + assert parsed == [(LLMProvider.MISTRAL, LLMModel.MISTRAL_LARGE)] + + +def test_multiple_entries_preserve_order() -> None: + parsed = parse_llm_fallback_chain( + "together:zai-glm-5,anthropic:claude-sonnet-4.6,mistral:mistral-large" + ) + assert parsed == [ + (LLMProvider.TOGETHER, LLMModel.ZAI_GLM_5), + (LLMProvider.ANTHROPIC, LLMModel.CLAUDE_SONNET_4_6), + (LLMProvider.MISTRAL, LLMModel.MISTRAL_LARGE), + ] + + +def test_whitespace_around_entries_is_ignored() -> None: + parsed = parse_llm_fallback_chain(" mistral : mistral-large , anthropic : claude-sonnet-4.6 ") + assert parsed == [ + (LLMProvider.MISTRAL, LLMModel.MISTRAL_LARGE), + (LLMProvider.ANTHROPIC, LLMModel.CLAUDE_SONNET_4_6), + ] + + +def test_unknown_provider_raises_with_accepted_list() -> None: + with pytest.raises(ValueError) as excinfo: + parse_llm_fallback_chain("bogus:mistral-large") + message = str(excinfo.value) + assert "bogus" in message + assert "mistral" in message + assert "anthropic" in message + + +def test_unknown_model_raises_with_accepted_list() -> None: + with pytest.raises(ValueError) as excinfo: + parse_llm_fallback_chain("mistral:not-a-real-model") + message = str(excinfo.value) + assert "not-a-real-model" in message + assert "mistral-large" in message + + +def test_missing_colon_raises_helpful_error() -> None: + with pytest.raises(ValueError) as excinfo: + parse_llm_fallback_chain("mistral-only") + assert "provider:model" in str(excinfo.value) + + +def test_trailing_comma_is_tolerated() -> None: + parsed = parse_llm_fallback_chain("mistral:mistral-large,") + assert parsed == [(LLMProvider.MISTRAL, LLMModel.MISTRAL_LARGE)] From 89c6314a0d84567a7767976debdac19b1486ffe2 Mon Sep 17 00:00:00 2001 From: Daan Manneke Date: Thu, 23 Apr 2026 10:09:48 +0200 Subject: [PATCH 12/25] refactor(sharepoint_online): remove dead _make_sp_token_provider methods MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Felix flagged the legacy single-site SP token provider as having no callers. Confirmed via grep — only the base declaration and two subclass overrides existed, nothing calling them. Removed all three, plus the likewise-unused _derive_sp_resource_scope helper on the OAuth subclass and the stale base-class docstring reference. _make_sp_token_provider_for_site(site_url) is the sole entry point. --- .../sources/sharepoint_online/source.py | 31 +------------------ 1 file changed, 1 insertion(+), 30 deletions(-) diff --git a/backend/airweave/platform/sources/sharepoint_online/source.py b/backend/airweave/platform/sources/sharepoint_online/source.py index c5d7eb4aa..c4cb4b999 100644 --- a/backend/airweave/platform/sources/sharepoint_online/source.py +++ b/backend/airweave/platform/sources/sharepoint_online/source.py @@ -97,7 +97,7 @@ class SharePointOnlineBase(BaseSource): - create() — class constructor - _get_access_token() — return a valid Microsoft Graph token - _handle_401() — refresh/re-exchange on 401, return new token - - _make_sp_token_provider() — callable returning SP REST API token + - _make_sp_token_provider_for_site(site_url) — per-site SP REST token provider - _get_download_auth(url) — auth suitable for file download - _discover_sites(graph_client) — site discovery strategy """ @@ -129,14 +129,6 @@ async def _handle_401(self) -> str: """Handle a 401 by refreshing/re-exchanging. Returns new token.""" raise NotImplementedError - def _make_sp_token_provider(self) -> Optional[Callable]: - """Create an async callable returning a SharePoint REST API token, or None. - - Legacy single-site provider that derives hostname from ``self._site_url``. - New code should use ``_make_sp_token_provider_for_site`` instead. - """ - raise NotImplementedError - def _make_sp_token_provider_for_site(self, site_url: str) -> Optional[Callable]: """Create an SP REST token provider scoped to a specific site URL. @@ -1351,10 +1343,6 @@ async def _handle_401(self) -> str: return await self.auth.force_refresh() return await self.auth.get_token() - def _make_sp_token_provider(self) -> Optional[Callable]: - """Create SP token provider via OAuth scope exchange (config-site-bound).""" - return self._make_sp_token_provider_for_site(self._site_url) - def _make_sp_token_provider_for_site(self, site_url: str) -> Optional[Callable]: """Create SP token provider for a specific site URL via OAuth scope exchange.""" if not site_url: @@ -1373,19 +1361,6 @@ async def _provider() -> str: return _provider - def _derive_sp_resource_scope(self) -> Optional[str]: - """Derive the SharePoint resource scope from the site URL. - - E.g. https://neenacorp.sharepoint.com/sites/JAman - -> https://neenacorp.sharepoint.com/.default - """ - if not self._site_url: - return None - parsed = urlparse(self._site_url) - if not parsed.netloc: - return None - return f"https://{parsed.netloc}/.default" - async def _discover_sites(self, graph_client: GraphClient) -> List[Dict[str, Any]]: """Discover sites via Graph search (delegated permissions). @@ -1596,10 +1571,6 @@ async def _handle_401(self) -> str: self._graph_token_expires = 0 # force re-exchange return await self._get_access_token() - def _make_sp_token_provider(self) -> Optional[Callable]: - """Create SP token provider via certificate exchange (config-site-bound).""" - return self._make_sp_token_provider_for_site(self._site_url) - def _make_sp_token_provider_for_site(self, site_url: str) -> Optional[Callable]: """Create SP token provider for a specific site URL via certificate exchange.""" if not site_url: From 949f058cee02e1bd8598329e514a4b6118370d4a Mon Sep 17 00:00:00 2001 From: Felix Schmetz Date: Thu, 23 Apr 2026 10:19:21 +0200 Subject: [PATCH 13/25] refactor: tighten LLM fallback chain wiring from review - Derive env-var list in LLMUnavailableError from PROVIDER_API_KEY_SETTINGS so the default message stays in sync when providers are added. - Extract _unavailable() helper in _build_llm_chain to dedupe the two UnavailableLLM return paths and their log lines. - Narrow UnavailableLLM.model_spec return type from Any to LLMModelSpec. - Hoist provider/model lookup dicts to module scope in search config, surface accepted values in enum-declaration order in error messages, and note that SearchConfig.LLM_FALLBACK_CHAIN is evaluated once at class-definition time. --- backend/airweave/adapters/llm/unavailable.py | 5 +++-- backend/airweave/core/container/factory.py | 19 +++++++++--------- backend/airweave/core/exceptions.py | 7 ++++--- backend/airweave/domains/search/config.py | 21 ++++++++++++-------- 4 files changed, 29 insertions(+), 23 deletions(-) diff --git a/backend/airweave/adapters/llm/unavailable.py b/backend/airweave/adapters/llm/unavailable.py index 1065706de..cca436d00 100644 --- a/backend/airweave/adapters/llm/unavailable.py +++ b/backend/airweave/adapters/llm/unavailable.py @@ -9,10 +9,11 @@ from __future__ import annotations -from typing import Any, TypeVar +from typing import TypeVar from pydantic import BaseModel +from airweave.adapters.llm.registry import LLMModelSpec from airweave.adapters.llm.tool_response import LLMResponse from airweave.core.exceptions import LLMUnavailableError @@ -28,7 +29,7 @@ class UnavailableLLM: """ @property - def model_spec(self) -> Any: + def model_spec(self) -> LLMModelSpec: """Raise because no provider is configured.""" raise LLMUnavailableError() diff --git a/backend/airweave/core/container/factory.py b/backend/airweave/core/container/factory.py index 2799b51a0..84cfc320e 100644 --- a/backend/airweave/core/container/factory.py +++ b/backend/airweave/core/container/factory.py @@ -1200,6 +1200,13 @@ def _build_llm_chain( LLMProvider.TOGETHER: TogetherLLM, } + def _unavailable(reason: str, level: str = "info") -> UnavailableLLM: + getattr(logger, level)( + f"[SearchFactory] {reason} — classic/agentic search will return HTTP 503 " + "until a key is set. Instant search is unaffected." + ) + return UnavailableLLM() + # Collect available (provider, model_spec, class) tuples first, # then decide retry strategy based on how many survived. available = [] @@ -1218,11 +1225,7 @@ def _build_llm_chain( available.append((provider, model, model_spec, provider_cls)) if not available: - logger.info( - "[SearchFactory] No LLM provider API keys configured — classic/agentic " - "search will return HTTP 503 until a key is set. Instant search is unaffected." - ) - return UnavailableLLM() + return _unavailable("No LLM provider API keys configured") # Single provider: use default retries. # Multiple providers: max_retries=0 for all except the last provider, @@ -1244,11 +1247,7 @@ def _build_llm_chain( ) if not llm_providers: - logger.warning( - "[SearchFactory] All configured LLM providers failed to initialize — " - "classic/agentic search will return HTTP 503. Instant search is unaffected." - ) - return UnavailableLLM() + return _unavailable("All configured LLM providers failed to initialize", level="warning") if len(llm_providers) == 1: return llm_providers[0] diff --git a/backend/airweave/core/exceptions.py b/backend/airweave/core/exceptions.py index d00418953..83110ae45 100644 --- a/backend/airweave/core/exceptions.py +++ b/backend/airweave/core/exceptions.py @@ -4,6 +4,8 @@ from pydantic import ValidationError +from airweave.adapters.llm.registry import PROVIDER_API_KEY_SETTINGS + class AirweaveException(Exception): """Base exception for Airweave services.""" @@ -219,10 +221,9 @@ class LLMUnavailableError(AirweaveException): def __init__(self, message: Optional[str] = None): """Create a new LLMUnavailableError with an actionable default message.""" if message is None: + env_vars = ", ".join(PROVIDER_API_KEY_SETTINGS.values()) message = ( - "No LLM provider configured. Set one of: " - "TOGETHER_API_KEY, ANTHROPIC_API_KEY, MISTRAL_API_KEY, " - "GROQ_API_KEY, CEREBRAS_API_KEY — " + f"No LLM provider configured. Set one of: {env_vars} — " "or customize the chain via LLM_FALLBACK_CHAIN " "(format: 'provider:model,provider:model')." ) diff --git a/backend/airweave/domains/search/config.py b/backend/airweave/domains/search/config.py index a37ab9480..a096a8f63 100644 --- a/backend/airweave/domains/search/config.py +++ b/backend/airweave/domains/search/config.py @@ -11,6 +11,11 @@ (LLMProvider.ANTHROPIC, LLMModel.CLAUDE_SONNET_4_6), ] +# Value → enum lookup tables built once at import time. Dict insertion order +# matches enum declaration order, which we surface in error messages. +_VALID_PROVIDERS: dict[str, LLMProvider] = {p.value: p for p in LLMProvider} +_VALID_MODELS: dict[str, LLMModel] = {m.value: m for m in LLMModel} + def parse_llm_fallback_chain(raw: str | None) -> list[tuple[LLMProvider, LLMModel]]: """Parse the LLM_FALLBACK_CHAIN env var. @@ -25,9 +30,6 @@ def parse_llm_fallback_chain(raw: str | None) -> list[tuple[LLMProvider, LLMMode if not raw or not raw.strip(): return list(_DEFAULT_LLM_FALLBACK_CHAIN) - valid_providers = {p.value: p for p in LLMProvider} - valid_models = {m.value: m for m in LLMModel} - parsed: list[tuple[LLMProvider, LLMModel]] = [] for entry in raw.split(","): entry = entry.strip() @@ -41,17 +43,17 @@ def parse_llm_fallback_chain(raw: str | None) -> list[tuple[LLMProvider, LLMMode provider_raw = provider_raw.strip() model_raw = model_raw.strip() - if provider_raw not in valid_providers: + if provider_raw not in _VALID_PROVIDERS: raise ValueError( f"Unknown provider {provider_raw!r} in LLM_FALLBACK_CHAIN. " - f"Accepted: {sorted(valid_providers)}." + f"Accepted: {list(_VALID_PROVIDERS)}." ) - if model_raw not in valid_models: + if model_raw not in _VALID_MODELS: raise ValueError( f"Unknown model {model_raw!r} in LLM_FALLBACK_CHAIN. " - f"Accepted: {sorted(valid_models)}." + f"Accepted: {list(_VALID_MODELS)}." ) - parsed.append((valid_providers[provider_raw], valid_models[model_raw])) + parsed.append((_VALID_PROVIDERS[provider_raw], _VALID_MODELS[model_raw])) if not parsed: return list(_DEFAULT_LLM_FALLBACK_CHAIN) @@ -89,6 +91,9 @@ class SearchConfig: # # Deployers can override via the LLM_FALLBACK_CHAIN env var # (format: "provider:model,provider:model"). Unset → use the default below. + # Evaluated once at class-definition time. Tests that need to vary this must + # call parse_llm_fallback_chain directly or reload the module — monkey- + # patching settings.LLM_FALLBACK_CHAIN after import has no effect here. LLM_FALLBACK_CHAIN: list[tuple[LLMProvider, LLMModel]] = parse_llm_fallback_chain( settings.LLM_FALLBACK_CHAIN ) From 8c4bdd9f592ce4ad85de4aca681e243b9a9457d1 Mon Sep 17 00:00:00 2001 From: Felix Schmetz Date: Thu, 23 Apr 2026 10:22:39 +0200 Subject: [PATCH 14/25] fix(tests): drop redundant assertion on UnavailableLLM.close() close() is typed -> None, so 'assert await llm.close() is None' trips mypy's func-returns-value check. The test's intent is that close() is a safe no-op (does not raise), which a bare await already covers. --- backend/airweave/adapters/llm/tests/test_unavailable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/airweave/adapters/llm/tests/test_unavailable.py b/backend/airweave/adapters/llm/tests/test_unavailable.py index 0b9c76deb..8da5f316b 100644 --- a/backend/airweave/adapters/llm/tests/test_unavailable.py +++ b/backend/airweave/adapters/llm/tests/test_unavailable.py @@ -36,7 +36,7 @@ def test_model_spec_raises_llm_unavailable_error() -> None: @pytest.mark.asyncio async def test_close_is_a_safe_noop() -> None: llm = UnavailableLLM() - assert await llm.close() is None + await llm.close() # must not raise; returns None by type def test_error_message_mentions_accepted_api_key_env_vars() -> None: From e8567094ee3d6d622eaf6bfe971c5be5af8baf14 Mon Sep 17 00:00:00 2001 From: Felix Schmetz Date: Thu, 23 Apr 2026 10:35:10 +0200 Subject: [PATCH 15/25] fix(deps): drop squatted 'weaviate' placeholder package MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit pyproject.toml declared both weaviate (^0.1.2) and weaviate-client (^4.10.2). The former is a PyPI placeholder ("A placeholder package for the Weaviate name") that ships its own weaviate/__init__.py, colliding with the real client. Poetry 2.x now fails the Docker image build with 'Installing weaviate/__init__.py over existing file'; 1.x silently tolerated the overwrite. No code imports weaviate — the only repo reference is a URL string in a test. Removing the placeholder unbreaks the container build. --- backend/poetry.lock | 145 +---------------------------------------- backend/pyproject.toml | 2 - 2 files changed, 1 insertion(+), 146 deletions(-) diff --git a/backend/poetry.lock b/backend/poetry.lock index edfff3ff3..7dfbf176d 100644 --- a/backend/poetry.lock +++ b/backend/poetry.lock @@ -1625,21 +1625,6 @@ wrapt = ">=1.10,<3" [package.extras] dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "setuptools ; python_version >= \"3.12\"", "tox"] -[[package]] -name = "deprecation" -version = "2.1.0" -description = "A library to handle automated deprecations" -optional = false -python-versions = "*" -groups = ["main"] -files = [ - {file = "deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a"}, - {file = "deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff"}, -] - -[package.dependencies] -packaging = "*" - [[package]] name = "diff-cover" version = "10.2.0" @@ -2677,83 +2662,6 @@ typing-extensions = ">=4.10,<5" [package.extras] aiohttp = ["aiohttp", "httpx-aiohttp (>=0.1.8)"] -[[package]] -name = "grpcio" -version = "1.76.0" -description = "HTTP/2-based RPC framework" -optional = false -python-versions = ">=3.9" -groups = ["main"] -files = [ - {file = "grpcio-1.76.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:65a20de41e85648e00305c1bb09a3598f840422e522277641145a32d42dcefcc"}, - {file = "grpcio-1.76.0-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:40ad3afe81676fd9ec6d9d406eda00933f218038433980aa19d401490e46ecde"}, - {file = "grpcio-1.76.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:035d90bc79eaa4bed83f524331d55e35820725c9fbb00ffa1904d5550ed7ede3"}, - {file = "grpcio-1.76.0-cp310-cp310-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:4215d3a102bd95e2e11b5395c78562967959824156af11fa93d18fdd18050990"}, - {file = "grpcio-1.76.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:49ce47231818806067aea3324d4bf13825b658ad662d3b25fada0bdad9b8a6af"}, - {file = "grpcio-1.76.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:8cc3309d8e08fd79089e13ed4819d0af72aa935dd8f435a195fd152796752ff2"}, - {file = "grpcio-1.76.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:971fd5a1d6e62e00d945423a567e42eb1fa678ba89072832185ca836a94daaa6"}, - {file = "grpcio-1.76.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:9d9adda641db7207e800a7f089068f6f645959f2df27e870ee81d44701dd9db3"}, - {file = "grpcio-1.76.0-cp310-cp310-win32.whl", hash = "sha256:063065249d9e7e0782d03d2bca50787f53bd0fb89a67de9a7b521c4a01f1989b"}, - {file = "grpcio-1.76.0-cp310-cp310-win_amd64.whl", hash = "sha256:a6ae758eb08088d36812dd5d9af7a9859c05b1e0f714470ea243694b49278e7b"}, - {file = "grpcio-1.76.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:2e1743fbd7f5fa713a1b0a8ac8ebabf0ec980b5d8809ec358d488e273b9cf02a"}, - {file = "grpcio-1.76.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:a8c2cf1209497cf659a667d7dea88985e834c24b7c3b605e6254cbb5076d985c"}, - {file = "grpcio-1.76.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:08caea849a9d3c71a542827d6df9d5a69067b0a1efbea8a855633ff5d9571465"}, - {file = "grpcio-1.76.0-cp311-cp311-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:f0e34c2079d47ae9f6188211db9e777c619a21d4faba6977774e8fa43b085e48"}, - {file = "grpcio-1.76.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8843114c0cfce61b40ad48df65abcfc00d4dba82eae8718fab5352390848c5da"}, - {file = "grpcio-1.76.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8eddfb4d203a237da6f3cc8a540dad0517d274b5a1e9e636fd8d2c79b5c1d397"}, - {file = "grpcio-1.76.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:32483fe2aab2c3794101c2a159070584e5db11d0aa091b2c0ea9c4fc43d0d749"}, - {file = "grpcio-1.76.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:dcfe41187da8992c5f40aa8c5ec086fa3672834d2be57a32384c08d5a05b4c00"}, - {file = "grpcio-1.76.0-cp311-cp311-win32.whl", hash = "sha256:2107b0c024d1b35f4083f11245c0e23846ae64d02f40b2b226684840260ed054"}, - {file = "grpcio-1.76.0-cp311-cp311-win_amd64.whl", hash = "sha256:522175aba7af9113c48ec10cc471b9b9bd4f6ceb36aeb4544a8e2c80ed9d252d"}, - {file = "grpcio-1.76.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:81fd9652b37b36f16138611c7e884eb82e0cec137c40d3ef7c3f9b3ed00f6ed8"}, - {file = "grpcio-1.76.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:04bbe1bfe3a68bbfd4e52402ab7d4eb59d72d02647ae2042204326cf4bbad280"}, - {file = "grpcio-1.76.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d388087771c837cdb6515539f43b9d4bf0b0f23593a24054ac16f7a960be16f4"}, - {file = "grpcio-1.76.0-cp312-cp312-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:9f8f757bebaaea112c00dba718fc0d3260052ce714e25804a03f93f5d1c6cc11"}, - {file = "grpcio-1.76.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:980a846182ce88c4f2f7e2c22c56aefd515daeb36149d1c897f83cf57999e0b6"}, - {file = "grpcio-1.76.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f92f88e6c033db65a5ae3d97905c8fea9c725b63e28d5a75cb73b49bda5024d8"}, - {file = "grpcio-1.76.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:4baf3cbe2f0be3289eb68ac8ae771156971848bb8aaff60bad42005539431980"}, - {file = "grpcio-1.76.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:615ba64c208aaceb5ec83bfdce7728b80bfeb8be97562944836a7a0a9647d882"}, - {file = "grpcio-1.76.0-cp312-cp312-win32.whl", hash = "sha256:45d59a649a82df5718fd9527ce775fd66d1af35e6d31abdcdc906a49c6822958"}, - {file = "grpcio-1.76.0-cp312-cp312-win_amd64.whl", hash = "sha256:c088e7a90b6017307f423efbb9d1ba97a22aa2170876223f9709e9d1de0b5347"}, - {file = "grpcio-1.76.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:26ef06c73eb53267c2b319f43e6634c7556ea37672029241a056629af27c10e2"}, - {file = "grpcio-1.76.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:45e0111e73f43f735d70786557dc38141185072d7ff8dc1829d6a77ac1471468"}, - {file = "grpcio-1.76.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:83d57312a58dcfe2a3a0f9d1389b299438909a02db60e2f2ea2ae2d8034909d3"}, - {file = "grpcio-1.76.0-cp313-cp313-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:3e2a27c89eb9ac3d81ec8835e12414d73536c6e620355d65102503064a4ed6eb"}, - {file = "grpcio-1.76.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:61f69297cba3950a524f61c7c8ee12e55c486cb5f7db47ff9dcee33da6f0d3ae"}, - {file = "grpcio-1.76.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6a15c17af8839b6801d554263c546c69c4d7718ad4321e3166175b37eaacca77"}, - {file = "grpcio-1.76.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:25a18e9810fbc7e7f03ec2516addc116a957f8cbb8cbc95ccc80faa072743d03"}, - {file = "grpcio-1.76.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:931091142fd8cc14edccc0845a79248bc155425eee9a98b2db2ea4f00a235a42"}, - {file = "grpcio-1.76.0-cp313-cp313-win32.whl", hash = "sha256:5e8571632780e08526f118f74170ad8d50fb0a48c23a746bef2a6ebade3abd6f"}, - {file = "grpcio-1.76.0-cp313-cp313-win_amd64.whl", hash = "sha256:f9f7bd5faab55f47231ad8dba7787866b69f5e93bc306e3915606779bbfb4ba8"}, - {file = "grpcio-1.76.0-cp314-cp314-linux_armv7l.whl", hash = "sha256:ff8a59ea85a1f2191a0ffcc61298c571bc566332f82e5f5be1b83c9d8e668a62"}, - {file = "grpcio-1.76.0-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:06c3d6b076e7b593905d04fdba6a0525711b3466f43b3400266f04ff735de0cd"}, - {file = "grpcio-1.76.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fd5ef5932f6475c436c4a55e4336ebbe47bd3272be04964a03d316bbf4afbcbc"}, - {file = "grpcio-1.76.0-cp314-cp314-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:b331680e46239e090f5b3cead313cc772f6caa7d0fc8de349337563125361a4a"}, - {file = "grpcio-1.76.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2229ae655ec4e8999599469559e97630185fdd53ae1e8997d147b7c9b2b72cba"}, - {file = "grpcio-1.76.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:490fa6d203992c47c7b9e4a9d39003a0c2bcc1c9aa3c058730884bbbb0ee9f09"}, - {file = "grpcio-1.76.0-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:479496325ce554792dba6548fae3df31a72cef7bad71ca2e12b0e58f9b336bfc"}, - {file = "grpcio-1.76.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:1c9b93f79f48b03ada57ea24725d83a30284a012ec27eab2cf7e50a550cbbbcc"}, - {file = "grpcio-1.76.0-cp314-cp314-win32.whl", hash = "sha256:747fa73efa9b8b1488a95d0ba1039c8e2dca0f741612d80415b1e1c560febf4e"}, - {file = "grpcio-1.76.0-cp314-cp314-win_amd64.whl", hash = "sha256:922fa70ba549fce362d2e2871ab542082d66e2aaf0c19480ea453905b01f384e"}, - {file = "grpcio-1.76.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:8ebe63ee5f8fa4296b1b8cfc743f870d10e902ca18afc65c68cf46fd39bb0783"}, - {file = "grpcio-1.76.0-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:3bf0f392c0b806905ed174dcd8bdd5e418a40d5567a05615a030a5aeddea692d"}, - {file = "grpcio-1.76.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0b7604868b38c1bfd5cf72d768aedd7db41d78cb6a4a18585e33fb0f9f2363fd"}, - {file = "grpcio-1.76.0-cp39-cp39-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:e6d1db20594d9daba22f90da738b1a0441a7427552cc6e2e3d1297aeddc00378"}, - {file = "grpcio-1.76.0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d099566accf23d21037f18a2a63d323075bebace807742e4b0ac210971d4dd70"}, - {file = "grpcio-1.76.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:ebea5cc3aa8ea72e04df9913492f9a96d9348db876f9dda3ad729cfedf7ac416"}, - {file = "grpcio-1.76.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:0c37db8606c258e2ee0c56b78c62fc9dee0e901b5dbdcf816c2dd4ad652b8b0c"}, - {file = "grpcio-1.76.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:ebebf83299b0cb1721a8859ea98f3a77811e35dce7609c5c963b9ad90728f886"}, - {file = "grpcio-1.76.0-cp39-cp39-win32.whl", hash = "sha256:0aaa82d0813fd4c8e589fac9b65d7dd88702555f702fb10417f96e2a2a6d4c0f"}, - {file = "grpcio-1.76.0-cp39-cp39-win_amd64.whl", hash = "sha256:acab0277c40eff7143c2323190ea57b9ee5fd353d8190ee9652369fae735668a"}, - {file = "grpcio-1.76.0.tar.gz", hash = "sha256:7be78388d6da1a25c0d5ec506523db58b18be22d9c37d8d3a32c08be4987bd73"}, -] - -[package.dependencies] -typing-extensions = ">=4.12,<5.0" - -[package.extras] -protobuf = ["grpcio-tools (>=1.76.0)"] - [[package]] name = "h11" version = "0.16.0" @@ -8013,21 +7921,6 @@ dev = ["Cython (>=3.0,<4.0)", "setuptools (>=60)"] docs = ["Sphinx (>=4.1.2,<4.2.0)", "sphinx_rtd_theme (>=0.5.2,<0.6.0)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"] test = ["aiohttp (>=3.10.5)", "flake8 (>=6.1,<7.0)", "mypy (>=0.800)", "psutil", "pyOpenSSL (>=25.3.0,<25.4.0)", "pycodestyle (>=2.11.0,<2.12.0)"] -[[package]] -name = "validators" -version = "0.35.0" -description = "Python Data Validation for Humans™" -optional = false -python-versions = ">=3.9" -groups = ["main"] -files = [ - {file = "validators-0.35.0-py3-none-any.whl", hash = "sha256:e8c947097eae7892cb3d26868d637f79f47b4a0554bc6b80065dfe5aac3705dd"}, - {file = "validators-0.35.0.tar.gz", hash = "sha256:992d6c48a4e77c81f1b4daba10d16c3a9bb0dbb79b3a19ea847ff0928e70497a"}, -] - -[package.extras] -crypto-eth-addresses = ["eth-hash[pycryptodome] (>=0.7.0)"] - [[package]] name = "virtualenv" version = "20.36.1" @@ -8183,42 +8076,6 @@ files = [ [package.dependencies] anyio = ">=3.0.0" -[[package]] -name = "weaviate" -version = "0.1.2" -description = "A placeholder package for the Weaviate name" -optional = false -python-versions = "*" -groups = ["main"] -files = [ - {file = "weaviate-0.1.2-py3-none-any.whl", hash = "sha256:40f1c1cf0b769036315d2b6026c8cd823a3a6e951c90d4e70a001a770ba8a444"}, - {file = "weaviate-0.1.2.tar.gz", hash = "sha256:a381b8bb0eb236bd10256def8612953ed9024e6738b8a259e7ec11e626ae0665"}, -] - -[[package]] -name = "weaviate-client" -version = "4.19.2" -description = "A python native Weaviate client" -optional = false -python-versions = ">=3.10" -groups = ["main"] -files = [ - {file = "weaviate_client-4.19.2-py3-none-any.whl", hash = "sha256:e78306d47c574c4035c87223e480bb77bd6e54142a21c4c58522dd43019fe493"}, - {file = "weaviate_client-4.19.2.tar.gz", hash = "sha256:99e76e912c95762436089cd5feedbfeea31e892aa13b6ad94729a2a54b316c45"}, -] - -[package.dependencies] -authlib = ">=1.6.5,<2.0.0" -deprecation = ">=2.1.0,<3.0.0" -grpcio = ">=1.59.5,<1.80.0" -httpx = ">=0.26.0,<0.29.0" -protobuf = ">=4.21.6,<7.0.0" -pydantic = ">=2.12.0,<3.0.0" -validators = ">=0.34.0,<1.0.0" - -[package.extras] -agents = ["weaviate-agents (>=1.0.0,<2.0.0)"] - [[package]] name = "websockets" version = "16.0" @@ -8610,4 +8467,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.13,<3.14" -content-hash = "2065fc338a5f211587a8838f8c3dcee58987ec94e5613452be47ac8dec37e6d0" +content-hash = "60288465ee2e36d9dd0099a018a0949f08dea9e3f96abad034a66d7976317d0c" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index c092253a3..f452f7896 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -28,8 +28,6 @@ tenacity = "^8.2.3" structlog = "^24.1.0" pydantic-settings = "^2.7.0" psycopg2-binary = "^2.9.10" -weaviate = "^0.1.2" -weaviate-client = "^4.10.2" markitdown = "^0.0.1a3" neo4j = "^5.27.0" pyodbc = "^5.2.0" From f6511ade08fa1457cf0d39b978cc01162b53f787 Mon Sep 17 00:00:00 2001 From: Felix Schmetz Date: Thu, 23 Apr 2026 10:50:57 +0200 Subject: [PATCH 16/25] fix: harden LLM fallback chain wiring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - parse_llm_fallback_chain now validates (provider, model) pairs against MODEL_REGISTRY, so combos like together:mistral-large fail at startup with an actionable error instead of crashing later in the factory. - Move the detailed env-var message out of core.exceptions (which had to import from adapters.llm.registry) into UnavailableLLM itself. The exception keeps a generic default pointing at the docs. - Drop the unreachable empty-chain guard in _create_search_services — the parser always returns a non-empty list. - Test the env-var-name assertion dynamically against PROVIDER_API_KEY_SETTINGS so it stays in sync with the registry. - Add a 'Configuring the LLM provider chain' section to the Search docs covering API keys, LLM_FALLBACK_CHAIN format, and fallback semantics, so the exception's doc reference has a real destination. --- .../adapters/llm/tests/test_unavailable.py | 11 ++-- backend/airweave/adapters/llm/unavailable.py | 15 ++++-- backend/airweave/core/container/factory.py | 3 -- backend/airweave/core/exceptions.py | 16 ++---- backend/airweave/domains/search/config.py | 12 ++++- .../domains/search/tests/test_config.py | 9 ++++ fern/docs/pages/search.mdx | 53 +++++++++++++++++++ 7 files changed, 91 insertions(+), 28 deletions(-) diff --git a/backend/airweave/adapters/llm/tests/test_unavailable.py b/backend/airweave/adapters/llm/tests/test_unavailable.py index 8da5f316b..5eae868b9 100644 --- a/backend/airweave/adapters/llm/tests/test_unavailable.py +++ b/backend/airweave/adapters/llm/tests/test_unavailable.py @@ -5,6 +5,7 @@ import pytest from pydantic import BaseModel +from airweave.adapters.llm.registry import PROVIDER_API_KEY_SETTINGS from airweave.adapters.llm.unavailable import UnavailableLLM from airweave.core.exceptions import LLMUnavailableError @@ -45,12 +46,6 @@ def test_error_message_mentions_accepted_api_key_env_vars() -> None: _ = llm.model_spec message = str(excinfo.value) - for env_var in ( - "TOGETHER_API_KEY", - "ANTHROPIC_API_KEY", - "MISTRAL_API_KEY", - "GROQ_API_KEY", - "CEREBRAS_API_KEY", - "LLM_FALLBACK_CHAIN", - ): + for env_var in PROVIDER_API_KEY_SETTINGS.values(): assert env_var in message, f"{env_var} missing from error message" + assert "LLM_FALLBACK_CHAIN" in message diff --git a/backend/airweave/adapters/llm/unavailable.py b/backend/airweave/adapters/llm/unavailable.py index cca436d00..8d4c76a81 100644 --- a/backend/airweave/adapters/llm/unavailable.py +++ b/backend/airweave/adapters/llm/unavailable.py @@ -13,12 +13,19 @@ from pydantic import BaseModel -from airweave.adapters.llm.registry import LLMModelSpec +from airweave.adapters.llm.registry import PROVIDER_API_KEY_SETTINGS, LLMModelSpec from airweave.adapters.llm.tool_response import LLMResponse from airweave.core.exceptions import LLMUnavailableError T = TypeVar("T", bound=BaseModel) +_DETAILED_MESSAGE = ( + "No LLM provider configured. Set one of: " + f"{', '.join(PROVIDER_API_KEY_SETTINGS.values())} — " + "or customize the chain via LLM_FALLBACK_CHAIN " + "(format: 'provider:model,provider:model')." +) + class UnavailableLLM: """LLMProtocol implementation that raises on every call. @@ -31,7 +38,7 @@ class UnavailableLLM: @property def model_spec(self) -> LLMModelSpec: """Raise because no provider is configured.""" - raise LLMUnavailableError() + raise LLMUnavailableError(_DETAILED_MESSAGE) async def structured_output( self, @@ -41,7 +48,7 @@ async def structured_output( thinking: bool = False, ) -> T: """Raise because no provider is configured.""" - raise LLMUnavailableError() + raise LLMUnavailableError(_DETAILED_MESSAGE) async def chat( self, @@ -52,7 +59,7 @@ async def chat( max_tokens: int | None = None, ) -> LLMResponse: """Raise because no provider is configured.""" - raise LLMUnavailableError() + raise LLMUnavailableError(_DETAILED_MESSAGE) async def close(self) -> None: """No-op: the null-object holds no resources.""" diff --git a/backend/airweave/core/container/factory.py b/backend/airweave/core/container/factory.py index 84cfc320e..a18a64669 100644 --- a/backend/airweave/core/container/factory.py +++ b/backend/airweave/core/container/factory.py @@ -1279,9 +1279,6 @@ def _create_search_services( config = SearchConfig() # 1. Tokenizer — validate against primary LLM model requirements - if not config.LLM_FALLBACK_CHAIN: - raise ValueError("LLM_FALLBACK_CHAIN is empty — at least one provider is required") - primary_provider, primary_model = config.LLM_FALLBACK_CHAIN[0] primary_llm_spec = get_llm_model_spec(primary_provider, primary_model) diff --git a/backend/airweave/core/exceptions.py b/backend/airweave/core/exceptions.py index 83110ae45..db8f69c65 100644 --- a/backend/airweave/core/exceptions.py +++ b/backend/airweave/core/exceptions.py @@ -4,8 +4,6 @@ from pydantic import ValidationError -from airweave.adapters.llm.registry import PROVIDER_API_KEY_SETTINGS - class AirweaveException(Exception): """Base exception for Airweave services.""" @@ -218,15 +216,11 @@ class LLMUnavailableError(AirweaveException): search surface this on first use and map to HTTP 503. """ - def __init__(self, message: Optional[str] = None): - """Create a new LLMUnavailableError with an actionable default message.""" - if message is None: - env_vars = ", ".join(PROVIDER_API_KEY_SETTINGS.values()) - message = ( - f"No LLM provider configured. Set one of: {env_vars} — " - "or customize the chain via LLM_FALLBACK_CHAIN " - "(format: 'provider:model,provider:model')." - ) + def __init__( + self, + message: str = ("No LLM provider configured. See LLM_FALLBACK_CHAIN docs for setup."), + ): + """Create a new LLMUnavailableError with an actionable message.""" self.message = message super().__init__(self.message) diff --git a/backend/airweave/domains/search/config.py b/backend/airweave/domains/search/config.py index a096a8f63..224e2dad1 100644 --- a/backend/airweave/domains/search/config.py +++ b/backend/airweave/domains/search/config.py @@ -2,7 +2,7 @@ from enum import Enum -from airweave.adapters.llm.registry import LLMModel, LLMProvider +from airweave.adapters.llm.registry import MODEL_REGISTRY, LLMModel, LLMProvider from airweave.adapters.tokenizer.registry import TokenizerEncoding, TokenizerType from airweave.core.config import settings @@ -53,7 +53,15 @@ def parse_llm_fallback_chain(raw: str | None) -> list[tuple[LLMProvider, LLMMode f"Unknown model {model_raw!r} in LLM_FALLBACK_CHAIN. " f"Accepted: {list(_VALID_MODELS)}." ) - parsed.append((_VALID_PROVIDERS[provider_raw], _VALID_MODELS[model_raw])) + provider = _VALID_PROVIDERS[provider_raw] + model = _VALID_MODELS[model_raw] + provider_models = MODEL_REGISTRY.get(provider, {}) + if model not in provider_models: + raise ValueError( + f"Model {model_raw!r} not available for provider {provider_raw!r}. " + f"Available: {[m.value for m in provider_models]}." + ) + parsed.append((provider, model)) if not parsed: return list(_DEFAULT_LLM_FALLBACK_CHAIN) diff --git a/backend/airweave/domains/search/tests/test_config.py b/backend/airweave/domains/search/tests/test_config.py index e7a91e6af..f1c267cba 100644 --- a/backend/airweave/domains/search/tests/test_config.py +++ b/backend/airweave/domains/search/tests/test_config.py @@ -73,3 +73,12 @@ def test_missing_colon_raises_helpful_error() -> None: def test_trailing_comma_is_tolerated() -> None: parsed = parse_llm_fallback_chain("mistral:mistral-large,") assert parsed == [(LLMProvider.MISTRAL, LLMModel.MISTRAL_LARGE)] + + +def test_valid_enums_but_invalid_pair_raises() -> None: + with pytest.raises(ValueError) as excinfo: + parse_llm_fallback_chain("together:mistral-large") + message = str(excinfo.value) + assert "mistral-large" in message + assert "together" in message + assert "zai-glm-5" in message diff --git a/fern/docs/pages/search.mdx b/fern/docs/pages/search.mdx index 114f7736a..0432c0b4c 100644 --- a/fern/docs/pages/search.mdx +++ b/fern/docs/pages/search.mdx @@ -581,3 +581,56 @@ This allows expressions like: `(A AND B) OR (C AND D)` ## Response Format All three tiers return the same `SearchV2Response` with a `results` array. See the [API Reference](/api-reference/collections/instant-search) for the full response schema and interactive examples. + + +## Configuring the LLM provider chain + +Classic and Agentic search call an LLM. Instant search does not — a backend with no LLM configured still answers instant queries, and Classic/Agentic return HTTP 503 until an API key is set. This section is only relevant to self-hosted deployments; the managed service ships with providers configured. + +### Default chain + +Out of the box Airweave tries providers in this order: + +1. `together:zai-glm-5` +2. `anthropic:claude-sonnet-4.6` + +The first provider whose API key is set and that responds successfully handles the request. Subsequent entries are only tried on failure. + +### Setting API keys + +Set at least one of the following environment variables on the backend: + +| Env var | Provider | +|---|---| +| `TOGETHER_API_KEY` | Together | +| `ANTHROPIC_API_KEY` | Anthropic | +| `MISTRAL_API_KEY` | Mistral | +| `GROQ_API_KEY` | Groq | +| `CEREBRAS_API_KEY` | Cerebras | + +If none are set, the backend boots normally; Classic/Agentic search return `503 Service Unavailable` with a message listing these variables. + +### Overriding the chain + +Set `LLM_FALLBACK_CHAIN` to a comma-separated list of `provider:model` pairs. Example: + +``` +LLM_FALLBACK_CHAIN=cerebras:gpt-oss-120b,anthropic:claude-sonnet-4.6 +``` + +Supported providers: `cerebras`, `groq`, `anthropic`, `together`, `mistral`. The full list of models per provider lives in `backend/airweave/adapters/llm/registry.py`. + +The parser validates three things at startup: + +- Every provider is a known provider. +- Every model is a known model. +- Every `(provider, model)` combination exists in the registry (e.g. `together:mistral-large` is rejected because `mistral-large` is hosted on Mistral, not Together). + +Misconfiguration fails fast with an error that lists the accepted values. + +### Fallback semantics + +- Providers without an API key are silently skipped when the chain is built. +- Providers whose initialization raises are logged and skipped. +- If the resulting chain is empty, the backend wires a null LLM — instant search still works; classic/agentic return 503. +- When a call fails in a chained provider, the next one is tried; a circuit breaker temporarily removes providers that recently failed. From ada1033a0963d6ea295a1b3ec94bf17fefa1624a Mon Sep 17 00:00:00 2001 From: Felix Schmetz Date: Thu, 23 Apr 2026 11:04:51 +0200 Subject: [PATCH 17/25] docs(search): address PR review on LLM provider chain section Apply @hiddeco's suggestions: - Lift 'self-hosted only' caveat into a Fern Callout. - Tighten 'Out of the box, Airweave...' phrasing. - Rework 'first provider with an API key set that responds...' for clarity. - Reword 'Misconfiguration is caught at startup...'. - Capitalize 'Classic/Agentic' in the fallback-semantics bullet to match the rest of the doc. --- fern/docs/pages/search.mdx | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/fern/docs/pages/search.mdx b/fern/docs/pages/search.mdx index 0432c0b4c..d0d343e59 100644 --- a/fern/docs/pages/search.mdx +++ b/fern/docs/pages/search.mdx @@ -585,16 +585,20 @@ All three tiers return the same `SearchV2Response` with a `results` array. See t ## Configuring the LLM provider chain -Classic and Agentic search call an LLM. Instant search does not — a backend with no LLM configured still answers instant queries, and Classic/Agentic return HTTP 503 until an API key is set. This section is only relevant to self-hosted deployments; the managed service ships with providers configured. + + This section is only relevant to self-hosted deployments. The managed service ships with providers configured. + + +Classic and Agentic search call an LLM. Instant search does not — a backend with no LLM configured still answers instant queries, and Classic/Agentic return HTTP 503 until an API key is set. ### Default chain -Out of the box Airweave tries providers in this order: +Out of the box, Airweave tries providers in this order: 1. `together:zai-glm-5` 2. `anthropic:claude-sonnet-4.6` -The first provider whose API key is set and that responds successfully handles the request. Subsequent entries are only tried on failure. +The first provider with an API key set that responds successfully handles the request. Subsequent entries are tried only on failure. ### Setting API keys @@ -626,11 +630,11 @@ The parser validates three things at startup: - Every model is a known model. - Every `(provider, model)` combination exists in the registry (e.g. `together:mistral-large` is rejected because `mistral-large` is hosted on Mistral, not Together). -Misconfiguration fails fast with an error that lists the accepted values. +Misconfiguration is caught at startup with an error that lists the accepted values. ### Fallback semantics - Providers without an API key are silently skipped when the chain is built. - Providers whose initialization raises are logged and skipped. -- If the resulting chain is empty, the backend wires a null LLM — instant search still works; classic/agentic return 503. +- If the resulting chain is empty, the backend wires a null LLM — instant search still works; Classic/Agentic return 503. - When a call fails in a chained provider, the next one is tried; a circuit breaker temporarily removes providers that recently failed. From 99ca0ed357c085fc760884cf8c99e00f08767eba Mon Sep 17 00:00:00 2001 From: Daan Manneke Date: Tue, 5 May 2026 10:58:45 +0200 Subject: [PATCH 18/25] fix(sharepoint_online): stop treating org-scoped sharing links as public MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit An "anyone in your organization with the link" share is a capability, not an ACL: a user must possess the link URL to access the file. Previously, ``extract_access_control`` mapped ``link.scope == "organization"`` to ``AccessControl.is_public = True``, causing the search broker to skip all viewer/group checks for any file with such a link. Net effect of the bug: any file that ever had an org-scoped share link created became visible in search to every authenticated user in the tenant — including files on private sites with restricted folder permissions, where only specific users / groups should have access. Reported by Mistral on a private site whose files appeared in search results for unrelated tenant users. Fix: only ``link.scope == "anonymous"`` flips ``is_public``. Org-scoped link permissions are now skipped — their audience is already represented via the ``SharingLinks...`` SP site group that SharePoint creates per link, which the connector tracks separately and resolves through ACL membership at search time. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../platform/sources/sharepoint_online/acl.py | 16 +- .../sources/test_sharepoint_online_acl.py | 193 ++++++++++++++++++ 2 files changed, 199 insertions(+), 10 deletions(-) create mode 100644 backend/tests/unit/platform/sources/test_sharepoint_online_acl.py diff --git a/backend/airweave/platform/sources/sharepoint_online/acl.py b/backend/airweave/platform/sources/sharepoint_online/acl.py index 404cc0d6a..8a50b4b5e 100644 --- a/backend/airweave/platform/sources/sharepoint_online/acl.py +++ b/backend/airweave/platform/sources/sharepoint_online/acl.py @@ -6,7 +6,11 @@ - grantedToV2.user → user:{email} - grantedToV2.group (Entra ID) → group:entra:{group_id} - grantedToV2.siteGroup → group:sp:{site_group_name} -- link with scope "organization" → is_public (org-wide access) +- link with scope "anonymous" → is_public (true tenant-wide / internet-wide access) + +Organization-scoped sharing links ("anyone in your org with the link") are NOT +treated as public. Possession of the link URL is required, so the audience is +captured via the per-link SharingLinks.* SP site group rather than is_public. """ from typing import Any, Dict, List, Optional @@ -76,14 +80,6 @@ def has_read_permission(permission: Dict[str, Any]) -> bool: return any(r in ("read", "write", "owner", "sp.full control") for r in roles) -def is_org_wide_link(permission: Dict[str, Any]) -> bool: - """Check if a permission is an organization-wide sharing link.""" - link = permission.get("link") - if not link: - return False - return link.get("scope", "") == "organization" - - def is_anonymous_link(permission: Dict[str, Any]) -> bool: """Check if a permission is an anonymous sharing link.""" link = permission.get("link") @@ -122,7 +118,7 @@ async def extract_access_control( if not has_read_permission(perm): continue - if is_org_wide_link(perm) or is_anonymous_link(perm): + if is_anonymous_link(perm): is_public = True continue diff --git a/backend/tests/unit/platform/sources/test_sharepoint_online_acl.py b/backend/tests/unit/platform/sources/test_sharepoint_online_acl.py new file mode 100644 index 000000000..0ff15551c --- /dev/null +++ b/backend/tests/unit/platform/sources/test_sharepoint_online_acl.py @@ -0,0 +1,193 @@ +"""Unit tests for SharePoint Online ACL extraction. + +Covers ``extract_access_control`` and the rules around how Microsoft Graph +sharing-link permissions map to ``AccessControl.is_public``. + +Background — bug fix: + Organization-scoped sharing links (``link.scope == "organization"``, + "anyone in your org with the link") used to set ``is_public = True``, + which made the search broker bypass all viewer checks. That mis-modeled + SharePoint semantics: an org-scoped link requires possession of the + link URL to grant access. Only ``link.scope == "anonymous"`` is true + public access. +""" + +import pytest + +from airweave.platform.sources.sharepoint_online.acl import extract_access_control + +# --------------------------------------------------------------------------- +# Helpers — build minimal Graph permission objects +# --------------------------------------------------------------------------- + + +def _link_perm(scope: str, roles=None) -> dict: + """Sharing-link permission with the given scope (no grantedTo principal).""" + return { + "id": f"link-{scope}", + "roles": roles if roles is not None else ["write"], + "link": {"scope": scope, "type": "edit"}, + "grantedToIdentitiesV2": [], + "grantedToIdentities": [], + } + + +def _site_group_perm(name: str, group_id: str = "5", roles=None) -> dict: + return { + "id": f"sg-{group_id}", + "roles": roles if roles is not None else ["write"], + "grantedToV2": {"siteGroup": {"displayName": name, "id": group_id}}, + } + + +def _user_perm(email: str, roles=None) -> dict: + return { + "id": f"u-{email}", + "roles": roles if roles is not None else ["read"], + "grantedToV2": {"user": {"email": email, "displayName": email}}, + } + + +def _entra_group_perm(group_id: str, roles=None) -> dict: + return { + "id": f"eg-{group_id}", + "roles": roles if roles is not None else ["read"], + "grantedToV2": {"group": {"id": group_id}}, + } + + +# --------------------------------------------------------------------------- +# Sharing-link scope handling +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_organization_scoped_link_does_not_set_is_public(): + """Org-scoped link by itself must not flip is_public. + + Regression: the previous behavior treated organization-scoped links as + fully public, bypassing all viewer checks at search time. + """ + ac = await extract_access_control([_link_perm("organization")]) + assert ac.is_public is False + assert ac.viewers == [] + + +@pytest.mark.asyncio +async def test_anonymous_link_sets_is_public(): + ac = await extract_access_control([_link_perm("anonymous")]) + assert ac.is_public is True + + +@pytest.mark.asyncio +async def test_org_and_anonymous_links_together_still_public_via_anonymous(): + ac = await extract_access_control([_link_perm("organization"), _link_perm("anonymous")]) + assert ac.is_public is True + + +@pytest.mark.asyncio +async def test_users_scoped_link_does_not_set_is_public(): + """``users``-scoped links target named recipients, not the org.""" + ac = await extract_access_control([_link_perm("users")]) + assert ac.is_public is False + + +@pytest.mark.asyncio +async def test_unknown_link_scope_does_not_set_is_public(): + """Future / unrecognized scopes default to non-public.""" + ac = await extract_access_control([_link_perm("someFutureScope")]) + assert ac.is_public is False + + +# --------------------------------------------------------------------------- +# Mixed permissions — the realistic case +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_org_link_alongside_explicit_grants_extracts_only_grants(): + """Org-link permission is skipped; explicit grants populate viewers. + + Mirrors the Mistral bug-report payload: a file with one organization- + scoped sharing link plus the inherited site-group grants. The fix must + keep is_public false while still extracting Owners / Members / Visitors. + """ + perms = [ + _link_perm("organization"), + _site_group_perm("Access Control Tests Owners", group_id="3", roles=["owner"]), + _site_group_perm("Access Control Tests Members", group_id="5", roles=["write"]), + _site_group_perm("Access Control Tests Visitors", group_id="4", roles=["read"]), + ] + ac = await extract_access_control(perms) + assert ac.is_public is False + assert set(ac.viewers) == { + "group:sp:access_control_tests_owners", + "group:sp:access_control_tests_members", + "group:sp:access_control_tests_visitors", + } + + +@pytest.mark.asyncio +async def test_user_and_entra_group_grants_extracted(): + perms = [ + _user_perm("alice@example.com"), + _entra_group_perm("11111111-2222-3333-4444-555555555555"), + ] + ac = await extract_access_control(perms) + assert ac.is_public is False + assert set(ac.viewers) == { + "user:alice@example.com", + "group:entra:11111111-2222-3333-4444-555555555555", + } + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_empty_permissions_returns_empty_access_control(): + ac = await extract_access_control([]) + assert ac.is_public is False + assert ac.viewers == [] + + +@pytest.mark.asyncio +async def test_permission_without_read_role_is_ignored(): + """Roles other than read/write/owner/sp.full control don't grant viewing.""" + perms = [ + { + "id": "restricted", + "roles": ["restricted"], + "grantedToV2": {"user": {"email": "alice@example.com"}}, + }, + ] + ac = await extract_access_control(perms) + assert ac.is_public is False + assert ac.viewers == [] + + +@pytest.mark.asyncio +async def test_org_link_without_read_role_is_ignored_entirely(): + """A link without a read-equivalent role doesn't even reach scope check.""" + perms = [ + { + "id": "link-restricted", + "roles": ["restricted"], + "link": {"scope": "organization"}, + } + ] + ac = await extract_access_control(perms) + assert ac.is_public is False + assert ac.viewers == [] + + +@pytest.mark.asyncio +async def test_duplicate_principal_only_added_once(): + perms = [ + _user_perm("alice@example.com", roles=["read"]), + _user_perm("alice@example.com", roles=["write"]), + ] + ac = await extract_access_control(perms) + assert ac.viewers == ["user:alice@example.com"] From fccafde91bfea31e46c59e6411ea36aaef59b535 Mon Sep 17 00:00:00 2001 From: Daan Manneke Date: Tue, 5 May 2026 11:46:28 +0200 Subject: [PATCH 19/25] fix(sharepoint_online): scope sharing-link viewers per-file, drop blanket-attach MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The SharePoint Online connector blanket-attached every site group returned by ``/_api/web/sitegroups`` to every file in the site (``_fetch_sp_group_viewers`` plus the merge block in ``_full_sync``). That over-granted in three ways: 1. ``SharingLinks..…`` system groups (one per sharing link in the site) attached to every file, so a user who redeemed a link for *one* file gained access to *every* file in the site. 2. Custom site groups (e.g. an admin's ``engineering-test`` group granted only on a specific subfolder) attached to every file, so members saw files outside the folder they were granted on. 3. Default ``Members`` / ``Owners`` / ``Visitors`` and any ``Limited Access System Group`` re-attached to files where the admin had broken inheritance and removed those very groups, undoing deliberate access tightening. Fix --- Trust the per-item Graph permissions response, which already returns the correct set of principals for both intact and broken inheritance. The one thing Graph does not represent as a ``siteGroup`` grant is the sharing-link audience: it returns a ``link`` permission instead, while SharePoint internally tracks redeemers as members of a ``SharingLinks...`` site group. We translate that ``link`` permission into the matching SharingLinks group viewer on that one file, scoped by the file's SharePoint UniqueId. Changes ------- - ``acl.py``: new ``link_permission_to_sp_group_viewer`` helper. Empirically verified scope/type → group-suffix mapping (organization+edit → OrganizationEdit, organization+view → OrganizationView, users+edit / users+view both → Flexible). ``extract_access_control`` grows an optional ``sp_unique_id`` parameter; when present, non-anonymous link permissions are appended as the per-link site-group viewer. - ``client.py``: new ``GraphClient.get_item_sp_unique_id`` — fetches just ``sharepointIds.listItemUniqueId`` for an item. - ``source.py``: deletes ``_fetch_sp_group_viewers`` and the merge block. Each file-build path now does a conditional UniqueId lookup only when the item's permissions actually contain a ``link`` block (most files have none, so the per-file cost is zero by default). - ``builders.py``: ``build_file_entity`` plumbs ``sp_unique_id`` through. Tests ----- 20 unit tests (existing + new) covering scope/type → group-suffix mapping, the anonymous → ``is_public`` short-circuit (no derived viewer), missing sp_unique_id / link_id / unknown scope returning ``None``, and the combined "link + explicit grants" payload. Verification ------------ End-to-end against neenacorp.sharepoint.com with five test fixtures designed to expose each over-grant mode independently. Pre/post viewer shrink confirmed (e.g. jean-doc.txt 8 viewers → 3, eng-doc.txt 7 → 3). Search-as-user matrix shows the SharingLinks redeemer (acl_test_user1) went from seeing all 4 fixture files to seeing only the file the link points to; the custom-group member (acl_test_user3) went from all 4 to just the file the group was granted on. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../platform/sources/sharepoint_online/acl.py | 77 +++++++++- .../sources/sharepoint_online/builders.py | 18 ++- .../sources/sharepoint_online/client.py | 23 +++ .../sources/sharepoint_online/source.py | 95 ++++++------- .../sources/test_sharepoint_online_acl.py | 133 +++++++++++++++--- 5 files changed, 268 insertions(+), 78 deletions(-) diff --git a/backend/airweave/platform/sources/sharepoint_online/acl.py b/backend/airweave/platform/sources/sharepoint_online/acl.py index 8a50b4b5e..b770540fe 100644 --- a/backend/airweave/platform/sources/sharepoint_online/acl.py +++ b/backend/airweave/platform/sources/sharepoint_online/acl.py @@ -7,10 +7,12 @@ - grantedToV2.group (Entra ID) → group:entra:{group_id} - grantedToV2.siteGroup → group:sp:{site_group_name} - link with scope "anonymous" → is_public (true tenant-wide / internet-wide access) - -Organization-scoped sharing links ("anyone in your org with the link") are NOT -treated as public. Possession of the link URL is required, so the audience is -captured via the per-link SharingLinks.* SP site group rather than is_public. +- link with scope "organization" or "users" → group:sp:sharinglinks.{itemId}.{scopeRole}.{linkId} + derived from the permission and the file's SP UniqueId. Microsoft represents + these links as a ``link`` permission rather than a ``siteGroup`` grant, but + internally tracks redeemers as members of a ``SharingLinks...`` + SP site group. We translate so that membership intersection at search time + works for users who have actually redeemed the link. """ from typing import Any, Dict, List, Optional @@ -88,6 +90,60 @@ def is_anonymous_link(permission: Dict[str, Any]) -> bool: return link.get("scope", "") == "anonymous" +# Mapping from Graph (link.scope, link.type) to SharePoint's SharingLinks group +# suffix. Verified empirically against neenacorp.sharepoint.com: +# organization+edit → OrganizationEdit +# organization+view → OrganizationView +# users+edit / users+view → Flexible (both collapse; SP stores role separately) +# Anonymous is handled by ``is_public`` and does not need a derived group. +_SCOPE_ROLE_MAP: Dict[tuple, str] = { + ("organization", "edit"): "OrganizationEdit", + ("organization", "view"): "OrganizationView", + ("users", "edit"): "Flexible", + ("users", "view"): "Flexible", +} + + +def link_permission_to_sp_group_viewer( + permission: Dict[str, Any], sp_unique_id: Optional[str] +) -> Optional[str]: + """Derive the SharingLinks SP site group viewer for a non-anonymous link permission. + + SharePoint creates an internal site group named + ``SharingLinks...`` for each sharing + link, whose members are the users who have redeemed the link. The Graph + per-item permissions response represents the link itself but does *not* + return that site group as a separate ``siteGroup`` grant, so we translate. + + Args: + permission: A Graph permission with a ``link`` block. + sp_unique_id: The file's SharePoint UniqueId (lowercase GUID, no + braces). Pass ``None`` for site/drive-level permissions, where + sharing-link translation does not apply. + + Returns: + ``group:sp:sharinglinks...`` viewer string, or + ``None`` if the permission isn't a translatable link or required + fields are missing. + """ + if not sp_unique_id: + return None + link = permission.get("link") + if not link: + return None + scope = link.get("scope", "") + if scope == "anonymous": + return None # handled by is_public + scope_role = _SCOPE_ROLE_MAP.get((scope, link.get("type", ""))) + if not scope_role: + return None # unknown scope/type combination — be conservative + link_id = permission.get("id", "") + if not link_id: + return None + title = f"SharingLinks.{sp_unique_id}.{scope_role}.{link_id}" + return f"group:sp:{title.lower()}" + + def _extract_identity_principals(perm: Dict[str, Any], viewers: List[str]) -> None: """Extract user principals from grantedToIdentitiesV2/grantedToIdentities.""" for identities_key in ("grantedToIdentitiesV2", "grantedToIdentities"): @@ -102,11 +158,17 @@ def _extract_identity_principals(perm: Dict[str, Any], viewers: List[str]) -> No async def extract_access_control( permissions: List[Dict[str, Any]], + sp_unique_id: Optional[str] = None, ) -> AccessControl: """Build AccessControl from Graph API permissions. Args: permissions: List of permission objects from Graph API. + sp_unique_id: The SharePoint UniqueId of the item the permissions + belong to (lowercase GUID, no braces). Required to translate + non-anonymous sharing-link permissions into their corresponding + ``SharingLinks.*`` SP site group viewer. Pass ``None`` for + site/drive-level permission lists. Returns: AccessControl with viewers and is_public flag. @@ -122,6 +184,13 @@ async def extract_access_control( is_public = True continue + # Non-anonymous sharing links: translate to the per-link SP site group. + link_viewer = link_permission_to_sp_group_viewer(perm, sp_unique_id) + if link_viewer: + if link_viewer not in viewers: + viewers.append(link_viewer) + continue + principal = extract_principal_from_permission(perm) if principal and principal not in viewers: viewers.append(principal) diff --git a/backend/airweave/platform/sources/sharepoint_online/builders.py b/backend/airweave/platform/sources/sharepoint_online/builders.py index eea0e83d1..eb8412b6a 100644 --- a/backend/airweave/platform/sources/sharepoint_online/builders.py +++ b/backend/airweave/platform/sources/sharepoint_online/builders.py @@ -96,8 +96,22 @@ async def build_file_entity( site_id: str, breadcrumbs: List[Breadcrumb], permissions: Optional[List[Dict[str, Any]]] = None, + sp_unique_id: Optional[str] = None, ) -> SharePointOnlineFileEntity: - """Build a file entity from Graph API drive item data.""" + """Build a file entity from Graph API drive item data. + + Args: + item_data: Graph drive item dict. + drive_id: Drive ID containing the item. + site_id: Site ID the drive belongs to. + breadcrumbs: Hierarchy breadcrumbs. + permissions: Optional permissions list from + ``/drives/{id}/items/{id}/permissions``. + sp_unique_id: Optional SharePoint ``listItemUniqueId`` for the item. + Required to translate sharing-link permissions; the caller should + fetch it via :meth:`GraphClient.get_item_sp_unique_id` when any + of the item's permissions has a ``link`` block. + """ item_id = item_data.get("id") if not item_id: raise EntityProcessingError("Missing id for file item") @@ -132,7 +146,7 @@ async def build_file_entity( download_url = item_data.get("@microsoft.graph.downloadUrl", "") spo_entity_id = f"spo:file:{drive_id}:{item_id}" - access = await extract_access_control(permissions or []) if permissions else None + access = await extract_access_control(permissions or [], sp_unique_id) if permissions else None return SharePointOnlineFileEntity( url=download_url or item_data.get("webUrl", ""), diff --git a/backend/airweave/platform/sources/sharepoint_online/client.py b/backend/airweave/platform/sources/sharepoint_online/client.py index f98d70d16..e4683a2db 100644 --- a/backend/airweave/platform/sources/sharepoint_online/client.py +++ b/backend/airweave/platform/sources/sharepoint_online/client.py @@ -303,6 +303,29 @@ async def get_item_permissions( return [] raise + async def get_item_sp_unique_id( + self, + drive_id: str, + item_id: str, + ) -> Optional[str]: + """Fetch the SharePoint ``listItemUniqueId`` (lowercase GUID) for a drive item. + + Used to translate sharing-link permissions into the underlying + ``SharingLinks...`` SP site group viewer. + Only worth calling when the item has at least one ``link`` permission; + for items with only direct grants there's nothing to translate. + """ + url = f"{GRAPH_BASE_URL}/drives/{drive_id}/items/{item_id}?$select=sharepointIds" + try: + data = await self.get(url) + except httpx.HTTPStatusError as e: + if e.response.status_code == 404: + return None + raise + sp_ids = data.get("sharepointIds") or {} + luid = sp_ids.get("listItemUniqueId") + return luid.lower() if luid else None + async def get_drive_root_permissions( self, drive_id: str, diff --git a/backend/airweave/platform/sources/sharepoint_online/source.py b/backend/airweave/platform/sources/sharepoint_online/source.py index c4cb4b999..5d41d26f5 100644 --- a/backend/airweave/platform/sources/sharepoint_online/source.py +++ b/backend/airweave/platform/sources/sharepoint_online/source.py @@ -598,48 +598,10 @@ async def _resolve_unresolved_viewers( new_viewers.append(v) entity.access.viewers = new_viewers - async def _fetch_sp_group_viewers(self, site_url: str) -> List[str]: - """Fetch all SP site groups for a site and return their viewer strings. - - Args: - site_url: Full site URL (e.g. https://tenant.sharepoint.com/sites/X). - Required — without it we can't hit the SP REST endpoint. - """ - norm_site = self._normalize_site_url(site_url) - if not norm_site: - return [] - sp_token_provider = self._make_sp_token_provider_for_site(norm_site) - if not sp_token_provider: - return [] - try: - token = await sp_token_provider() - headers = { - "Authorization": f"Bearer {token}", - "Accept": "application/json;odata=verbose", - } - resp = await self.http_client.get( - f"{norm_site}/_api/web/sitegroups", - headers=headers, - timeout=30.0, - ) - resp.raise_for_status() - groups = resp.json().get("d", {}).get("results", []) - - viewers = [] - site_bucket = self._item_level_sp_groups.setdefault(norm_site, set()) - for g in groups: - title = g.get("Title", "") - if title: - tag = f"group:sp:{title.lower().replace(' ', '_')}" - viewers.append(tag) - site_bucket.add(tag[len("group:") :]) - self.logger.info(f"Fetched {len(viewers)} SP site groups as viewers for {norm_site}") - return viewers - except SourceAuthError: - raise - except Exception as e: - self.logger.warning(f"SP group fetch failed for {norm_site}: {e}") - return [] + @staticmethod + def _has_link_permission(permissions: List[Dict[str, Any]]) -> bool: + """Return True if any permission carries a sharing-link block.""" + return any(p.get("link") for p in (permissions or [])) async def _full_sync( # noqa: C901 self, @@ -687,8 +649,6 @@ async def _full_sync( # noqa: C901 self.logger.warning(f"Skipping site {site_id}: {e}") continue - sp_group_viewers = await self._fetch_sp_group_viewers(site_url) - for drive_data in all_drives: drive_id = drive_data.get("id", "") try: @@ -730,20 +690,25 @@ async def _full_sync( # noqa: C901 item_data["id"], ) + # Sharing-link permissions need the file's SP UniqueId + # to translate into the SharingLinks.* SP site group. + # Skip the extra fetch when the file has no sharing links. + sp_unique_id = None + if self._has_link_permission(permissions): + sp_unique_id = await graph_client.get_item_sp_unique_id( + drive_id, item_data["id"] + ) + file_entity = await build_file_entity( item_data, drive_id, site_id, drive_breadcrumbs, permissions, + sp_unique_id=sp_unique_id, ) await self._resolve_unresolved_viewers(file_entity, graph_client) - if sp_group_viewers and file_entity.access: - existing = set(file_entity.access.viewers or []) - for spv in sp_group_viewers: - if spv not in existing: - file_entity.access.viewers.append(spv) self._track_entity_groups(file_entity, site_url) if files: @@ -893,12 +858,18 @@ async def _incremental_sync( # noqa: C901 if item_data.get("file"): try: permissions = await graph_client.get_item_permissions(drive_id, item_id) + sp_unique_id = None + if self._has_link_permission(permissions): + sp_unique_id = await graph_client.get_item_sp_unique_id( + drive_id, item_id + ) file_entity = await build_file_entity( item_data, drive_id, "", [], permissions, + sp_unique_id=sp_unique_id, ) await self._resolve_unresolved_viewers(file_entity, graph_client) self._track_entity_groups(file_entity) @@ -1036,8 +1007,18 @@ async def _targeted_sync( # noqa: C901 item_data = await graph_client.get(url) if item_data.get("file"): permissions = await graph_client.get_item_permissions(drive_id, item_id) + sp_unique_id = None + if self._has_link_permission(permissions): + sp_unique_id = await graph_client.get_item_sp_unique_id( + drive_id, item_id + ) file_entity = await build_file_entity( - item_data, drive_id, "", [], permissions + item_data, + drive_id, + "", + [], + permissions, + sp_unique_id=sp_unique_id, ) await self._resolve_unresolved_viewers(file_entity, graph_client) self._track_entity_groups(file_entity) @@ -1117,7 +1098,7 @@ async def _sync_folder_recursive( ): yield entity - async def _process_file_items( + async def _process_file_items( # noqa: C901 self, graph_client: GraphClient, item_stream: AsyncGenerator[Dict[str, Any], None], @@ -1136,8 +1117,18 @@ async def _process_file_items( continue try: permissions = await graph_client.get_item_permissions(drive_id, item_data["id"]) + sp_unique_id = None + if self._has_link_permission(permissions): + sp_unique_id = await graph_client.get_item_sp_unique_id( + drive_id, item_data["id"] + ) file_entity = await build_file_entity( - item_data, drive_id, site_id, breadcrumbs, permissions + item_data, + drive_id, + site_id, + breadcrumbs, + permissions, + sp_unique_id=sp_unique_id, ) if resolve_viewers: await self._resolve_unresolved_viewers(file_entity, graph_client) diff --git a/backend/tests/unit/platform/sources/test_sharepoint_online_acl.py b/backend/tests/unit/platform/sources/test_sharepoint_online_acl.py index 0ff15551c..a976f9802 100644 --- a/backend/tests/unit/platform/sources/test_sharepoint_online_acl.py +++ b/backend/tests/unit/platform/sources/test_sharepoint_online_acl.py @@ -1,32 +1,43 @@ """Unit tests for SharePoint Online ACL extraction. Covers ``extract_access_control`` and the rules around how Microsoft Graph -sharing-link permissions map to ``AccessControl.is_public``. - -Background — bug fix: - Organization-scoped sharing links (``link.scope == "organization"``, - "anyone in your org with the link") used to set ``is_public = True``, - which made the search broker bypass all viewer checks. That mis-modeled - SharePoint semantics: an org-scoped link requires possession of the - link URL to grant access. Only ``link.scope == "anonymous"`` is true - public access. +sharing-link permissions map to ``AccessControl``. + +Background — two related bug fixes: + +1. Organization-scoped sharing links (``link.scope == "organization"``, + "anyone in your org with the link") used to set ``is_public = True``, + which made the search broker bypass all viewer checks. Only + ``link.scope == "anonymous"`` is genuine public access. + +2. The previous code attempted to recover sharing-link audience by + blanket-attaching every site group (including SharingLinks system + groups for unrelated files) to every file in the site. That over-granted + massively. The fix is to translate each link permission into the + specific ``SharingLinks...`` SP site group + for that one file, scoped by the file's SharePoint UniqueId. """ import pytest -from airweave.platform.sources.sharepoint_online.acl import extract_access_control +from airweave.platform.sources.sharepoint_online.acl import ( + extract_access_control, + link_permission_to_sp_group_viewer, +) # --------------------------------------------------------------------------- # Helpers — build minimal Graph permission objects # --------------------------------------------------------------------------- -def _link_perm(scope: str, roles=None) -> dict: +def _link_perm( + scope: str, type_: str = "edit", roles=None, link_id: str = "link-1" +) -> dict: """Sharing-link permission with the given scope (no grantedTo principal).""" return { - "id": f"link-{scope}", + "id": link_id, "roles": roles if roles is not None else ["write"], - "link": {"scope": scope, "type": "edit"}, + "link": {"scope": scope, "type": type_}, "grantedToIdentitiesV2": [], "grantedToIdentities": [], } @@ -68,11 +79,54 @@ async def test_organization_scoped_link_does_not_set_is_public(): Regression: the previous behavior treated organization-scoped links as fully public, bypassing all viewer checks at search time. """ + # Without sp_unique_id the link cannot be translated into a SharingLinks + # site group viewer either — both halves of the fix combine to give + # "no public access, no viewer either". ac = await extract_access_control([_link_perm("organization")]) assert ac.is_public is False assert ac.viewers == [] +@pytest.mark.asyncio +async def test_organization_edit_link_with_sp_unique_id_yields_per_link_viewer(): + """When the file's SP UniqueId is known, an org+edit link translates.""" + perm = _link_perm("organization", type_="edit", link_id="LINK0001") + ac = await extract_access_control( + [perm], sp_unique_id="dd7691b0-3468-446f-81b0-72f3bdab7d1f" + ) + assert ac.is_public is False + assert ac.viewers == [ + "group:sp:sharinglinks.dd7691b0-3468-446f-81b0-72f3bdab7d1f.organizationedit.link0001" + ] + + +@pytest.mark.asyncio +async def test_organization_view_link_translates_to_organizationview_suffix(): + perm = _link_perm("organization", type_="view", link_id="LINK0002") + ac = await extract_access_control([perm], sp_unique_id="aaaa-bbbb") + assert ac.viewers == ["group:sp:sharinglinks.aaaa-bbbb.organizationview.link0002"] + + +@pytest.mark.asyncio +async def test_users_scope_link_translates_to_flexible_suffix(): + """Empirically verified: both users+edit and users+view collapse to Flexible.""" + perm_edit = _link_perm("users", type_="edit", link_id="LINKE") + perm_view = _link_perm("users", type_="view", link_id="LINKV") + ac_e = await extract_access_control([perm_edit], sp_unique_id="ITEM1") + ac_v = await extract_access_control([perm_view], sp_unique_id="ITEM1") + assert ac_e.viewers == ["group:sp:sharinglinks.item1.flexible.linke"] + assert ac_v.viewers == ["group:sp:sharinglinks.item1.flexible.linkv"] + + +@pytest.mark.asyncio +async def test_anonymous_link_does_not_get_translated_to_viewer(): + """Anonymous → is_public, never a SharingLinks viewer.""" + perm = _link_perm("anonymous", type_="view", link_id="LINKA") + ac = await extract_access_control([perm], sp_unique_id="ITEM1") + assert ac.is_public is True + assert ac.viewers == [] + + @pytest.mark.asyncio async def test_anonymous_link_sets_is_public(): ac = await extract_access_control([_link_perm("anonymous")]) @@ -105,25 +159,27 @@ async def test_unknown_link_scope_does_not_set_is_public(): @pytest.mark.asyncio -async def test_org_link_alongside_explicit_grants_extracts_only_grants(): - """Org-link permission is skipped; explicit grants populate viewers. +async def test_org_link_alongside_explicit_grants_extracts_grants_and_link_group(): + """Org-link plus explicit grants → both end up in viewers. - Mirrors the Mistral bug-report payload: a file with one organization- - scoped sharing link plus the inherited site-group grants. The fix must - keep is_public false while still extracting Owners / Members / Visitors. + Mirrors the Mistral bug-report payload shape: a file with one + organization-scoped sharing link plus inherited site-group grants. + Post-fix, is_public is False, all explicit grants are present, and + the per-link SharingLinks site group is included exactly once. """ perms = [ - _link_perm("organization"), + _link_perm("organization", link_id="LINK0001"), _site_group_perm("Access Control Tests Owners", group_id="3", roles=["owner"]), _site_group_perm("Access Control Tests Members", group_id="5", roles=["write"]), _site_group_perm("Access Control Tests Visitors", group_id="4", roles=["read"]), ] - ac = await extract_access_control(perms) + ac = await extract_access_control(perms, sp_unique_id="ITEM1") assert ac.is_public is False assert set(ac.viewers) == { "group:sp:access_control_tests_owners", "group:sp:access_control_tests_members", "group:sp:access_control_tests_visitors", + "group:sp:sharinglinks.item1.organizationedit.link0001", } @@ -191,3 +247,40 @@ async def test_duplicate_principal_only_added_once(): ] ac = await extract_access_control(perms) assert ac.viewers == ["user:alice@example.com"] + + +# --------------------------------------------------------------------------- +# link_permission_to_sp_group_viewer — None-return paths +# --------------------------------------------------------------------------- + + +def test_link_translation_returns_none_without_sp_unique_id(): + """No SP UniqueId means we can't construct the group name — return None.""" + perm = _link_perm("organization", link_id="L1") + assert link_permission_to_sp_group_viewer(perm, None) is None + + +def test_link_translation_returns_none_for_non_link_perm(): + """A non-link permission is not a sharing link — return None.""" + perm = {"id": "x", "roles": ["read"], "grantedToV2": {"user": {"email": "a@b.com"}}} + assert link_permission_to_sp_group_viewer(perm, "ITEM1") is None + + +def test_link_translation_returns_none_for_anonymous(): + perm = _link_perm("anonymous", link_id="L1") + assert link_permission_to_sp_group_viewer(perm, "ITEM1") is None + + +def test_link_translation_returns_none_for_unknown_scope(): + """Unknown / future scope: be conservative, don't fabricate a viewer.""" + perm = { + "id": "L1", + "roles": ["read"], + "link": {"scope": "future-scope", "type": "edit"}, + } + assert link_permission_to_sp_group_viewer(perm, "ITEM1") is None + + +def test_link_translation_returns_none_when_link_id_missing(): + perm = {"id": "", "roles": ["read"], "link": {"scope": "organization", "type": "edit"}} + assert link_permission_to_sp_group_viewer(perm, "ITEM1") is None From df2715d264f04d572dabcf9d4f641b6af586fef7 Mon Sep 17 00:00:00 2001 From: Daan Manneke Date: Tue, 5 May 2026 13:07:44 +0200 Subject: [PATCH 20/25] fix(sharepoint_online): expand "Everyone except external users" claim into per-user memberships MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The "Everyone except external users" claim (``c:0-.f|rolemanager|spo-grid-all-users/``) is the default audience SharePoint adds to most sites' Members and Visitors groups — i.e. the standard permission model on classic team sites and communication sites. SharePoint returns it as a member of site groups with ``PrincipalType=4`` (same as Entra federated groups), but with the ``rolemanager`` claim provider rather than ``federateddirectoryclaimprovider``. Our existing PT=4 branch in ``_parse_sp_group_member`` only matched the Entra shape, so the claim hit ``return None`` and was silently dropped. Net effect: any file accessible only via a site group that contained this claim was invisible in search to internal users without an alternate access path. No log line, no warning — just a silent drop. Fix --- Translate the claim into a synthetic group ``claim:everyone_except_external``, populated once per sync with the tenant's internal members (``userType eq 'Member'`` filter excludes B2B guests, preserving the claim's semantics). The broker's existing recursive group expansion then resolves user → claim:everyone_except_external → sp: → file at search time with no broker, schema, or search-path changes. Also: log a single info-level line whenever we encounter a PT=4 LoginName we don't recognize, so future unknown claim shapes (custom rolemanager roles, legacy Windows claims, etc.) surface in operator logs without breaking sync. Changes ------- - ``source.py``: - ``EVERYONE_EXCEPT_EXTERNAL_PRINCIPAL`` synthetic group constant. - ``_EVERYONE_EXCEPT_EXTERNAL_LOGIN_RE`` regex. - ``_parse_sp_group_member`` recognizes the claim and returns the sentinel; docstring corrected (PT=4 with rolemanager, not PT=16). - ``_is_unrecognized_pt4_login`` classifier for diagnostic logging. - ``_expand_sp_site_groups`` flips ``_needs_internal_user_enum`` when the sentinel is yielded. - New ``_expand_everyone_except_external`` enumerates internal users via Graph and yields user → claim memberships once per sync. - ``generate_access_control_memberships`` runs the enumeration only when the flag was set — zero cost on tenants that don't use the claim. - ``client.py``: new ``GraphClient.list_internal_tenant_users`` paginates ``/users?$filter=userType eq 'Member'``. Tests ----- 6 new unit tests in ``test_sharepoint_online_group_expansion.py`` covering the claim regex (lower- and uppercase tenant GUIDs), the synthetic sentinel return, ``_is_unrecognized_pt4_login`` for both known and unknown PT=4 shapes, and the non-PT=4 case. Existing 488 source-unit tests still pass. Verification ------------ End-to-end against neenacorp.sharepoint.com on a communication site (M365-group-backed sites refuse the claim by org policy, so the bug surface needs a non-group site). Set up: a custom site group ``claim-test-group`` containing only the claim; a folder with broken inheritance granted only to that group; a file ``claim-doc.txt`` inside. Pre-fix: 0 membership rows for ``claim-test-group`` in the DB despite the claim being present in SharePoint. Three internal users (``lennert@``, ``acl_test_user1@``, ``acl_test_user5@``) returned 0 hits on a search for the file's canary token. Post-fix: 49 user → claim rows + 1 claim → SP-group row in the DB. Same three users now return 1 hit each. PR #2's regression matrix on the access-control-tests collection remains unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../sources/sharepoint_online/client.py | 24 ++++ .../sources/sharepoint_online/source.py | 117 ++++++++++++++++-- .../test_sharepoint_online_group_expansion.py | 68 +++++++++- 3 files changed, 199 insertions(+), 10 deletions(-) diff --git a/backend/airweave/platform/sources/sharepoint_online/client.py b/backend/airweave/platform/sources/sharepoint_online/client.py index e4683a2db..fcfd945ff 100644 --- a/backend/airweave/platform/sources/sharepoint_online/client.py +++ b/backend/airweave/platform/sources/sharepoint_online/client.py @@ -303,6 +303,30 @@ async def get_item_permissions( return [] raise + async def list_internal_tenant_users(self) -> AsyncGenerator[Dict[str, str], None]: + """Yield internal tenant members (``userType eq 'Member'``). + + Used to expand the SharePoint "Everyone except external users" claim + into per-user memberships of the synthetic claim group. The Graph + filter excludes guests (``userType eq 'Guest'``), preserving the + claim's "except external" semantics. + + Yields: + Dicts with at least ``email`` (mail or userPrincipalName, + lowercased) and ``display_name``. Users without any addressable + identifier are skipped. + """ + url = ( + f"{GRAPH_BASE_URL}/users" + "?$filter=userType eq 'Member'" + "&$select=id,mail,userPrincipalName,displayName" + ) + async for u in self.get_paginated(url): + email = (u.get("mail") or u.get("userPrincipalName") or "").strip().lower() + if not email: + continue + yield {"email": email, "display_name": u.get("displayName") or email} + async def get_item_sp_unique_id( self, drive_id: str, diff --git a/backend/airweave/platform/sources/sharepoint_online/source.py b/backend/airweave/platform/sources/sharepoint_online/source.py index 5d41d26f5..5ba50ca54 100644 --- a/backend/airweave/platform/sources/sharepoint_online/source.py +++ b/backend/airweave/platform/sources/sharepoint_online/source.py @@ -75,6 +75,14 @@ MAX_CONCURRENT_FILE_DOWNLOADS = 10 ITEM_BATCH_SIZE = 50 +# Synthetic principal representing the SharePoint "Everyone except external users" +# claim. SP exposes this claim as a member of site groups but our membership +# table only handles real users / Entra groups / SP groups. We translate the +# claim into a synthetic group, populate it with the tenant's internal members +# at sync time, and let the broker's recursive expansion do the rest. +EVERYONE_EXCEPT_EXTERNAL_PRINCIPAL = "claim:everyone_except_external" +EVERYONE_EXCEPT_EXTERNAL_DISPLAY_NAME = "Everyone except external users (synthetic)" + @dataclass class PendingFileDownload: @@ -110,6 +118,10 @@ class SharePointOnlineBase(BaseSource): # Site-scoped SP group tracking: {site_url: {sp_group_name, ...}} # Keyed by normalized site URL so multi-site syncs can expand SP groups per site. _item_level_sp_groups: Dict[str, Set[str]] + # Set to True during membership extraction when an SP group contains the + # "Everyone except external users" claim, so we know to enumerate internal + # tenant users once at the end. + _needs_internal_user_enum: bool def _init_common(self, config: SharePointOnlineConfig) -> None: """Initialize fields shared by both OAuth and client-credentials sources.""" @@ -118,6 +130,7 @@ def _init_common(self, config: SharePointOnlineConfig) -> None: self._include_pages = config.include_pages self._item_level_entra_groups = set() self._item_level_sp_groups = {} + self._needs_internal_user_enum = False # -- Auth hooks (subclasses override) -- @@ -234,6 +247,14 @@ def _track_entity_groups(self, entity: BaseEntity, site_url: str = "") -> None: r"(?P[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-" r"[0-9a-fA-F]{4}-[0-9a-fA-F]{12})(_o)?$" ) + # Match the "Everyone except external users" claim: + # "c:0-.f|rolemanager|spo-grid-all-users/" + # PrincipalType=4 (SecurityGroup), but the claim provider is `rolemanager` + # rather than `federateddirectoryclaimprovider`. Represents all tenant + # users with userType=Member; excludes B2B guests by definition. + _EVERYONE_EXCEPT_EXTERNAL_LOGIN_RE = re.compile( + r"^c:0-\.f\|rolemanager\|spo-grid-all-users/[0-9a-fA-F-]+$" + ) @classmethod def _email_from_membership_login(cls, login: str) -> Optional[str]: @@ -255,16 +276,26 @@ def _parse_sp_group_member(cls, user: Dict[str, Any]) -> Optional[Tuple[str, str """Parse one entry from /_api/web/sitegroups({id})/users into (member_id, member_type). Returns None for entries that should not become memberships: - - Role principals (PrincipalType=16, e.g. "Everyone except external users") - - Catch-all "All" principals (PrincipalType=15) - - DistList, SPGroup, unknown types (skipped; rare in practice) + - Catch-all "All" / "Everyone" principals (PrincipalType=15) + - DistList, SPGroup, RoleManager (other than the recognized claim below) - Unparseable entries (no email for users, no GUID for groups) + Recognized PrincipalType=4 shapes: + - Entra federated group: ``c:0o.c|federateddirectoryclaimprovider|[_o]`` + → returns ``("entra:", "group")``. + - "Everyone except external users" claim: + ``c:0-.f|rolemanager|spo-grid-all-users/`` → returns the + synthetic ``(EVERYONE_EXCEPT_EXTERNAL_PRINCIPAL, "group")`` sentinel. + The caller (``_expand_sp_site_groups``) then enumerates internal + tenant users once per sync to populate the synthetic group. + - Any other PT=4 LoginName: returns None. The caller logs the raw + shape at info-level so unknown claim shapes show up in operator + logs and can be wired up explicitly later. + PrincipalType reference: 1 = User 2 = DistList - 4 = SecurityGroup (Entra group when LoginName uses - federateddirectoryclaimprovider) + 4 = SecurityGroup (Entra group OR rolemanager claim) 8 = SPGroup 15 = All 16 = RoleManager @@ -285,15 +316,32 @@ def _parse_sp_group_member(cls, user: Dict[str, Any]) -> Optional[Tuple[str, str if ptype == 4: m = cls._ENTRA_GROUP_LOGIN_RE.match(login) - if not m: - return None - guid = m.group("guid").lower() - return (f"entra:{guid}", "group") + if m: + return (f"entra:{m.group('guid').lower()}", "group") + if cls._EVERYONE_EXCEPT_EXTERNAL_LOGIN_RE.match(login): + return (EVERYONE_EXCEPT_EXTERNAL_PRINCIPAL, "group") + return None # PrincipalType 2 (DistList), 8 (SPGroup), 15 (All), 16 (RoleManager), # and unknown types are intentionally skipped. return None + @classmethod + def _is_unrecognized_pt4_login(cls, user: Dict[str, Any]) -> bool: + """Return True for a PT=4 entry whose LoginName matches none of our patterns. + + Used at the call site to emit a one-line diagnostic so that unknown + claim shapes (rare custom rolemanager roles, legacy Windows claims, + etc.) surface in operator logs without breaking sync. + """ + if user.get("PrincipalType") != 4: + return False + login = user.get("LoginName", "") or "" + return not ( + cls._ENTRA_GROUP_LOGIN_RE.match(login) + or cls._EVERYONE_EXCEPT_EXTERNAL_LOGIN_RE.match(login) + ) + # -- Browse Tree -- BROWSE_TREE_MAX_ITEMS = 500 @@ -1248,6 +1296,12 @@ async def _expand_sp_site_groups( # noqa: C901 for user in users: parsed = self._parse_sp_group_member(user) if parsed is None: + if self._is_unrecognized_pt4_login(user): + self.logger.info( + "Unrecognized PrincipalType=4 SP group member; skipped. " + f"LoginName={user.get('LoginName', '')!r} " + f"Title={user.get('Title', '')!r}" + ) continue member_id, member_type = parsed yield MembershipTuple( @@ -1256,6 +1310,42 @@ async def _expand_sp_site_groups( # noqa: C901 group_id=sp_name, group_name=user.get("Title") or sp_name, ) + if member_id == EVERYONE_EXCEPT_EXTERNAL_PRINCIPAL: + self._needs_internal_user_enum = True + + async def _expand_everyone_except_external( + self, + ) -> AsyncGenerator[MembershipTuple, None]: + """Populate the synthetic ``EVERYONE_EXCEPT_EXTERNAL_PRINCIPAL`` group. + + Called once per sync, only when the SP group expansion observed at + least one occurrence of the claim. Enumerates internal tenant users + via Graph (``userType eq 'Member'`` filter excludes B2B guests) and + yields one user → claim membership per user. The broker's recursive + group expansion then chains user → claim → SP group at search time. + """ + graph_client = self._create_graph_client() + count = 0 + try: + async for u in graph_client.list_internal_tenant_users(): + count += 1 + yield MembershipTuple( + member_id=u["email"], + member_type="user", + group_id=EVERYONE_EXCEPT_EXTERNAL_PRINCIPAL, + group_name=EVERYONE_EXCEPT_EXTERNAL_DISPLAY_NAME, + ) + except SourceAuthError: + raise + except Exception as e: + self.logger.warning( + f"Failed to enumerate internal tenant users for " + f"'{EVERYONE_EXCEPT_EXTERNAL_PRINCIPAL}': {e}" + ) + self.logger.info( + f"Populated synthetic '{EVERYONE_EXCEPT_EXTERNAL_PRINCIPAL}' group " + f"with {count} internal tenant users" + ) async def generate_access_control_memberships( self, @@ -1263,6 +1353,7 @@ async def generate_access_control_memberships( """Expand Entra ID groups and SP site groups into user memberships.""" self.logger.info("Starting access control membership extraction") membership_count = 0 + self._needs_internal_user_enum = False group_expander = self._create_group_expander() async for m in self._expand_entra_groups(group_expander): @@ -1278,6 +1369,14 @@ async def generate_access_control_memberships( except Exception as e: self.logger.warning(f"SP site group expansion failed: {e}") + # If any SP site group contained the "Everyone except external users" + # claim, populate the synthetic claim group with internal tenant users + # exactly once. Skipped entirely when no group used the claim. + if self._needs_internal_user_enum: + async for m in self._expand_everyone_except_external(): + yield m + membership_count += 1 + group_expander.log_stats() self.logger.info(f"Access control extraction complete: {membership_count} memberships") diff --git a/backend/tests/unit/platform/sources/test_sharepoint_online_group_expansion.py b/backend/tests/unit/platform/sources/test_sharepoint_online_group_expansion.py index 669912b17..45c2f8c54 100644 --- a/backend/tests/unit/platform/sources/test_sharepoint_online_group_expansion.py +++ b/backend/tests/unit/platform/sources/test_sharepoint_online_group_expansion.py @@ -6,7 +6,10 @@ from unittest.mock import MagicMock -from airweave.platform.sources.sharepoint_online.source import SharePointOnlineBase +from airweave.platform.sources.sharepoint_online.source import ( + EVERYONE_EXCEPT_EXTERNAL_PRINCIPAL, + SharePointOnlineBase, +) # --------------------------------------------------------------------------- # _email_from_membership_login @@ -181,6 +184,69 @@ def test_parse_security_group_non_federated_returns_none(): assert SharePointOnlineBase._parse_sp_group_member(user) is None +def test_parse_everyone_except_external_claim_returns_synthetic_sentinel(): + """The rolemanager/spo-grid-all-users claim → synthetic group sentinel.""" + user = { + "PrincipalType": 4, + "LoginName": ( + "c:0-.f|rolemanager|spo-grid-all-users/26adf163-2699-4d04-a0ad-3d935411bf45" + ), + "Title": "Everyone except external users", + } + assert SharePointOnlineBase._parse_sp_group_member(user) == ( + EVERYONE_EXCEPT_EXTERNAL_PRINCIPAL, + "group", + ) + + +def test_parse_everyone_except_external_uppercase_tenant_id(): + """Tenant ID GUIDs may be upper- or lowercase; both should match.""" + user = { + "PrincipalType": 4, + "LoginName": ( + "c:0-.f|rolemanager|spo-grid-all-users/26ADF163-2699-4D04-A0AD-3D935411BF45" + ), + "Title": "Everyone except external users", + } + assert SharePointOnlineBase._parse_sp_group_member(user) == ( + EVERYONE_EXCEPT_EXTERNAL_PRINCIPAL, + "group", + ) + + +def test_parse_other_rolemanager_claim_skipped_and_flagged_as_unrecognized(): + """A different rolemanager claim shouldn't match; should be flagged for logging.""" + user = { + "PrincipalType": 4, + "LoginName": "c:0-.f|rolemanager|some-future-claim", + "Title": "Custom Role", + } + assert SharePointOnlineBase._parse_sp_group_member(user) is None + assert SharePointOnlineBase._is_unrecognized_pt4_login(user) is True + + +def test_is_unrecognized_pt4_login_false_for_known_shapes(): + """Known PT=4 shapes (Entra group, claim) must NOT be flagged as unrecognized.""" + entra = { + "PrincipalType": 4, + "LoginName": ( + "c:0o.c|federateddirectoryclaimprovider|7d344400-39bc-4ee7-aa6e-437bd8de85c0" + ), + } + claim = { + "PrincipalType": 4, + "LoginName": "c:0-.f|rolemanager|spo-grid-all-users/26adf163-2699-4d04-a0ad-3d935411bf45", + } + assert SharePointOnlineBase._is_unrecognized_pt4_login(entra) is False + assert SharePointOnlineBase._is_unrecognized_pt4_login(claim) is False + + +def test_is_unrecognized_pt4_login_false_for_non_pt4(): + """The flag is scoped to PT=4 only; other PrincipalTypes are skipped silently.""" + user = {"PrincipalType": 1, "LoginName": "i:0#.f|membership|alice@example.com"} + assert SharePointOnlineBase._is_unrecognized_pt4_login(user) is False + + def test_parse_distlist_skipped(): user = {"PrincipalType": 2, "LoginName": "some-dl", "Title": "DL"} assert SharePointOnlineBase._parse_sp_group_member(user) is None From 988035e8ef4994923b6deb63aa4d0fb45dccc08f Mon Sep 17 00:00:00 2001 From: Beth Horn <170328193+bethh0rn@users.noreply.github.com> Date: Wed, 22 Apr 2026 10:37:24 +0200 Subject: [PATCH 21/25] Edit: Welcome --- fern/docs/pages/welcome.mdx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fern/docs/pages/welcome.mdx b/fern/docs/pages/welcome.mdx index e9e04fc09..c0c764250 100644 --- a/fern/docs/pages/welcome.mdx +++ b/fern/docs/pages/welcome.mdx @@ -42,7 +42,7 @@ Airweave continuously syncs information from connected sources and makes it avai ## Who it's for -Developers and teams building AI agents and other AI-powered applications that need reliable access to information across multiple tools and data sources. +Ideal for developers and teams building AI agents and AI-powered applications, Airweave provides reliable access to information across multiple tools and data sources. If you're working on long-running AI agents, retrieval-augmented generation, or any context-heavy LLM application, Airweave provides the infrastructure to retrieve and manage context without maintaining bespoke integrations for every data source. @@ -69,7 +69,7 @@ Common use cases include: title="Multi-Source Context Retrieval" icon="fa-solid fa-database" > - Retrieve and combine relevant context from structured and unstructured sources at query time. + Retrieve and combine relevant context from structured and unstructured sources at query time @@ -97,6 +97,6 @@ In all cases, Airweave helps agents retrieve facts from the right source instead icon="fa-solid fa-home" href="https://airweave.ai" > - High-level overview and latest product updates + See a high-level overview and the latest product updates From d3be983c28927a89fe922302097efc21dc932651 Mon Sep 17 00:00:00 2001 From: Beth Horn <170328193+bethh0rn@users.noreply.github.com> Date: Wed, 22 Apr 2026 10:53:49 +0200 Subject: [PATCH 22/25] Edit: Quickstart --- fern/docs/pages/quickstart.mdx | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/fern/docs/pages/quickstart.mdx b/fern/docs/pages/quickstart.mdx index 565a53f48..2ca124b45 100644 --- a/fern/docs/pages/quickstart.mdx +++ b/fern/docs/pages/quickstart.mdx @@ -12,7 +12,7 @@ Follow this guide to get up and running with Airweave in just a few steps. The simplest way to use Airweave is through our hosted cloud platform at [app.airweave.ai](https://app.airweave.ai). - If you prefer to run Airweave yourself, you can deploy it locally on macOS, Linux or WSL. After cloning the repository and starting the server, you will be able to open the dashboard at http://localhost:8080 + If you prefer to run Airweave yourself, you can deploy it locally on macOS, Linux or WSL. After cloning the repository and starting the server, open the dashboard at [http://localhost:8080](http://localhost:8080). ```bash git clone https://github.com/airweave-ai/airweave.git @@ -21,7 +21,7 @@ Follow this guide to get up and running with Airweave in just a few steps. ``` - + Airweave provides SDKs for Python and Node.js. Install the package. @@ -47,10 +47,10 @@ Follow this guide to get up and running with Airweave in just a few steps. > Your browser does not support the video tag. - Initialize the Airweave client with your new API key. For local deployments, set base_url to `"http://localhost:8001"`. + Initialize the Airweave client with your new API key. For local deployments, set `base_url` to `"http://localhost:8001"`. - ```Python title="Python" + ```python title="Python" from airweave import AirweaveSDK airweave = AirweaveSDK(api_key="YOUR_API_KEY", base_url="https://api.airweave.ai") @@ -69,7 +69,7 @@ Follow this guide to get up and running with Airweave in just a few steps. A collection is a group of different data sources that you can search using a single endpoint. - ```Python title="Python" + ```python title="Python" collection = airweave.collections.create(name="My First Collection") print(f"Created collection: {collection.readable_id}") @@ -92,11 +92,11 @@ Follow this guide to get up and running with Airweave in just a few steps. - - A source connection links a specific app or database to your collection. It handles authentication and automatically syncs data. + + Source connections link specific apps or databases to your collection. They handle authentication and automatically sync data. You need at least one source connection per collection. -```Python title="Python" +```python title="Python" source_connection = airweave.source_connections.create( name="My Stripe Connection", short_name="stripe", @@ -149,10 +149,10 @@ curl -X POST 'https://api.airweave.ai/source-connections' \ You can now search your collection and get the most relevant results from all connected sources. -```Python title="Python" +```python title="Python" response = airweave.collections.search.instant( readable_id=collection.readable_id, - query="Find returned payments from user John Doe?", + query="Find returned payments from user John Doe", ) for result in response.results: @@ -162,7 +162,7 @@ for result in response.results: ```javascript title="Node.js" const response = await airweave.collections.search.instant( collection.readableId, - { query: "Find returned payments from user John Doe?" } + { query: "Find returned payments from user John Doe" } ); response.results.forEach(result => console.log(result.name, result.relevanceScore)); @@ -172,7 +172,7 @@ response.results.forEach(result => console.log(result.name, result.relevanceScor curl -X POST 'https://api.airweave.ai/collections/my-first-collection-abc123/search/instant' \ -H 'x-api-key: YOUR_API_KEY' \ -H 'Content-Type: application/json' \ - -d '{"query": "Find returned payments from user John Doe?"}' + -d '{"query": "Find returned payments from user John Doe"}' ``` From 02e0e6398852dd4e8676ab7982f0c3a67fa458d4 Mon Sep 17 00:00:00 2001 From: Beth Horn <170328193+bethh0rn@users.noreply.github.com> Date: Wed, 22 Apr 2026 13:16:03 +0200 Subject: [PATCH 23/25] Proofread: Concepts --- fern/docs/pages/concepts.mdx | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/fern/docs/pages/concepts.mdx b/fern/docs/pages/concepts.mdx index a5387018f..9743e2141 100644 --- a/fern/docs/pages/concepts.mdx +++ b/fern/docs/pages/concepts.mdx @@ -6,10 +6,9 @@ slug: concepts Airweave connects to your apps, databases, and documents, then turns them into knowledge you can search. To understand how it works, you only need a few core concepts. - ## Source -A **Source** is an external application, database, or document store where your data lives. Sources are the systems Airweave pulls data from to build your retrieval layer. +A **Source** is an external application, database, or document store where your data lives. Sources are the systems from which Airweave pulls data to build your retrieval layer. Sources can be: - **Productivity tools**: Notion, Slack, Asana, Jira, Confluence @@ -23,7 +22,7 @@ Each source type has its own data structures, authentication methods, and API pa A **Connector** is the integration code that allows Airweave to communicate with a specific source. Each connector handles: -- **Authentication**: OAuth flows, API keys, or database credentials depending on the source +- **Authentication**: Using OAuth flows, API keys, or database credentials depending on the source - **Data extraction**: Fetching records, documents, or rows from the source API - **Entity mapping**: Transforming source-specific data structures into Airweave's unified entity format - **Incremental sync**: Tracking changes so only new or modified data is re-synced @@ -39,7 +38,7 @@ A **Source Connection** is a configured, authenticated instance of a connector l When you create a source connection, you: 1. Select a connector (e.g., Slack) 2. Authenticate with your credentials (e.g., OAuth login to your Slack workspace) -3. Assign it to a collection +3. Assign the connection to a collection Once created, Airweave continuously syncs data from that source connection, keeping your retrieval layer fresh and up-to-date. You can have multiple source connections of the same type (e.g., connecting to several different Slack workspaces). @@ -47,10 +46,10 @@ Once created, Airweave continuously syncs data from that source connection, keep An **Entity** is a single, searchable item extracted from a source. Entities are the atomic units of data that get indexed and returned in search results. -Examples of entities: +Examples of entities include: - A Slack message or thread - A Notion page or database row -- A GitHub issue, pull request, or code file +- A GitHub issue, pull request (PR), or code file - A Google Doc or spreadsheet - A Zendesk ticket or customer conversation - An Airtable record or row @@ -78,4 +77,4 @@ When an agent searches a collection, the query runs across all entities from all Collections can be queried via the REST API, SDKs, or MCP, making them accessible to any AI agent, RAG pipeline, or application that needs grounded, up-to-date context. -To learn more about querying collections, including the three search tiers (instant, classic, agentic), retrieval strategies, and filtering, see the [Search](/search) documentation. +To learn more about querying collections, including the three search tiers (instant, classic, and agentic), retrieval strategies, and filtering, see the [Search](/search) documentation. From 6dc42a315c1af0366bbefd3ee09c500266060aef Mon Sep 17 00:00:00 2001 From: Beth Horn <170328193+bethh0rn@users.noreply.github.com> Date: Wed, 22 Apr 2026 13:16:50 +0200 Subject: [PATCH 24/25] Define RAG abbreviation (used later in Concepts) --- fern/docs/pages/welcome.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fern/docs/pages/welcome.mdx b/fern/docs/pages/welcome.mdx index c0c764250..0fab0d7b5 100644 --- a/fern/docs/pages/welcome.mdx +++ b/fern/docs/pages/welcome.mdx @@ -44,7 +44,7 @@ Airweave continuously syncs information from connected sources and makes it avai Ideal for developers and teams building AI agents and AI-powered applications, Airweave provides reliable access to information across multiple tools and data sources. -If you're working on long-running AI agents, retrieval-augmented generation, or any context-heavy LLM application, Airweave provides the infrastructure to retrieve and manage context without maintaining bespoke integrations for every data source. +If you're working on long-running AI agents, retrieval-augmented generation (RAG), or any context-heavy LLM application, Airweave provides the infrastructure to retrieve and manage context without maintaining bespoke integrations for every data source. ## Use cases From 8fff722ebde8c981a80d6c950607ba35a836628d Mon Sep 17 00:00:00 2001 From: Beth Horn <170328193+bethh0rn@users.noreply.github.com> Date: Wed, 22 Apr 2026 13:42:28 +0200 Subject: [PATCH 25/25] Proofread: Search --- fern/docs/pages/search.mdx | 48 +++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/fern/docs/pages/search.mdx b/fern/docs/pages/search.mdx index d0d343e59..2d6615b20 100644 --- a/fern/docs/pages/search.mdx +++ b/fern/docs/pages/search.mdx @@ -5,46 +5,46 @@ edit-this-page-url: https://github.com/airweave-ai/airweave/blob/main/fern/docs/ slug: search --- -## Instant Search +## Instant search `POST /collections/{id}/search/instant` -Direct vector search. Use when speed is critical (~0.5sec). +Use this direct vector search when speed is critical (~0.5sec). -The only parameter unique to instant is `retrieval_strategy`, which controls how the vector database matches your query: +The only parameter unique to instant search is `retrieval_strategy`, which controls how the vector database matches your query: -- **`hybrid`** (default) — Combines semantic and keyword search via Reciprocal Rank Fusion. Best for most queries. -- **`semantic`** — Dense vector cosine similarity. Finds conceptually similar content even when wording differs. -- **`keyword`** — BM25 text matching. Only returns content with your exact terms. Use for error codes, identifiers, or known phrases. +- **`hybrid`** (default) combines semantic and keyword search via Reciprocal Rank Fusion. It's best for most queries. +- **`semantic`** uses dense vector cosine similarity. It finds conceptually similar content even when wording differs. +- **`keyword`** uses BM25 text matching and only returns content with your exact terms. Use it for error codes, identifiers, or known phrases. In classic and agentic search, the retrieval strategy is chosen automatically. -## Classic Search +## Classic search `POST /collections/{id}/search/classic` -AI-optimized search strategy. Sensible default for most use cases (~2sec). +Classic search uses an AI-optimized search strategy. It's a sensible default for most use cases (~2sec). An LLM analyzes your query and generates an optimized search strategy. -## Agentic Search +## Agentic search `POST /collections/{id}/search/agentic` -Agent that navigates through your collection to find the best results. Use when recall matters more than latency (<2min). +Agentic search uses an agent to navigate through your collection to find the best results. Use it when recall matters more than latency (<2min). An AI agent iteratively searches your data using tool calling. It searches with multiple strategies, reads full documents, navigates entity hierarchies (parent/child/sibling), and builds a comprehensive result set. -Two parameters unique to agentic: +Two parameters are unique to agentic search: -- **`thinking`** — Enables extended chain-of-thought reasoning before tool calls. Better search strategies, but slower and uses more tokens. Useful for complex or ambiguous queries. -- **`limit`** — Unlike instant/classic where the vector database always returns up to `limit` results, the agent collects results based on relevance. It may return fewer if it decides there aren't enough matches. Setting a limit caps the maximum — if the agent collects more, results are truncated. When `null` (default), there is no cap. +- **`thinking`** enables extended chain-of-thought reasoning before tool calls. It results in better search strategies, but is slower and uses more tokens. It's useful for complex or ambiguous queries. +- **`limit`** enables the agent to collect results based on relevance instead of having the vector data return up to `limit` results like in instant and classic search. The agent may return fewer results if it decides there aren't enough matches. Setting a limit caps the maximum — if the agent collects more, results are truncated. When the `limit` is `null` (default), there is no cap on the number of results. ### Streaming `POST /collections/{id}/search/agentic/stream` -Real-time SSE events as the agent works. Events are delivered as `data: {json}\n\n` messages. The stream terminates after a `done` or `error` event. +Streaming lets you see real-time SSE events as the agent works. Events are delivered as `data: {json}\n\n` messages. The stream terminates after a `done` or `error` event. Emitted once when the search begins. @@ -64,7 +64,7 @@ Emitted once when the search begins. -Emitted once per iteration after the LLM responds. `thinking` contains extended reasoning (when enabled), `text` contains conversational output before tool calls. +Emitted once per iteration after the LLM responds. `thinking` contains extended reasoning (when enabled), and `text` contains conversational output before tool calls. ```json { @@ -82,7 +82,7 @@ Emitted once per iteration after the LLM responds. `thinking` contains extended -Emitted after each tool the agent calls. `diagnostics.arguments` has the full tool input, `diagnostics.stats` has the output. The stats shape depends on which tool was called: +Emitted after each tool the agent calls. `diagnostics.arguments` has the full tool input, and `diagnostics.stats` has the output. The stats shape depends on which tool was called: ```json @@ -432,9 +432,9 @@ Emitted when the search fails. Also terminates the stream. ## Filters -Filters constrain search results by metadata. They work across all three tiers. +Filters constrain search results by metadata. They work across all three tiers (instant, classic, and agentic search). -In classic and agentic search, the AI generates its own filters internally, your filters are **AND'd into every search** it performs, acting as constraints that cannot be bypassed. +In classic and agentic search, the AI generates its own filters internally. Your filters are **AND'd into every search** it performs, acting as constraints that cannot be bypassed. ### Structure @@ -457,7 +457,7 @@ This allows expressions like: `(A AND B) OR (C AND D)` } ``` -### Filterable Fields +### Filterable fields | Field | Type | Description | |-------|------|-------------| @@ -534,7 +534,7 @@ This allows expressions like: `(A AND B) OR (C AND D)` } ``` -**Combine groups with OR — Slack messages OR Notion pages:** +**Combine groups with OR (Slack messages OR Notion pages):** ```json { @@ -555,7 +555,7 @@ This allows expressions like: `(A AND B) OR (C AND D)` } ``` -**Navigate hierarchy — find all entities inside a parent:** +**Navigate hierarchy (find all entities inside a parent):** ```json { @@ -569,7 +569,7 @@ This allows expressions like: `(A AND B) OR (C AND D)` } ``` -### Validation Rules +### Validation rules - Date fields (`created_at`, `updated_at`) require ISO 8601 timestamps (e.g., `2025-01-15T00:00:00Z`) - Ordering operators (`greater_than`, `less_than`, etc.) only work on date and numeric fields @@ -578,9 +578,9 @@ This allows expressions like: `(A AND B) OR (C AND D)` - Scalar operators (`equals`, `contains`, etc.) require a single value, not an array -## Response Format +## Response format -All three tiers return the same `SearchV2Response` with a `results` array. See the [API Reference](/api-reference/collections/instant-search) for the full response schema and interactive examples. +All three search tiers return the same `SearchV2Response` with a `results` array. See the [API Reference](/api-reference/collections/instant-search) for the full response schema and interactive examples. ## Configuring the LLM provider chain