diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b3c69cf17..ec7ff3cd6 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -167,6 +167,7 @@ jobs: images: ${{ env.REGISTRY }}/${{ github.repository }}-${{ matrix.component }} - name: Trivy image scan + id: trivy uses: aquasecurity/trivy-action@97e0b3872f55f89b95b2f65b3dbab56962816478 # v0.34.2 with: image-ref: ${{ env.REGISTRY }}/${{ github.repository }}-${{ matrix.component }}:${{ steps.meta.outputs.version }} @@ -179,13 +180,15 @@ jobs: limit-severities-for-sarif: true - name: Upload Trivy SARIF - if: always() + if: ${{ !cancelled() && steps.trivy.outcome != 'skipped' }} uses: github/codeql-action/upload-sarif@c793b717bc78562f491db7b0e93a3a178b099162 # v4.32.5 with: sarif_file: trivy-${{ matrix.component }}.sarif category: trivy-${{ matrix.component }} - name: Generate SBOM + id: sbom + if: ${{ !cancelled() }} uses: anchore/sbom-action@17ae1740179002c89186b61233e0f892c3118b11 # v0.23.0 with: image: ${{ env.REGISTRY }}/${{ github.repository }}-${{ matrix.component }}:${{ steps.meta.outputs.version }} @@ -193,13 +196,24 @@ jobs: output-file: sbom-${{ matrix.component }}.json - name: Vulnerability scan on SBOM + id: grype + if: ${{ !cancelled() && steps.sbom.outcome == 'success' }} uses: anchore/scan-action@7037fa011853d5a11690026fb85feee79f4c946c # v7.3.2 with: sbom: sbom-${{ matrix.component }}.json fail-build: "true" severity-cutoff: critical + only-fixed: "true" + + - name: Upload Grype SARIF + if: ${{ always() && steps.grype.outputs.sarif }} + uses: github/codeql-action/upload-sarif@c793b717bc78562f491db7b0e93a3a178b099162 # v4.32.5 + with: + sarif_file: ${{ steps.grype.outputs.sarif }} + category: grype-${{ matrix.component }} - name: Get image digest + if: ${{ !cancelled() }} id: digest run: | IMAGE="${{ env.REGISTRY }}/${{ github.repository }}-${{ matrix.component }}" @@ -209,6 +223,7 @@ jobs: echo "Image digest: ${DIGEST}" - name: Attest SBOM + if: ${{ !cancelled() && steps.digest.outcome == 'success' && steps.sbom.outcome == 'success' }} uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0 with: subject-name: ${{ env.REGISTRY }}/${{ github.repository }}-${{ matrix.component }} @@ -217,6 +232,7 @@ jobs: push-to-registry: true - name: Attest build provenance + if: ${{ !cancelled() && steps.digest.outcome == 'success' }} uses: actions/attest@59d89421af93a897026c735860bf21b6eb4f7b26 # v4.1.0 with: subject-name: ${{ env.REGISTRY }}/${{ github.repository }}-${{ matrix.component }} @@ -224,6 +240,7 @@ jobs: push-to-registry: true - name: Upload SBOM artifact + if: ${{ !cancelled() && steps.sbom.outcome == 'success' }} uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 with: name: sbom-${{ matrix.component }} @@ -233,6 +250,7 @@ jobs: # Package Vespa application and attach to the GitHub Release package-vespa: needs: scan-and-attest + if: ${{ !cancelled() && needs.scan-and-attest.result != 'skipped' }} runs-on: ubuntu-latest permissions: contents: write diff --git a/.github/workflows/code-quality.yml b/.github/workflows/code-quality.yml index 3d18d8f3d..fb3f2cb64 100644 --- a/.github/workflows/code-quality.yml +++ b/.github/workflows/code-quality.yml @@ -83,9 +83,17 @@ jobs: uses: snok/install-poetry@76e04a911780d5b312d89783f7b1cd627778900a # v1.4.1 with: version: 2.3.2 - virtualenvs-create: false + virtualenvs-create: true + virtualenvs-in-project: true + - name: Load cached venv + id: cached-mypy-deps + uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3 + with: + path: ./backend/.venv + key: mypy-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }} - name: Install dependencies - run: poetry install + if: steps.cached-mypy-deps.outputs.cache-hit != 'true' + run: poetry install --no-interaction - name: Run mypy if: github.event_name == 'push' run: poetry run mypy --config-file pyproject.toml airweave/ @@ -116,8 +124,16 @@ jobs: uses: snok/install-poetry@76e04a911780d5b312d89783f7b1cd627778900a # v1.4.1 with: version: 2.3.2 - virtualenvs-create: false + virtualenvs-create: true + virtualenvs-in-project: true + - name: Load cached venv + id: cached-importlint-deps + uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3 + with: + path: ./backend/.venv + key: lint-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }} - name: Install lint dependencies + if: steps.cached-importlint-deps.outputs.cache-hit != 'true' run: poetry install --only lint --no-interaction --no-root - name: Run import-linter run: poetry run lint-imports diff --git a/.github/workflows/fern-docs.yml b/.github/workflows/fern-docs.yml index 6a652c8b0..054c9bbe1 100644 --- a/.github/workflows/fern-docs.yml +++ b/.github/workflows/fern-docs.yml @@ -71,7 +71,7 @@ jobs: - name: Setup Node.js uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # v6.2.0 with: - node-version: "18" + node-version: "20" - name: Install Fern run: npm install -g fern-api diff --git a/backend/airweave/adapters/llm/__init__.py b/backend/airweave/adapters/llm/__init__.py index a00b5dfd2..9733d29d0 100644 --- a/backend/airweave/adapters/llm/__init__.py +++ b/backend/airweave/adapters/llm/__init__.py @@ -12,6 +12,7 @@ ) from airweave.adapters.llm.fallback import FallbackChainLLM from airweave.adapters.llm.groq import GroqLLM +from airweave.adapters.llm.mistral import MistralLLM from airweave.adapters.llm.override import create_llm_from_override from airweave.adapters.llm.registry import ( MODEL_REGISTRY, @@ -35,6 +36,7 @@ "AnthropicLLM", "CerebrasLLM", "GroqLLM", + "MistralLLM", "TogetherLLM", "FallbackChainLLM", # Exceptions diff --git a/backend/airweave/adapters/llm/mistral.py b/backend/airweave/adapters/llm/mistral.py new file mode 100644 index 000000000..68e6717a3 --- /dev/null +++ b/backend/airweave/adapters/llm/mistral.py @@ -0,0 +1,260 @@ +"""Mistral LLM implementation. + +Uses the native mistralai SDK for chat completions with json_schema +structured output and OpenAI-compatible tool/function calling. + +Supports reasoning/thinking via two mechanisms: +- Magistral models: native thinking (always-on), returned as ThinkChunk content blocks +- Mistral Small 4: adjustable reasoning via reasoning_effort parameter +""" + +import json +import time +from typing import Any, TypeVar + +from mistralai import Mistral +from mistralai.models import AssistantMessageContent +from mistralai.models.jsonschema import JSONSchema +from mistralai.models.responseformat import ResponseFormat +from mistralai.models.textchunk import TextChunk +from mistralai.models.thinkchunk import ThinkChunk +from mistralai.types.basemodel import Unset +from pydantic import BaseModel + +from airweave.adapters.llm.base import BaseLLM +from airweave.adapters.llm.exceptions import LLMTransientError +from airweave.adapters.llm.registry import LLMModelSpec +from airweave.adapters.llm.tool_response import LLMResponse, LLMToolCall +from airweave.core.config import settings + +T = TypeVar("T", bound=BaseModel) + + +class MistralLLM(BaseLLM): + """Mistral LLM provider with json_schema structured output and tool calling.""" + + def __init__( + self, + model_spec: LLMModelSpec, + max_retries: int | None = None, + ) -> None: + """Initialize the Mistral LLM client with API key validation.""" + super().__init__(model_spec, max_retries=max_retries) + + api_key = settings.MISTRAL_API_KEY + if not api_key: + raise ValueError( + "MISTRAL_API_KEY not configured. Set it in your environment or .env file." + ) + + try: + self._client = Mistral(api_key=api_key) + except Exception as e: + raise RuntimeError(f"Failed to initialize Mistral client: {e}") from e + + self._logger.debug( + f"[MistralLLM] Initialized with model={model_spec.api_model_name}, " + f"context_window={model_spec.context_window}, " + f"max_output_tokens={model_spec.max_output_tokens}" + ) + + def _prepare_schema(self, schema_json: dict[str, Any]) -> dict[str, Any]: + return self._normalize_strict_schema(schema_json) + + async def _call_api( + self, + prompt: str, + schema: type[T], + schema_json: dict[str, Any], + system_prompt: str, + thinking: bool = False, + ) -> T: + api_start = time.monotonic() + response = await self._client.chat.complete_async( + model=self._model_spec.api_model_name, + messages=[ # type: ignore[arg-type] + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ], + temperature=0.3, + response_format=ResponseFormat( # type: ignore[arg-type] + type="json_schema", + json_schema=JSONSchema( + name=schema.__name__.lower(), + strict=True, + schema_definition=schema_json, + ), + ), + max_tokens=self._model_spec.max_output_tokens, + ) + api_time = time.monotonic() - api_start + + # Empty body (including ThinkChunk-only content for reasoning models) + # is transient: retry often clears a momentary truncation on the API side. + content = _extract_text(response.choices[0].message.content) + if not content: + raise LLMTransientError( + "Mistral returned empty response content", + provider=self._name, + ) + + if response.usage: + self._logger.debug( + f"[MistralLLM] API call completed in {api_time:.2f}s, " + f"tokens: prompt={response.usage.prompt_tokens}, " + f"completion={response.usage.completion_tokens}, " + f"total={response.usage.total_tokens}" + ) + + return self._parse_json_response(content, schema) + + async def _call_api_chat( + self, + messages: list[dict], + tools: list[dict], + system_prompt: str, + thinking: bool = False, + max_tokens: int | None = None, + ) -> LLMResponse: + """Mistral tool calling with OpenAI-compatible format.""" + converted = self._prepare_messages_for_api(messages) + api_messages = [{"role": "system", "content": system_prompt}, *converted] + + # Mistral uses OpenAI-compatible tool definitions directly + strict_tools = self._prepare_tools_strict(tools) + + # Build reasoning params based on thinking config + reasoning_params: dict[str, Any] = {} + tc = self._model_spec.thinking_config + if tc and tc.param_name == "reasoning_effort": + reasoning_params[tc.param_name] = "high" if thinking else "none" + + api_start = time.monotonic() + response = await self._client.chat.complete_async( + model=self._model_spec.api_model_name, + messages=api_messages, # type: ignore[arg-type] + tools=strict_tools, # type: ignore[arg-type] + tool_choice="any", + temperature=0.3, + max_tokens=max_tokens or self._model_spec.max_output_tokens, + **reasoning_params, + ) + api_time = time.monotonic() - api_start + + choice = response.choices[0] + message = choice.message + + # Parse content — may contain thinking chunks for reasoning models + raw_content = message.content + text, thinking_text = _extract_text_and_thinking(raw_content) + + # Only surface thinking when the caller requested it + if not thinking: + thinking_text = None + + tool_calls: list[LLMToolCall] = [] + if message.tool_calls: + for tc_item in message.tool_calls: + arguments = tc_item.function.arguments + if isinstance(arguments, str): + try: + arguments = json.loads(arguments) + except json.JSONDecodeError: + arguments = {} + tool_calls.append( + LLMToolCall( + id=tc_item.id or "", + name=tc_item.function.name, + arguments=arguments, + ) + ) + + prompt_tokens = 0 + completion_tokens = 0 + if response.usage: + prompt_tokens = response.usage.prompt_tokens or 0 + completion_tokens = response.usage.completion_tokens or 0 + self._logger.debug( + f"[MistralLLM] Tool call completed in {api_time:.2f}s, " + f"tokens: prompt={prompt_tokens}, completion={completion_tokens}" + ) + + return LLMResponse( + text=text, + thinking=thinking_text, + tool_calls=tool_calls, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + + def _prepare_tools_strict(self, tools: list[dict]) -> list[dict]: + """Normalize tool parameter schemas for Mistral's json_schema strict mode.""" + strict_tools = [] + for tool in tools: + func = tool["function"] + params = self._normalize_strict_schema(func["parameters"]) + strict_tools.append( + { + "type": "function", + "function": { + "name": func["name"], + "description": func.get("description", ""), + "parameters": params, + }, + } + ) + return strict_tools + + async def close(self) -> None: + """Close the Mistral client and release resources.""" + if self._client: + # Mistral SDK uses context manager protocol (__aexit__) for cleanup + await self._client.__aexit__(None, None, None) + self._logger.debug("[MistralLLM] Client closed") + + +# ── Module-level helpers ─────────────────────────────────────────────── + + +def _extract_text(raw_content: AssistantMessageContent | Unset | None) -> str: + """Extract text from content, which may be a string or list of typed chunks.""" + if not isinstance(raw_content, (str, list)): + return "" + + if isinstance(raw_content, str): + return raw_content + + text_parts = [chunk.text for chunk in raw_content if isinstance(chunk, TextChunk)] + return "\n".join(text_parts) + + +def _extract_text_and_thinking( + raw_content: AssistantMessageContent | Unset | None, +) -> tuple[str | None, str | None]: + """Extract text and thinking from content chunks. + + Mistral reasoning models (Magistral, Mistral Small 4 with reasoning_effort) + return ThinkChunk blocks alongside TextChunk blocks in the content array. + ThinkChunk.thinking is a list of TextChunk/ReferenceChunk sub-items. + """ + if not isinstance(raw_content, (str, list)): + return None, None + + if isinstance(raw_content, str): + return raw_content or None, None + + text_parts: list[str] = [] + thinking_parts: list[str] = [] + + for chunk in raw_content: + if isinstance(chunk, ThinkChunk): + # ThinkChunk.thinking is List[Union[TextChunk, ReferenceChunk]] + for sub in chunk.thinking: + if isinstance(sub, TextChunk): + thinking_parts.append(sub.text) + elif isinstance(chunk, TextChunk): + text_parts.append(chunk.text) + + text = "\n".join(text_parts) if text_parts else None + thinking_text = "\n".join(thinking_parts) if thinking_parts else None + return text, thinking_text diff --git a/backend/airweave/adapters/llm/override.py b/backend/airweave/adapters/llm/override.py index af8a468b9..40295bae4 100644 --- a/backend/airweave/adapters/llm/override.py +++ b/backend/airweave/adapters/llm/override.py @@ -7,6 +7,7 @@ from airweave.adapters.llm.anthropic import AnthropicLLM from airweave.adapters.llm.cerebras import CerebrasLLM from airweave.adapters.llm.groq import GroqLLM +from airweave.adapters.llm.mistral import MistralLLM from airweave.adapters.llm.registry import ( LLMModel, LLMProvider, @@ -19,6 +20,7 @@ LLMProvider.ANTHROPIC: AnthropicLLM, LLMProvider.CEREBRAS: CerebrasLLM, LLMProvider.GROQ: GroqLLM, + LLMProvider.MISTRAL: MistralLLM, LLMProvider.TOGETHER: TogetherLLM, } diff --git a/backend/airweave/adapters/llm/registry.py b/backend/airweave/adapters/llm/registry.py index 01edfce60..0665b1ee1 100644 --- a/backend/airweave/adapters/llm/registry.py +++ b/backend/airweave/adapters/llm/registry.py @@ -23,6 +23,7 @@ class LLMProvider(str, Enum): GROQ = "groq" ANTHROPIC = "anthropic" TOGETHER = "together" + MISTRAL = "mistral" class LLMModel(str, Enum): @@ -45,6 +46,9 @@ class LLMModel(str, Enum): QWEN_3_5_DEDICATED = "qwen-3.5-dedicated" ZAI_GLM_5_DEDICATED = "zai-glm-5-dedicated" MINIMAX_M2_5 = "minimax-m2.5" + MISTRAL_LARGE = "mistral-large" + MISTRAL_SMALL = "mistral-small" + MAGISTRAL_SMALL = "magistral-small" @dataclass(frozen=True) @@ -205,6 +209,43 @@ class LLMModelSpec: output_price_factor=1.20, ), }, + LLMProvider.MISTRAL: { + LLMModel.MISTRAL_LARGE: LLMModelSpec( + api_model_name="mistral-large-latest", + context_window=256_000, + max_output_tokens=16_384, + required_tokenizer_type=TokenizerType.TIKTOKEN, + required_tokenizer_encoding=TokenizerEncoding.O200K_HARMONY, + thinking_config=ThinkingConfig(param_name="_noop", param_value=True), + input_price_factor=2.0, + output_price_factor=6.0, + ), + # Mistral Small 4 — adjustable reasoning via reasoning_effort + LLMModel.MISTRAL_SMALL: LLMModelSpec( + api_model_name="mistral-small-latest", + context_window=128_000, + max_output_tokens=16_384, + required_tokenizer_type=TokenizerType.TIKTOKEN, + required_tokenizer_encoding=TokenizerEncoding.O200K_HARMONY, + thinking_config=ThinkingConfig( + param_name="reasoning_effort", + param_value="high", + ), + input_price_factor=0.1, + output_price_factor=0.3, + ), + # Magistral Small — native reasoning (always-on thinking) + LLMModel.MAGISTRAL_SMALL: LLMModelSpec( + api_model_name="magistral-small-latest", + context_window=128_000, + max_output_tokens=40_000, + required_tokenizer_type=TokenizerType.TIKTOKEN, + required_tokenizer_encoding=TokenizerEncoding.O200K_HARMONY, + thinking_config=ThinkingConfig(param_name="_noop", param_value=True), + input_price_factor=0.5, + output_price_factor=1.5, + ), + }, } @@ -214,6 +255,7 @@ class LLMModelSpec: LLMProvider.GROQ: "GROQ_API_KEY", LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY", LLMProvider.TOGETHER: "TOGETHER_API_KEY", + LLMProvider.MISTRAL: "MISTRAL_API_KEY", } diff --git a/backend/airweave/adapters/llm/tests/test_mistral.py b/backend/airweave/adapters/llm/tests/test_mistral.py new file mode 100644 index 000000000..6dddc5bb1 --- /dev/null +++ b/backend/airweave/adapters/llm/tests/test_mistral.py @@ -0,0 +1,361 @@ +"""Tests for MistralLLM — mock the SDK client, not the network.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from mistralai.models.textchunk import TextChunk +from mistralai.models.thinkchunk import ThinkChunk +from pydantic import BaseModel + +from airweave.adapters.llm.exceptions import LLMProviderExhaustedError +from airweave.adapters.llm.mistral import MistralLLM +from airweave.adapters.llm.registry import LLMModelSpec, ThinkingConfig +from airweave.adapters.tokenizer.registry import TokenizerEncoding, TokenizerType + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_spec( + thinking_param: str = "_noop", + thinking_value: str | bool = True, +) -> LLMModelSpec: + return LLMModelSpec( + api_model_name="mistral-large-latest", + context_window=256_000, + max_output_tokens=16_384, + required_tokenizer_type=TokenizerType.TIKTOKEN, + required_tokenizer_encoding=TokenizerEncoding.O200K_HARMONY, + thinking_config=ThinkingConfig( + param_name=thinking_param, + param_value=thinking_value, + ), + ) + + +class _DummyOutput(BaseModel): + key: str + + +def _mock_response( + content: str | list | None = '{"key": "value"}', + tool_calls: list | None = None, + prompt_tokens: int = 100, + completion_tokens: int = 50, + total_tokens: int = 150, +) -> MagicMock: + """Build a mock mimicking the Mistral SDK ChatCompletionResponse.""" + mock_choice = MagicMock() + mock_choice.message.content = content + mock_choice.message.tool_calls = tool_calls + + mock_resp = MagicMock() + mock_resp.choices = [mock_choice] + mock_resp.usage = MagicMock( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ) + return mock_resp + + +@pytest.fixture +def mistral_llm(): + """Instantiate MistralLLM with a patched settings object.""" + with patch("airweave.adapters.llm.mistral.settings") as mock_settings: + mock_settings.MISTRAL_API_KEY = "test-key" + llm = MistralLLM(model_spec=_make_spec(), max_retries=0) + yield llm + + +@pytest.fixture +def mistral_llm_reasoning(): + """MistralLLM configured with reasoning_effort thinking config.""" + with patch("airweave.adapters.llm.mistral.settings") as mock_settings: + mock_settings.MISTRAL_API_KEY = "test-key" + llm = MistralLLM( + model_spec=_make_spec( + thinking_param="reasoning_effort", + thinking_value="high", + ), + max_retries=0, + ) + yield llm + + +# ═══════════════════════════════════════════════════════════════════════════ +# structured_output tests +# ═══════════════════════════════════════════════════════════════════════════ + + +@pytest.mark.asyncio +async def test_structured_output_returns_parsed(mistral_llm: MistralLLM) -> None: + """_call_api parses JSON content into the Pydantic model.""" + mistral_llm._client.chat.complete_async = AsyncMock( + return_value=_mock_response(content='{"key": "hello"}') + ) + + result = await mistral_llm.structured_output( + prompt="test prompt", + schema=_DummyOutput, + system_prompt="sys", + ) + + assert isinstance(result, _DummyOutput) + assert result.key == "hello" + + +@pytest.mark.asyncio +async def test_empty_response_raises_transient(mistral_llm: MistralLLM) -> None: + """Empty content from the API raises LLMProviderExhaustedError (wrapping transient).""" + mistral_llm._client.chat.complete_async = AsyncMock( + return_value=_mock_response(content=None) + ) + + with pytest.raises(LLMProviderExhaustedError, match="empty response"): + await mistral_llm.structured_output( + prompt="test prompt", + schema=_DummyOutput, + system_prompt="sys", + ) + + +@pytest.mark.asyncio +async def test_structured_output_uses_json_schema(mistral_llm: MistralLLM) -> None: + """_call_api sends json_schema response_format.""" + mock_create = AsyncMock( + return_value=_mock_response(content='{"key": "v"}') + ) + mistral_llm._client.chat.complete_async = mock_create + + await mistral_llm.structured_output( + prompt="test", + schema=_DummyOutput, + system_prompt="sys", + ) + + call_kwargs = mock_create.call_args.kwargs + rf = call_kwargs["response_format"] + assert rf.type == "json_schema" + assert rf.json_schema.strict is True + + +# ═══════════════════════════════════════════════════════════════════════════ +# chat tests +# ═══════════════════════════════════════════════════════════════════════════ + + +@pytest.mark.asyncio +async def test_chat_returns_tool_calls(mistral_llm: MistralLLM) -> None: + """chat() extracts tool_calls from the response.""" + mock_tc = MagicMock() + mock_tc.id = "tc-1" + mock_tc.function.name = "search" + mock_tc.function.arguments = '{"query": "hello"}' + + mistral_llm._client.chat.complete_async = AsyncMock( + return_value=_mock_response(content="text", tool_calls=[mock_tc]) + ) + + result = await mistral_llm.chat( + messages=[{"role": "user", "content": "hi"}], + tools=[ + { + "type": "function", + "function": { + "name": "search", + "description": "Search", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + system_prompt="sys", + ) + + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].name == "search" + assert result.tool_calls[0].arguments == {"query": "hello"} + assert result.prompt_tokens == 100 + + +@pytest.mark.asyncio +async def test_chat_dict_arguments(mistral_llm: MistralLLM) -> None: + """chat() handles arguments already returned as dict (not stringified).""" + mock_tc = MagicMock() + mock_tc.id = "tc-2" + mock_tc.function.name = "lookup" + mock_tc.function.arguments = {"id": 42} + + mistral_llm._client.chat.complete_async = AsyncMock( + return_value=_mock_response(content=None, tool_calls=[mock_tc]) + ) + + result = await mistral_llm.chat( + messages=[{"role": "user", "content": "hi"}], + tools=[ + { + "type": "function", + "function": { + "name": "lookup", + "description": "Lookup", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + system_prompt="sys", + ) + + assert result.tool_calls[0].arguments == {"id": 42} + + +@pytest.mark.asyncio +async def test_chat_uses_tool_choice_any(mistral_llm: MistralLLM) -> None: + """chat() passes tool_choice='any' to force tool usage.""" + mock_create = AsyncMock( + return_value=_mock_response(content=None, tool_calls=None) + ) + mistral_llm._client.chat.complete_async = mock_create + + await mistral_llm.chat( + messages=[{"role": "user", "content": "hi"}], + tools=[ + { + "type": "function", + "function": { + "name": "noop", + "description": "", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + system_prompt="sys", + ) + + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs["tool_choice"] == "any" + + +# ═══════════════════════════════════════════════════════════════════════════ +# thinking/reasoning tests +# ═══════════════════════════════════════════════════════════════════════════ + + +@pytest.mark.asyncio +async def test_chat_extracts_thinking_from_content_chunks(mistral_llm: MistralLLM) -> None: + """chat() extracts thinking from ThinkChunk content blocks.""" + # Simulate Magistral-style response with thinking + text chunks + think_chunk = ThinkChunk( + thinking=[TextChunk(text="Let me reason about this...")], + type="thinking", + ) + text_chunk = TextChunk(text="The answer is 42", type="text") + + mistral_llm._client.chat.complete_async = AsyncMock( + return_value=_mock_response(content=[think_chunk, text_chunk], tool_calls=None) + ) + + result = await mistral_llm.chat( + messages=[{"role": "user", "content": "hi"}], + tools=[ + { + "type": "function", + "function": { + "name": "noop", + "description": "", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + system_prompt="sys", + thinking=True, + ) + + assert result.thinking == "Let me reason about this..." + assert result.text == "The answer is 42" + + +@pytest.mark.asyncio +async def test_chat_no_thinking_returns_none(mistral_llm: MistralLLM) -> None: + """chat() returns thinking=None when no ThinkChunk is present.""" + mistral_llm._client.chat.complete_async = AsyncMock( + return_value=_mock_response(content="plain text", tool_calls=None) + ) + + result = await mistral_llm.chat( + messages=[{"role": "user", "content": "hi"}], + tools=[ + { + "type": "function", + "function": { + "name": "noop", + "description": "", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + system_prompt="sys", + ) + + assert result.thinking is None + assert result.text == "plain text" + + +@pytest.mark.asyncio +async def test_chat_reasoning_effort_passed(mistral_llm_reasoning: MistralLLM) -> None: + """chat(thinking=True) passes reasoning_effort='high' for Small 4 models.""" + mock_create = AsyncMock( + return_value=_mock_response(content=None, tool_calls=None) + ) + mistral_llm_reasoning._client.chat.complete_async = mock_create + + await mistral_llm_reasoning.chat( + messages=[{"role": "user", "content": "hi"}], + tools=[ + { + "type": "function", + "function": { + "name": "noop", + "description": "", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + system_prompt="sys", + thinking=True, + ) + + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs["reasoning_effort"] == "high" + + +@pytest.mark.asyncio +async def test_chat_reasoning_effort_none_when_not_thinking( + mistral_llm_reasoning: MistralLLM, +) -> None: + """chat(thinking=False) passes reasoning_effort='none' for Small 4 models.""" + mock_create = AsyncMock( + return_value=_mock_response(content=None, tool_calls=None) + ) + mistral_llm_reasoning._client.chat.complete_async = mock_create + + await mistral_llm_reasoning.chat( + messages=[{"role": "user", "content": "hi"}], + tools=[ + { + "type": "function", + "function": { + "name": "noop", + "description": "", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + system_prompt="sys", + thinking=False, + ) + + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs["reasoning_effort"] == "none" diff --git a/backend/airweave/adapters/llm/tests/test_unavailable.py b/backend/airweave/adapters/llm/tests/test_unavailable.py new file mode 100644 index 000000000..5eae868b9 --- /dev/null +++ b/backend/airweave/adapters/llm/tests/test_unavailable.py @@ -0,0 +1,51 @@ +"""Tests for UnavailableLLM null-object adapter.""" + +from __future__ import annotations + +import pytest +from pydantic import BaseModel + +from airweave.adapters.llm.registry import PROVIDER_API_KEY_SETTINGS +from airweave.adapters.llm.unavailable import UnavailableLLM +from airweave.core.exceptions import LLMUnavailableError + + +class _Dummy(BaseModel): + key: str + + +@pytest.mark.asyncio +async def test_structured_output_raises_llm_unavailable_error() -> None: + llm = UnavailableLLM() + with pytest.raises(LLMUnavailableError): + await llm.structured_output(prompt="x", schema=_Dummy, system_prompt="y") + + +@pytest.mark.asyncio +async def test_chat_raises_llm_unavailable_error() -> None: + llm = UnavailableLLM() + with pytest.raises(LLMUnavailableError): + await llm.chat(messages=[], tools=[], system_prompt="y") + + +def test_model_spec_raises_llm_unavailable_error() -> None: + llm = UnavailableLLM() + with pytest.raises(LLMUnavailableError): + _ = llm.model_spec + + +@pytest.mark.asyncio +async def test_close_is_a_safe_noop() -> None: + llm = UnavailableLLM() + await llm.close() # must not raise; returns None by type + + +def test_error_message_mentions_accepted_api_key_env_vars() -> None: + llm = UnavailableLLM() + with pytest.raises(LLMUnavailableError) as excinfo: + _ = llm.model_spec + + message = str(excinfo.value) + for env_var in PROVIDER_API_KEY_SETTINGS.values(): + assert env_var in message, f"{env_var} missing from error message" + assert "LLM_FALLBACK_CHAIN" in message diff --git a/backend/airweave/adapters/llm/unavailable.py b/backend/airweave/adapters/llm/unavailable.py new file mode 100644 index 000000000..8d4c76a81 --- /dev/null +++ b/backend/airweave/adapters/llm/unavailable.py @@ -0,0 +1,66 @@ +"""Null-object LLM provider for deployments without any configured API key. + +Wired into the container when LLM_FALLBACK_CHAIN has no entries whose API key is +set. Instant search — which does not use an LLM — keeps working. Classic and +agentic search services are unchanged (they still expect a non-null LLMProtocol); +the failure surfaces on first use as LLMUnavailableError, which the FastAPI +exception handler maps to HTTP 503. +""" + +from __future__ import annotations + +from typing import TypeVar + +from pydantic import BaseModel + +from airweave.adapters.llm.registry import PROVIDER_API_KEY_SETTINGS, LLMModelSpec +from airweave.adapters.llm.tool_response import LLMResponse +from airweave.core.exceptions import LLMUnavailableError + +T = TypeVar("T", bound=BaseModel) + +_DETAILED_MESSAGE = ( + "No LLM provider configured. Set one of: " + f"{', '.join(PROVIDER_API_KEY_SETTINGS.values())} — " + "or customize the chain via LLM_FALLBACK_CHAIN " + "(format: 'provider:model,provider:model')." +) + + +class UnavailableLLM: + """LLMProtocol implementation that raises on every call. + + The protocol is structural (``typing.Protocol``), so no inheritance is + required. Every method and the ``model_spec`` property raise + ``LLMUnavailableError`` with an actionable message. + """ + + @property + def model_spec(self) -> LLMModelSpec: + """Raise because no provider is configured.""" + raise LLMUnavailableError(_DETAILED_MESSAGE) + + async def structured_output( + self, + prompt: str, + schema: type[T], + system_prompt: str, + thinking: bool = False, + ) -> T: + """Raise because no provider is configured.""" + raise LLMUnavailableError(_DETAILED_MESSAGE) + + async def chat( + self, + messages: list[dict], + tools: list[dict], + system_prompt: str, + thinking: bool = False, + max_tokens: int | None = None, + ) -> LLMResponse: + """Raise because no provider is configured.""" + raise LLMUnavailableError(_DETAILED_MESSAGE) + + async def close(self) -> None: + """No-op: the null-object holds no resources.""" + return None diff --git a/backend/airweave/api/middleware.py b/backend/airweave/api/middleware.py index 892c51c85..6237aefa0 100644 --- a/backend/airweave/api/middleware.py +++ b/backend/airweave/api/middleware.py @@ -22,6 +22,7 @@ AirweaveException, InvalidInputError, InvalidStateError, + LLMUnavailableError, NotFoundException, PermissionException, RateLimitExceededException, @@ -437,6 +438,7 @@ async def airweave_exception_handler(request: Request, exc: AirweaveException) - # Add new base classes here as they're introduced (BadRequestError, etc.). status_map = { TokenRefreshError: 401, + LLMUnavailableError: 503, } for exc_type, code in status_map.items(): diff --git a/backend/airweave/api/v1/endpoints/search.py b/backend/airweave/api/v1/endpoints/search.py index 9a8aa5c19..a9c6b5c74 100644 --- a/backend/airweave/api/v1/endpoints/search.py +++ b/backend/airweave/api/v1/endpoints/search.py @@ -10,7 +10,7 @@ import json from collections.abc import AsyncGenerator -from fastapi import Depends, Path +from fastapi import Depends, Path, Query from sqlalchemy.ext.asyncio import AsyncSession from starlette.responses import StreamingResponse @@ -349,3 +349,97 @@ async def admin_stream_agentic_search( media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) + + +@admin_router.post( + "/{readable_id}/search/instant/as-user", + response_model=SearchV2Response, + summary="Instant Search as User (Admin)", + description=( + "Admin-only: Instant search with access control applied for a specific user principal." + ), +) +async def admin_instant_search_as_user( + readable_id: str = Path(...), + request: InstantSearchRequest = ..., # type: ignore[assignment] + user_principal: str = Query( + ..., + description="User principal (email) to search as. " + "Access control filtering will use this user's resolved group memberships.", + ), + db: AsyncSession = Depends(deps.get_db), + ctx: ApiContext = Depends(deps.get_context), + usage_checker: UsageLimitCheckerProtocol = Inject(UsageLimitCheckerProtocol), + service: InstantSearchServiceProtocol = Inject(InstantSearchServiceProtocol), +) -> SearchV2Response: + """Admin-only: instant search with ACL filtering for a specific user.""" + _require_admin(ctx) + await usage_checker.is_allowed(db, ctx.organization.id, ActionType.QUERIES) + + results = await service.search( + db, ctx, readable_id, request, user_principal_override=user_principal + ) + return SearchV2Response(results=results.results) + + +@admin_router.post( + "/{readable_id}/search/classic/as-user", + response_model=SearchV2Response, + summary="Classic Search as User (Admin)", + description=( + "Admin-only: Classic search with access control applied for a specific user principal." + ), +) +async def admin_classic_search_as_user( + readable_id: str = Path(...), + request: ClassicSearchRequest = ..., # type: ignore[assignment] + user_principal: str = Query( + ..., + description="User principal (email) to search as. " + "Access control filtering will use this user's resolved group memberships.", + ), + db: AsyncSession = Depends(deps.get_db), + ctx: ApiContext = Depends(deps.get_context), + usage_checker: UsageLimitCheckerProtocol = Inject(UsageLimitCheckerProtocol), + service: ClassicSearchServiceProtocol = Inject(ClassicSearchServiceProtocol), +) -> SearchV2Response: + """Admin-only: classic search with ACL filtering for a specific user.""" + _require_admin(ctx) + await usage_checker.is_allowed(db, ctx.organization.id, ActionType.QUERIES) + + results = await service.search( + db, ctx, readable_id, request, user_principal_override=user_principal + ) + return SearchV2Response(results=results.results) + + +@admin_router.post( + "/{readable_id}/search/agentic/as-user", + response_model=SearchV2Response, + summary="Agentic Search as User (Admin)", + description=( + "Admin-only: Agentic search with access control applied for a specific user principal." + ), +) +async def admin_agentic_search_as_user( + readable_id: str = Path(...), + request: AgenticSearchRequest = ..., # type: ignore[assignment] + user_principal: str = Query( + ..., + description="User principal (email) to search as. " + "Access control filtering will use this user's resolved group memberships.", + ), + db: AsyncSession = Depends(deps.get_db), + ctx: ApiContext = Depends(deps.get_context), + usage_checker: UsageLimitCheckerProtocol = Inject(UsageLimitCheckerProtocol), + service: AgenticSearchServiceProtocol = Inject(AgenticSearchServiceProtocol), +) -> SearchV2Response: + """Admin-only: agentic search with ACL filtering for a specific user.""" + _require_admin(ctx) + await usage_checker.is_allowed(db, ctx.organization.id, ActionType.TOKENS) + + results = await service.search( + db, ctx, readable_id, request, user_principal_override=user_principal + ) + truncated = results.results[: request.limit] if request.limit else results.results + return SearchV2Response(results=truncated) diff --git a/backend/airweave/core/config/settings.py b/backend/airweave/core/config/settings.py index 1ab327bd0..a5bc6c69e 100644 --- a/backend/airweave/core/config/settings.py +++ b/backend/airweave/core/config/settings.py @@ -195,6 +195,13 @@ class Settings(BaseSettings): TOGETHER_API_KEY: Optional[str] = None AZURE_KEYVAULT_NAME: Optional[str] = None + # Overrides SearchConfig.LLM_FALLBACK_CHAIN when set. + # Format: comma-separated "provider:model" pairs using the values from + # airweave.adapters.llm.registry (e.g. "mistral:mistral-large" or + # "together:zai-glm-5,anthropic:claude-sonnet-4.6"). Unset → use the + # in-code default in domains/search/config.py. + LLM_FALLBACK_CHAIN: Optional[str] = None + # Docling OCR fallback service (None = disabled) DOCLING_BASE_URL: Optional[str] = None diff --git a/backend/airweave/core/container/container.py b/backend/airweave/core/container/container.py index b2fb6d25a..4b0822ad0 100644 --- a/backend/airweave/core/container/container.py +++ b/backend/airweave/core/container/container.py @@ -99,11 +99,8 @@ ) from airweave.domains.syncs.protocols import ( SyncCursorRepositoryProtocol, - SyncLifecycleServiceProtocol, - SyncRecordServiceProtocol, SyncRepositoryProtocol, SyncServiceProtocol, - SyncStateMachineProtocol, ) from airweave.domains.temporal.protocols import ( TemporalScheduleServiceProtocol, @@ -203,15 +200,11 @@ async def my_endpoint(event_bus: EventBus = Inject(EventBus)): # Sync domain sync_repo: SyncRepositoryProtocol sync_cursor_repo: SyncCursorRepositoryProtocol - # Sync cursor service — cursor CRUD operations sync_cursor_service: SyncCursorService sync_job_repo: SyncJobRepositoryProtocol - sync_record_service: SyncRecordServiceProtocol sync_job_service: SyncJobServiceProtocol sync_job_state_machine: SyncJobStateMachineProtocol - sync_state_machine: SyncStateMachineProtocol sync_service: SyncServiceProtocol - sync_lifecycle: SyncLifecycleServiceProtocol sync_factory: SyncFactoryProtocol entity_repo: EntityRepositoryProtocol diff --git a/backend/airweave/core/container/factory.py b/backend/airweave/core/container/factory.py index 61565cc00..a18a64669 100644 --- a/backend/airweave/core/container/factory.py +++ b/backend/airweave/core/container/factory.py @@ -24,6 +24,7 @@ from airweave.adapters.llm.cerebras import CerebrasLLM from airweave.adapters.llm.fallback import FallbackChainLLM from airweave.adapters.llm.groq import GroqLLM +from airweave.adapters.llm.mistral import MistralLLM from airweave.adapters.llm.registry import ( PROVIDER_API_KEY_SETTINGS, LLMProvider, @@ -32,6 +33,7 @@ get_model_spec as get_llm_model_spec, ) from airweave.adapters.llm.together import TogetherLLM +from airweave.adapters.llm.unavailable import UnavailableLLM from airweave.adapters.metrics import ( PrometheusAgenticSearchMetrics, PrometheusDbPoolMetrics, @@ -129,8 +131,6 @@ from airweave.domains.syncs.jobs.repository import SyncJobRepository from airweave.domains.syncs.jobs.service import SyncJobService from airweave.domains.syncs.jobs.state_machine import SyncJobStateMachine -from airweave.domains.syncs.lifecycle_service import SyncLifecycleService -from airweave.domains.syncs.record_service import SyncRecordService from airweave.domains.syncs.repository import SyncRepository from airweave.domains.syncs.service import SyncService from airweave.domains.syncs.state_machine import SyncStateMachine @@ -401,41 +401,24 @@ def create_container(settings: Settings) -> Container: ) sync_service = SyncService( - state_machine=sync_deps["sync_job_state_machine"], - sync_factory=sync_factory, - sync_state_machine=sync_deps["sync_state_machine"], - ) - - sync_record_service = SyncRecordService( sync_repo=source_deps["sync_repo"], sync_job_repo=source_deps["sync_job_repo"], - connection_repo=source_deps["conn_repo"], - ) - - sync_lifecycle = SyncLifecycleService( - sc_repo=source_deps["sc_repo"], - collection_repo=source_deps["collection_repo"], - connection_repo=source_deps["conn_repo"], sync_cursor_repo=sync_deps["sync_cursor_repo"], - sync_service=sync_record_service, - state_machine=sync_deps["sync_job_state_machine"], - sync_job_repo=source_deps["sync_job_repo"], + state_machine=sync_deps["sync_state_machine"], + job_state_machine=sync_deps["sync_job_state_machine"], temporal_workflow_service=sync_deps["temporal_workflow_service"], temporal_schedule_service=sync_deps["temporal_schedule_service"], - response_builder=sync_deps["response_builder"], - event_bus=event_bus, + sync_factory=sync_factory, ) # ----------------------------------------------------------------- - # Source connection sub-services + # Source connection sub-services (need sync_service) # ----------------------------------------------------------------- deletion_service = SourceConnectionDeletionService( sc_repo=source_deps["sc_repo"], collection_repo=source_deps["collection_repo"], - sync_job_repo=source_deps["sync_job_repo"], - sync_lifecycle=sync_lifecycle, response_builder=sync_deps["response_builder"], - temporal_workflow_service=sync_deps["temporal_workflow_service"], + sync_service=sync_service, ) update_service = SourceConnectionUpdateService( sc_repo=source_deps["sc_repo"], @@ -443,13 +426,13 @@ def create_container(settings: Settings) -> Container: connection_repo=source_deps["conn_repo"], cred_repo=source_deps["cred_repo"], sync_repo=source_deps["sync_repo"], - sync_record_service=sync_record_service, + sync_service=sync_service, source_service=source_deps["source_service"], + source_registry=source_deps["source_registry"], source_validation=source_validation, credential_encryptor=encryptor, response_builder=sync_deps["response_builder"], temporal_schedule_service=sync_deps["temporal_schedule_service"], - sync_state_machine=sync_deps["sync_state_machine"], ) create_service = SourceConnectionCreationService( sc_repo=source_deps["sc_repo"], @@ -459,8 +442,7 @@ def create_container(settings: Settings) -> Container: source_registry=source_deps["source_registry"], source_validation=source_validation, source_lifecycle=source_deps["source_lifecycle_service"], - sync_lifecycle=sync_lifecycle, - sync_record_service=sync_record_service, + sync_service=sync_service, response_builder=sync_deps["response_builder"], oauth_flow_service=oauth_flow_svc, temporal_workflow_service=sync_deps["temporal_workflow_service"], @@ -476,7 +458,8 @@ def create_container(settings: Settings) -> Container: source_registry=source_deps["source_registry"], auth_provider_registry=source_deps["auth_provider_registry"], response_builder=sync_deps["response_builder"], - sync_lifecycle=sync_lifecycle, + sync_service=sync_service, + event_bus=event_bus, create_service=create_service, update_service=update_service, deletion_service=deletion_service, @@ -496,6 +479,7 @@ def create_container(settings: Settings) -> Container: entity_definition_registry=source_deps["entity_definition_registry"], event_bus=event_bus, source_lifecycle=source_deps["source_lifecycle_service"], + access_broker=access_broker, ) # ----------------------------------------------------------------- @@ -504,7 +488,7 @@ def create_container(settings: Settings) -> Container: collection_service = CollectionService( collection_repo=source_deps["collection_repo"], sc_repo=source_deps["sc_repo"], - sync_lifecycle=sync_lifecycle, + sync_service=sync_service, event_bus=event_bus, settings=settings, deployment_metadata_repo=VectorDbDeploymentMetadataRepository(), @@ -520,10 +504,8 @@ def create_container(settings: Settings) -> Container: response_builder=sync_deps["response_builder"], source_registry=source_deps["source_registry"], source_lifecycle=source_deps["source_lifecycle_service"], - sync_lifecycle=sync_lifecycle, - sync_record_service=sync_record_service, + sync_service=sync_service, temporal_workflow_service=sync_deps["temporal_workflow_service"], - sync_state_machine=sync_deps["sync_state_machine"], event_bus=event_bus, organization_repo=OrgRepo(), sc_repo=source_deps["sc_repo"], @@ -618,9 +600,6 @@ def create_container(settings: Settings) -> Container: sync_job_service=sync_deps["sync_job_service"], sync_job_state_machine=sync_deps["sync_job_state_machine"], sync_service=sync_service, - sync_record_service=sync_record_service, - sync_lifecycle=sync_lifecycle, - sync_state_machine=sync_deps["sync_state_machine"], sync_factory=sync_factory, entity_repo=sync_deps["entity_repo"], access_broker=access_broker, @@ -1208,19 +1187,26 @@ def _build_llm_chain( ): """Build LLM fallback chain from SearchConfig, skipping providers without API keys. - Returns: - An LLM instance (single provider or FallbackChainLLM). - - Raises: - ValueError: If no LLM providers are available. + When no provider in the chain has a configured API key (or all fail to initialize), + returns an ``UnavailableLLM`` null-object rather than raising. This keeps the + backend bootable for deployers who only use Instant search. Classic and agentic + search surface ``LLMUnavailableError`` on first invocation, mapped to HTTP 503. """ provider_classes = { LLMProvider.ANTHROPIC: AnthropicLLM, LLMProvider.CEREBRAS: CerebrasLLM, LLMProvider.GROQ: GroqLLM, + LLMProvider.MISTRAL: MistralLLM, LLMProvider.TOGETHER: TogetherLLM, } + def _unavailable(reason: str, level: str = "info") -> UnavailableLLM: + getattr(logger, level)( + f"[SearchFactory] {reason} — classic/agentic search will return HTTP 503 " + "until a key is set. Instant search is unaffected." + ) + return UnavailableLLM() + # Collect available (provider, model_spec, class) tuples first, # then decide retry strategy based on how many survived. available = [] @@ -1239,10 +1225,7 @@ def _build_llm_chain( available.append((provider, model, model_spec, provider_cls)) if not available: - raise ValueError( - "No LLM providers available for search. " - "Configure at least one API key from SearchConfig.LLM_FALLBACK_CHAIN." - ) + return _unavailable("No LLM provider API keys configured") # Single provider: use default retries. # Multiple providers: max_retries=0 for all except the last provider, @@ -1264,9 +1247,7 @@ def _build_llm_chain( ) if not llm_providers: - raise ValueError( - "No LLM providers available for search. All configured providers failed to initialize." - ) + return _unavailable("All configured LLM providers failed to initialize", level="warning") if len(llm_providers) == 1: return llm_providers[0] @@ -1284,6 +1265,7 @@ def _create_search_services( entity_definition_registry: "EntityDefinitionRegistry", event_bus: "EventBus", source_lifecycle: "SourceLifecycleService", + access_broker: "AccessBroker", ) -> dict: """Create search domain services (LLM, tokenizer, reranker, metadata builder, per-tier). @@ -1297,9 +1279,6 @@ def _create_search_services( config = SearchConfig() # 1. Tokenizer — validate against primary LLM model requirements - if not config.LLM_FALLBACK_CHAIN: - raise ValueError("LLM_FALLBACK_CHAIN is empty — at least one provider is required") - primary_provider, primary_model = config.LLM_FALLBACK_CHAIN[0] primary_llm_spec = get_llm_model_spec(primary_provider, primary_model) @@ -1351,6 +1330,7 @@ def _create_search_services( sc_repo=sc_repo, source_registry=source_registry, source_lifecycle=source_lifecycle, + access_broker=access_broker, ) # 6. Per-tier services diff --git a/backend/airweave/core/container/tests/test_llm_chain_wiring.py b/backend/airweave/core/container/tests/test_llm_chain_wiring.py new file mode 100644 index 000000000..3226e1e3b --- /dev/null +++ b/backend/airweave/core/container/tests/test_llm_chain_wiring.py @@ -0,0 +1,53 @@ +"""Tests for _build_llm_chain: null-object fallback when no providers resolve.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +from airweave.adapters.llm.registry import LLMModel, LLMProvider +from airweave.adapters.llm.unavailable import UnavailableLLM +from airweave.core.container.factory import _build_llm_chain + + +def _settings_with_no_keys() -> MagicMock: + s = MagicMock() + for attr in ( + "TOGETHER_API_KEY", + "ANTHROPIC_API_KEY", + "MISTRAL_API_KEY", + "GROQ_API_KEY", + "CEREBRAS_API_KEY", + ): + setattr(s, attr, None) + return s + + +def _config_with_chain(chain: list[tuple[LLMProvider, LLMModel]]) -> MagicMock: + config = MagicMock() + config.LLM_FALLBACK_CHAIN = chain + return config + + +def test_returns_unavailable_llm_when_no_keys_configured() -> None: + settings = _settings_with_no_keys() + config = _config_with_chain( + [ + (LLMProvider.TOGETHER, LLMModel.ZAI_GLM_5), + (LLMProvider.ANTHROPIC, LLMModel.CLAUDE_SONNET_4_6), + ] + ) + circuit_breaker = MagicMock() + + llm = _build_llm_chain(settings, config, circuit_breaker) + + assert isinstance(llm, UnavailableLLM) + + +def test_returns_unavailable_llm_when_chain_is_empty() -> None: + settings = _settings_with_no_keys() + config = _config_with_chain([]) + circuit_breaker = MagicMock() + + llm = _build_llm_chain(settings, config, circuit_breaker) + + assert isinstance(llm, UnavailableLLM) diff --git a/backend/airweave/core/exceptions.py b/backend/airweave/core/exceptions.py index eea773d88..db8f69c65 100644 --- a/backend/airweave/core/exceptions.py +++ b/backend/airweave/core/exceptions.py @@ -208,6 +208,23 @@ def __init__( super().__init__(self.message) +class LLMUnavailableError(AirweaveException): + """Raised when an LLM-backed feature is requested but no LLM provider is configured. + + The container wires an UnavailableLLM null-object when LLM_FALLBACK_CHAIN has no + entries with a configured API key. Instant search still works; classic/agentic + search surface this on first use and map to HTTP 503. + """ + + def __init__( + self, + message: str = ("No LLM provider configured. See LLM_FALLBACK_CHAIN docs for setup."), + ): + """Create a new LLMUnavailableError with an actionable message.""" + self.message = message + super().__init__(self.message) + + class SourceRateLimitExceededException(Exception): """Exception raised when source API rate limit is exceeded. diff --git a/backend/airweave/core/shared_models.py b/backend/airweave/core/shared_models.py index 70227466a..7139a8fb9 100644 --- a/backend/airweave/core/shared_models.py +++ b/backend/airweave/core/shared_models.py @@ -109,6 +109,7 @@ class FeatureFlag(str, Enum): # These allow specific admin operations via API key authentication API_KEY_ADMIN_SYNC = "api_key_admin_sync" # Allows resync operations via API key CONNECT = "connect" # Enables the Connect playground and embeddable widget features + CUSTOM_AUTH_PROVIDER = "custom_auth_provider" # Enables the Custom auth provider class AuthMethod(str, Enum): diff --git a/backend/airweave/crud/crud_sync_job.py b/backend/airweave/crud/crud_sync_job.py index 7056a868d..fa906b682 100644 --- a/backend/airweave/crud/crud_sync_job.py +++ b/backend/airweave/crud/crud_sync_job.py @@ -39,8 +39,9 @@ async def get_all_by_sync_id( db: AsyncSession, sync_id: UUID, status: Optional[list[str]] = None, + limit: Optional[int] = None, ) -> list[SyncJob]: - """Get all jobs for a specific sync, optionally filtered by status.""" + """Get jobs for a sync; optional status filter; newest first; optional row limit.""" stmt = ( select(SyncJob, Sync.name.label("sync_name")) .join(Sync, SyncJob.sync_id == Sync.id) @@ -52,6 +53,10 @@ async def get_all_by_sync_id( # Database enum already uses uppercase values stmt = stmt.where(SyncJob.status.in_(status)) + stmt = stmt.order_by(SyncJob.created_at.desc()) + if limit is not None: + stmt = stmt.limit(limit) + result = await db.execute(stmt) jobs = [] for job, sync_name in result: diff --git a/backend/airweave/domains/auth_provider/_base.py b/backend/airweave/domains/auth_provider/_base.py index 24e833b40..537acbdf6 100644 --- a/backend/airweave/domains/auth_provider/_base.py +++ b/backend/airweave/domains/auth_provider/_base.py @@ -2,12 +2,19 @@ from abc import ABC, abstractmethod from typing import Any, ClassVar, Dict, List, Optional, Set +from uuid import UUID from pydantic import BaseModel from airweave.core.logging import logger from airweave.domains.auth_provider.auth_result import AuthResult +# OAuth lifecycle fields that auth providers handle internally — +# always optional when fetching credentials from a provider. +AUTH_PROVIDER_OPTIONAL_FIELDS: frozenset[str] = frozenset( + {"refresh_token", "client_id", "client_secret"} +) + class BaseAuthProvider(ABC): """Base class for all auth providers.""" @@ -19,6 +26,7 @@ class BaseAuthProvider(ABC): auth_config_class: ClassVar[Optional[type[BaseModel]]] = None config_class: ClassVar[Optional[type[BaseModel]]] = None SETTINGS_URL: ClassVar[str] = "" + feature_flag: ClassVar[Optional[str]] = None def __init__(self): """Initialize the base auth provider.""" @@ -58,6 +66,7 @@ async def get_creds_for_source( source_short_name: str, source_auth_config_fields: List[str], optional_fields: Optional[Set[str]] = None, + source_connection_id: Optional[UUID] = None, ) -> Dict[str, Any]: """Get credentials for a source. @@ -65,6 +74,8 @@ async def get_creds_for_source( source_short_name: The short name of the source to get credentials for source_auth_config_fields: The fields required for the source auth config optional_fields: Fields that can be skipped if the provider doesn't have them + source_connection_id: UUID of the source connection (used by Custom provider + to scope credentials per connection) """ pass @@ -104,6 +115,7 @@ async def get_auth_result( source_auth_config_fields: List[str], optional_fields: Optional[Set[str]] = None, source_config_field_mappings: Optional[Dict[str, str]] = None, + source_connection_id: Optional[UUID] = None, ) -> AuthResult: """Get auth result with credentials for a source. @@ -115,12 +127,16 @@ async def get_auth_result( source_auth_config_fields: The fields required for the source auth config optional_fields: Fields that can be skipped if the provider doesn't have them source_config_field_mappings: Mapping of config fields extractable from auth response + source_connection_id: UUID of the source connection Returns: AuthResult with credentials and optional source config """ credentials = await self.get_creds_for_source( - source_short_name, source_auth_config_fields, optional_fields + source_short_name, + source_auth_config_fields, + optional_fields, + source_connection_id=source_connection_id, ) source_config = {} diff --git a/backend/airweave/domains/auth_provider/providers/__init__.py b/backend/airweave/domains/auth_provider/providers/__init__.py index 8d5dd9306..7805a13e0 100644 --- a/backend/airweave/domains/auth_provider/providers/__init__.py +++ b/backend/airweave/domains/auth_provider/providers/__init__.py @@ -1,9 +1,11 @@ """Auth provider implementations.""" from .composio import ComposioAuthProvider +from .custom import CustomAuthProvider from .pipedream import PipedreamAuthProvider ALL_AUTH_PROVIDERS: list[type] = [ ComposioAuthProvider, + CustomAuthProvider, PipedreamAuthProvider, ] diff --git a/backend/airweave/domains/auth_provider/providers/composio.py b/backend/airweave/domains/auth_provider/providers/composio.py index edfb52f48..9d95daf5f 100644 --- a/backend/airweave/domains/auth_provider/providers/composio.py +++ b/backend/airweave/domains/auth_provider/providers/composio.py @@ -1,6 +1,7 @@ """Composio Test Auth Provider - provides authentication services for other integrations.""" from typing import Any, Dict, List, Optional, Set +from uuid import UUID import httpx @@ -211,6 +212,7 @@ async def get_creds_for_source( source_short_name: str, source_auth_config_fields: List[str], optional_fields: Optional[Set[str]] = None, + source_connection_id: Optional[UUID] = None, ) -> Dict[str, Any]: """Get credentials for a specific source integration. diff --git a/backend/airweave/domains/auth_provider/providers/custom.py b/backend/airweave/domains/auth_provider/providers/custom.py new file mode 100644 index 000000000..9440a7baf --- /dev/null +++ b/backend/airweave/domains/auth_provider/providers/custom.py @@ -0,0 +1,223 @@ +"""Custom Auth Provider - fetches tokens from a customer-hosted HTTP endpoint.""" + +from typing import Any, Dict, List, Optional, Set +from uuid import UUID + +import httpx + +from airweave.domains.auth_provider._base import BaseAuthProvider +from airweave.domains.auth_provider.exceptions import ( + AuthProviderAuthError, + AuthProviderConfigError, + AuthProviderMissingFieldsError, + AuthProviderRateLimitError, + AuthProviderTemporaryError, +) +from airweave.platform.configs.auth import CustomAuthConfig +from airweave.platform.configs.config import CustomConfig +from airweave.platform.decorators import auth_provider +from airweave.platform.utils.ssrf import SSRFViolation, validate_url + + +@auth_provider( + name="Custom", + short_name="custom", + auth_config_class=CustomAuthConfig, + config_class=CustomConfig, + feature_flag="custom_auth_provider", +) +class CustomAuthProvider(BaseAuthProvider): + """Custom authentication provider. + + Calls GET {base_url}/{source_connection_id} on a customer-hosted endpoint + to fetch fresh credentials. The customer is responsible for returning + the freshest credentials as JSON. + """ + + BLOCKED_SOURCES: list[str] = ["ctti"] + + # Map Airweave-internal field names to the simple names customers return. + # Customers always return {"access_token": "..."} or {"api_key": "..."}. + FIELD_NAME_MAPPING: Dict[str, str] = { + "personal_access_token": "access_token", # GitHub + "api_token": "access_token", # Document360, Pipedrive + } + + # Instance attributes set in create() + base_endpoint_url: str + api_key: str + + @classmethod + async def create( + cls, + credentials: Optional[Dict[str, Any]] = None, + config: Optional[Dict[str, Any]] = None, + ) -> "CustomAuthProvider": + """Create a new Custom auth provider instance.""" + if credentials is None: + raise ValueError("credentials parameter is required") + auth_config = CustomAuthConfig(**credentials) + instance = cls() + instance.base_endpoint_url = auth_config.base_endpoint_url + instance.api_key = auth_config.api_key + return instance + + def _build_headers(self) -> Dict[str, str]: + """Build request headers with API key authentication.""" + return { + "Accept": "application/json", + "X-API-Key": self.api_key, + } + + def _check_ssrf(self, url: str) -> None: + """Validate URL against SSRF blocklist before making a request.""" + try: + validate_url(url) + except SSRFViolation as exc: + self.logger.warning(f"[Custom] SSRF blocked: {exc}") + raise AuthProviderConfigError( + f"Custom endpoint URL blocked by SSRF policy: {exc}", + provider_name="custom", + ) from exc + + def _raise_for_http_status(self, e: httpx.HTTPStatusError, source_short_name: str) -> None: + """Classify an HTTP error status into the appropriate auth provider exception.""" + status = e.response.status_code + self.logger.error(f"[Custom] HTTP {status} from endpoint for source '{source_short_name}'") + if status in (401, 403): + raise AuthProviderAuthError( + f"Custom endpoint returned {status} for source '{source_short_name}'", + provider_name="custom", + ) from e + if status == 429: + retry_after = float(e.response.headers.get("retry-after", 30)) + raise AuthProviderRateLimitError( + f"Custom endpoint rate-limited for source '{source_short_name}'", + provider_name="custom", + retry_after=retry_after, + ) from e + if status == 404: + raise AuthProviderMissingFieldsError( + f"Custom endpoint has no credentials configured for " + f"source '{source_short_name}' (404)", + provider_name="custom", + missing_fields=[], + available_fields=[], + ) from e + if status >= 500: + raise AuthProviderTemporaryError( + f"Custom endpoint returned {status} for source '{source_short_name}'", + provider_name="custom", + status_code=status, + ) from e + raise AuthProviderConfigError( + f"Custom endpoint returned unexpected {status} for source '{source_short_name}'", + provider_name="custom", + ) from e + + async def get_creds_for_source( + self, + source_short_name: str, + source_auth_config_fields: List[str], + optional_fields: Optional[Set[str]] = None, + source_connection_id: Optional[UUID] = None, + ) -> Dict[str, Any]: + """Get credentials for a source by calling GET {base_url}/{source_connection_id}.""" + if not source_connection_id: + raise AuthProviderConfigError( + "Custom auth provider requires a source_connection_id", + provider_name="custom", + ) + _optional_fields = optional_fields or set() + headers = self._build_headers() + url = f"{self.base_endpoint_url}/{source_connection_id}" + + self._check_ssrf(url) + self.logger.info(f"[Custom] Fetching credentials for source '{source_short_name}'") + + async with httpx.AsyncClient(timeout=30.0, follow_redirects=False) as client: + try: + response = await client.get(url, headers=headers) + response.raise_for_status() + data = response.json() + except httpx.HTTPStatusError as e: + self._raise_for_http_status(e, source_short_name) + except (httpx.ConnectError, httpx.TimeoutException) as e: + self.logger.error(f"[Custom] Network error reaching endpoint: {e}") + raise AuthProviderTemporaryError( + f"Custom endpoint unreachable: {e}", + provider_name="custom", + ) from e + + missing_fields = [] + found_credentials: Dict[str, Any] = {} + + for field in source_auth_config_fields: + # Check the response using the mapped name (e.g. access_token for + # personal_access_token), then store under the Airweave-internal name. + mapped = self.FIELD_NAME_MAPPING.get(field, field) + if mapped in data: + found_credentials[field] = data[mapped] + elif field in data: + found_credentials[field] = data[field] + elif field not in _optional_fields: + missing_fields.append(mapped) + + if missing_fields: + available = list(data.keys()) + self.logger.error( + f"[Custom] Missing required fields for source '{source_short_name}': " + f"{missing_fields}. Available: {available}" + ) + raise AuthProviderMissingFieldsError( + f"Custom endpoint response missing required fields for " + f"source '{source_short_name}': {missing_fields}", + provider_name="custom", + missing_fields=missing_fields, + available_fields=available, + ) + + self.logger.info( + f"[Custom] Successfully retrieved {len(found_credentials)} credential fields " + f"for source '{source_short_name}'" + ) + return found_credentials + + async def validate(self) -> bool: + """Validate the custom endpoint by calling GET {base_url}.""" + headers = self._build_headers() + url = self.base_endpoint_url + + self._check_ssrf(url) + self.logger.info("[Custom] Validating endpoint") + + try: + async with httpx.AsyncClient(timeout=30.0, follow_redirects=False) as client: + response = await client.get(url, headers=headers) + response.raise_for_status() + + self.logger.info("[Custom] Endpoint validated successfully") + return True + + except httpx.HTTPStatusError as e: + status = e.response.status_code + if status in (401, 403): + raise AuthProviderAuthError( + f"Custom endpoint validation failed: {status}", + provider_name="custom", + ) from e + if status >= 500: + raise AuthProviderTemporaryError( + f"Custom endpoint validation failed: {status}", + provider_name="custom", + status_code=status, + ) from e + raise AuthProviderConfigError( + f"Custom endpoint validation failed: HTTP {status}", + provider_name="custom", + ) from e + except (httpx.ConnectError, httpx.TimeoutException) as e: + raise AuthProviderTemporaryError( + f"Custom endpoint unreachable during validation: {e}", + provider_name="custom", + ) from e diff --git a/backend/airweave/domains/auth_provider/providers/pipedream.py b/backend/airweave/domains/auth_provider/providers/pipedream.py index edccb5de8..94f1bedb7 100644 --- a/backend/airweave/domains/auth_provider/providers/pipedream.py +++ b/backend/airweave/domains/auth_provider/providers/pipedream.py @@ -2,6 +2,7 @@ import time from typing import Any, Dict, List, Optional, Set +from uuid import UUID import httpx @@ -278,6 +279,7 @@ async def get_creds_for_source( source_short_name: str, source_auth_config_fields: List[str], optional_fields: Optional[Set[str]] = None, + source_connection_id: Optional[UUID] = None, ) -> Dict[str, Any]: """Get credentials for a source from Pipedream. diff --git a/backend/airweave/domains/auth_provider/registry.py b/backend/airweave/domains/auth_provider/registry.py index 4608f3def..800d71509 100644 --- a/backend/airweave/domains/auth_provider/registry.py +++ b/backend/airweave/domains/auth_provider/registry.py @@ -93,6 +93,7 @@ def _build_entry(provider_cls: type) -> AuthProviderRegistryEntry: field_name_mapping: dict[str, str] = getattr(provider_cls, "FIELD_NAME_MAPPING", {}) slug_name_mapping: dict[str, str] = getattr(provider_cls, "SLUG_NAME_MAPPING", {}) settings_url: str = getattr(provider_cls, "SETTINGS_URL", "") + feature_flag: str | None = getattr(provider_cls, "feature_flag", None) # ------------------------------------------------------------------ # Precompute fields @@ -120,4 +121,6 @@ def _build_entry(provider_cls: type) -> AuthProviderRegistryEntry: slug_name_mapping=slug_name_mapping, # Settings URL settings_url=settings_url, + # Feature flag + feature_flag=feature_flag, ) diff --git a/backend/airweave/domains/auth_provider/service.py b/backend/airweave/domains/auth_provider/service.py index b44cb9994..f8ff7d1cd 100644 --- a/backend/airweave/domains/auth_provider/service.py +++ b/backend/airweave/domains/auth_provider/service.py @@ -10,7 +10,7 @@ from airweave.core import credentials from airweave.core.datetime_utils import utc_now_naive from airweave.core.exceptions import InvalidInputError, InvalidStateError, NotFoundException -from airweave.core.shared_models import ConnectionStatus, IntegrationType +from airweave.core.shared_models import ConnectionStatus, FeatureFlag, IntegrationType from airweave.db.unit_of_work import UnitOfWork from airweave.domains.auth_provider.protocols import ( AuthProviderRegistryProtocol, @@ -52,8 +52,17 @@ async def list_connections( return result async def list_metadata(self, *, ctx: ApiContext) -> list[AuthProviderMetadata]: - """List auth provider metadata from registry.""" - return [self._entry_to_metadata(entry) for entry in self._registry.list_all()] + """List auth provider metadata from registry. + + Entries gated by a feature flag are excluded unless the organization + has that flag enabled. + """ + enabled_features = ctx.organization.enabled_features or [] + return [ + self._entry_to_metadata(entry) + for entry in self._registry.list_all() + if not self._is_hidden_by_feature_flag(entry, enabled_features) + ] async def get_metadata(self, *, short_name: str, ctx: ApiContext) -> AuthProviderMetadata: """Get auth provider metadata by short name from registry.""" @@ -388,6 +397,19 @@ async def _to_schema( masked_client_id=masked_client_id, ) + @staticmethod + def _is_hidden_by_feature_flag( + entry: AuthProviderRegistryEntry, enabled_features: list[FeatureFlag] + ) -> bool: + """Return True if the entry requires a feature flag the org doesn't have.""" + if not entry.feature_flag: + return False + try: + required = FeatureFlag(entry.feature_flag) + return required not in enabled_features + except ValueError: + return False + @staticmethod def _entry_to_metadata(entry: AuthProviderRegistryEntry) -> AuthProviderMetadata: """Map a registry entry to public metadata.""" diff --git a/backend/airweave/domains/auth_provider/tests/test_custom.py b/backend/airweave/domains/auth_provider/tests/test_custom.py new file mode 100644 index 000000000..84e978115 --- /dev/null +++ b/backend/airweave/domains/auth_provider/tests/test_custom.py @@ -0,0 +1,355 @@ +"""Tests for CustomAuthProvider.""" + +from unittest.mock import AsyncMock, patch +from uuid import UUID + +import httpx +import pytest + +from airweave.domains.auth_provider.exceptions import ( + AuthProviderAuthError, + AuthProviderConfigError, + AuthProviderMissingFieldsError, + AuthProviderRateLimitError, + AuthProviderTemporaryError, +) +from airweave.domains.auth_provider.providers.custom import CustomAuthProvider + +TEST_SC_ID = UUID("d035439c-dc7d-4813-a207-c68e548cfe51") + + +@pytest.fixture +async def provider(): + """Create a Custom provider.""" + return await CustomAuthProvider.create( + credentials={ + "base_endpoint_url": "https://api.example.com/tokens", + "api_key": "my-secret-key", + } + ) + + +class TestCreate: + """Tests for CustomAuthProvider.create().""" + + @pytest.mark.unit + async def test_create(self, provider): + assert provider.base_endpoint_url == "https://api.example.com/tokens" + assert provider.api_key == "my-secret-key" + + @pytest.mark.unit + async def test_create_strips_trailing_slash(self): + p = await CustomAuthProvider.create( + credentials={ + "base_endpoint_url": "https://api.example.com/tokens/", + "api_key": "key", + } + ) + assert p.base_endpoint_url == "https://api.example.com/tokens" + + +class TestBuildHeaders: + """Tests for _build_headers().""" + + @pytest.mark.unit + async def test_headers(self, provider): + headers = provider._build_headers() + assert headers["Accept"] == "application/json" + assert headers["X-API-Key"] == "my-secret-key" + + +class TestGetCredsForSource: + """Tests for get_creds_for_source().""" + + @pytest.mark.unit + async def test_requires_source_connection_id(self, provider): + with pytest.raises(AuthProviderConfigError, match="source_connection_id"): + await provider.get_creds_for_source("slack", ["access_token"]) + + @pytest.mark.unit + async def test_success(self, provider): + mock_response = httpx.Response( + 200, + json={"access_token": "eyJ-gdrive-token", "refresh_token": "rt-123"}, + request=httpx.Request("GET", f"https://api.example.com/tokens/{TEST_SC_ID}"), + ) + + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=mock_response): + creds = await provider.get_creds_for_source( + "google_drive", + ["access_token"], + source_connection_id=TEST_SC_ID, + ) + + assert creds == {"access_token": "eyJ-gdrive-token"} + + @pytest.mark.unit + async def test_maps_access_token_to_personal_access_token(self, provider): + """Customer returns access_token, provider maps to personal_access_token for GitHub.""" + mock_response = httpx.Response( + 200, + json={"access_token": "ghp_test123"}, + request=httpx.Request("GET", f"https://api.example.com/tokens/{TEST_SC_ID}"), + ) + + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=mock_response): + creds = await provider.get_creds_for_source( + "github", + ["personal_access_token"], + source_connection_id=TEST_SC_ID, + ) + + assert creds == {"personal_access_token": "ghp_test123"} + + @pytest.mark.unit + async def test_maps_access_token_to_api_token(self, provider): + """Customer returns access_token, provider maps to api_token for Document360.""" + mock_response = httpx.Response( + 200, + json={"access_token": "doc360_token"}, + request=httpx.Request("GET", f"https://api.example.com/tokens/{TEST_SC_ID}"), + ) + + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=mock_response): + creds = await provider.get_creds_for_source( + "document360", + ["api_token"], + source_connection_id=TEST_SC_ID, + ) + + assert creds == {"api_token": "doc360_token"} + + @pytest.mark.unit + async def test_calls_correct_url(self, provider): + mock_response = httpx.Response( + 200, + json={"access_token": "token"}, + request=httpx.Request("GET", f"https://api.example.com/tokens/{TEST_SC_ID}"), + ) + + with patch( + "httpx.AsyncClient.get", + new_callable=AsyncMock, + return_value=mock_response, + ) as mock_get: + await provider.get_creds_for_source( + "slack", ["access_token"], source_connection_id=TEST_SC_ID, + ) + + mock_get.assert_called_once() + call_args = mock_get.call_args + assert call_args.args[0] == f"https://api.example.com/tokens/{TEST_SC_ID}" + + @pytest.mark.unit + async def test_optional_fields_not_required(self, provider): + mock_response = httpx.Response( + 200, + json={"access_token": "token"}, + request=httpx.Request("GET", f"https://api.example.com/tokens/{TEST_SC_ID}"), + ) + + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=mock_response): + creds = await provider.get_creds_for_source( + "google_drive", + ["access_token", "refresh_token"], + optional_fields={"refresh_token"}, + source_connection_id=TEST_SC_ID, + ) + + assert creds == {"access_token": "token"} + + @pytest.mark.unit + async def test_error_401(self, provider): + mock_response = httpx.Response( + 401, + json={"error": "unauthorized"}, + request=httpx.Request("GET", f"https://api.example.com/tokens/{TEST_SC_ID}"), + ) + + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=mock_response): + with pytest.raises(AuthProviderAuthError, match="401"): + await provider.get_creds_for_source( + "slack", ["access_token"], source_connection_id=TEST_SC_ID, + ) + + @pytest.mark.unit + async def test_error_429(self, provider): + mock_response = httpx.Response( + 429, + json={"error": "rate limited"}, + headers={"retry-after": "60"}, + request=httpx.Request("GET", f"https://api.example.com/tokens/{TEST_SC_ID}"), + ) + + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=mock_response): + with pytest.raises(AuthProviderRateLimitError) as exc_info: + await provider.get_creds_for_source( + "slack", ["access_token"], source_connection_id=TEST_SC_ID, + ) + assert exc_info.value.retry_after == 60.0 + + @pytest.mark.unit + async def test_error_500(self, provider): + mock_response = httpx.Response( + 500, + json={"error": "internal"}, + request=httpx.Request("GET", f"https://api.example.com/tokens/{TEST_SC_ID}"), + ) + + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=mock_response): + with pytest.raises(AuthProviderTemporaryError, match="500"): + await provider.get_creds_for_source( + "slack", ["access_token"], source_connection_id=TEST_SC_ID, + ) + + @pytest.mark.unit + async def test_error_timeout(self, provider): + with patch( + "httpx.AsyncClient.get", + new_callable=AsyncMock, + side_effect=httpx.TimeoutException("timed out"), + ): + with pytest.raises(AuthProviderTemporaryError, match="unreachable"): + await provider.get_creds_for_source( + "slack", ["access_token"], source_connection_id=TEST_SC_ID, + ) + + @pytest.mark.unit + async def test_error_missing_fields(self, provider): + mock_response = httpx.Response( + 200, + json={"some_other_field": "value"}, + request=httpx.Request("GET", f"https://api.example.com/tokens/{TEST_SC_ID}"), + ) + + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=mock_response): + with pytest.raises(AuthProviderMissingFieldsError) as exc_info: + await provider.get_creds_for_source( + "slack", ["access_token"], source_connection_id=TEST_SC_ID, + ) + assert "access_token" in exc_info.value.missing_fields + + @pytest.mark.unit + async def test_error_404(self, provider): + mock_response = httpx.Response( + 404, + json={"error": "not found"}, + request=httpx.Request("GET", f"https://api.example.com/tokens/{TEST_SC_ID}"), + ) + + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=mock_response): + with pytest.raises(AuthProviderMissingFieldsError, match="404"): + await provider.get_creds_for_source( + "slack", ["access_token"], source_connection_id=TEST_SC_ID, + ) + + @pytest.mark.unit + async def test_ssrf_blocked(self, provider): + provider.base_endpoint_url = "http://169.254.169.254/latest/meta-data" + + with pytest.raises(AuthProviderConfigError, match="SSRF"): + await provider.get_creds_for_source( + "slack", ["access_token"], source_connection_id=TEST_SC_ID, + ) + + +class TestValidate: + """Tests for validate().""" + + @pytest.mark.unit + async def test_validate_success(self, provider): + mock_response = httpx.Response( + 200, + json={"status": "ok"}, + request=httpx.Request("GET", "https://api.example.com/tokens"), + ) + + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=mock_response): + result = await provider.validate() + + assert result is True + + @pytest.mark.unit + async def test_validate_auth_error(self, provider): + mock_response = httpx.Response( + 401, + json={"error": "unauthorized"}, + request=httpx.Request("GET", "https://api.example.com/tokens"), + ) + + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=mock_response): + with pytest.raises(AuthProviderAuthError): + await provider.validate() + + @pytest.mark.unit + async def test_validate_server_error(self, provider): + mock_response = httpx.Response( + 503, + json={"error": "unavailable"}, + request=httpx.Request("GET", "https://api.example.com/tokens"), + ) + + with patch("httpx.AsyncClient.get", new_callable=AsyncMock, return_value=mock_response): + with pytest.raises(AuthProviderTemporaryError): + await provider.validate() + + @pytest.mark.unit + async def test_validate_timeout(self, provider): + with patch( + "httpx.AsyncClient.get", + new_callable=AsyncMock, + side_effect=httpx.TimeoutException("timed out"), + ): + with pytest.raises(AuthProviderTemporaryError, match="unreachable"): + await provider.validate() + + @pytest.mark.unit + async def test_validate_ssrf_blocked(self, provider): + provider.base_endpoint_url = "http://169.254.169.254/latest/meta-data" + + with pytest.raises(AuthProviderConfigError, match="SSRF"): + await provider.validate() + + +class TestFollowRedirectsDisabled: + """Verify httpx.AsyncClient is created with follow_redirects=False.""" + + @pytest.mark.unit + async def test_get_creds_no_follow_redirects(self, provider): + with patch( + "airweave.domains.auth_provider.providers.custom.httpx.AsyncClient" + ) as mock_client_cls: + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.get.return_value = httpx.Response( + 200, + json={"access_token": "tok"}, + request=httpx.Request("GET", f"https://api.example.com/tokens/{TEST_SC_ID}"), + ) + mock_client_cls.return_value = mock_client + + await provider.get_creds_for_source( + "slack", ["access_token"], source_connection_id=TEST_SC_ID, + ) + + mock_client_cls.assert_called_once_with(timeout=30.0, follow_redirects=False) + + @pytest.mark.unit + async def test_validate_no_follow_redirects(self, provider): + with patch( + "airweave.domains.auth_provider.providers.custom.httpx.AsyncClient" + ) as mock_client_cls: + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.get.return_value = httpx.Response( + 200, + json={"status": "ok"}, + request=httpx.Request("GET", "https://api.example.com/tokens"), + ) + mock_client_cls.return_value = mock_client + + await provider.validate() + + mock_client_cls.assert_called_once_with(timeout=30.0, follow_redirects=False) diff --git a/backend/airweave/domains/auth_provider/tests/test_service.py b/backend/airweave/domains/auth_provider/tests/test_service.py index afac3cfa3..33c6be419 100644 --- a/backend/airweave/domains/auth_provider/tests/test_service.py +++ b/backend/airweave/domains/auth_provider/tests/test_service.py @@ -9,7 +9,7 @@ from pydantic import BaseModel from airweave.core.exceptions import InvalidInputError, InvalidStateError, NotFoundException -from airweave.core.shared_models import IntegrationType +from airweave.core.shared_models import FeatureFlag, IntegrationType from airweave.domains.auth_provider.service import AuthProviderService from airweave.domains.auth_provider.types import AuthProviderRegistryEntry from airweave.platform.configs._base import Fields @@ -403,3 +403,73 @@ async def test_to_schema_include_and_exclude_masked(): no_mask = await service._to_schema("db", conn, _ctx(), include_masked_client_id=False) assert no_mask.masked_client_id is None + + +# --------------------------------------------------------------------------- +# Feature flag filtering +# --------------------------------------------------------------------------- + +_UNFLAGGED = _entry(short_name="composio") +_FLAGGED = AuthProviderRegistryEntry( + **{**_entry(short_name="custom").model_dump(), "name": "Custom", "feature_flag": "custom_auth_provider"} +) + + +def _ctx_with_features(features: list[FeatureFlag] | None = None): + """Build a context with organization.enabled_features.""" + org = SimpleNamespace(enabled_features=features or []) + logger = SimpleNamespace(info=lambda *a, **k: None, error=lambda *a, **k: None) + return SimpleNamespace( + logger=logger, + organization=org, + has_user_context=True, + tracking_email="owner@airweave.ai", + ) + + +class TestIsHiddenByFeatureFlag: + """Tests for _is_hidden_by_feature_flag.""" + + def test_no_flag_always_visible(self): + assert AuthProviderService._is_hidden_by_feature_flag(_UNFLAGGED, []) is False + + def test_flag_missing_from_org(self): + assert AuthProviderService._is_hidden_by_feature_flag(_FLAGGED, []) is True + + def test_flag_present_in_org(self): + assert ( + AuthProviderService._is_hidden_by_feature_flag( + _FLAGGED, [FeatureFlag.CUSTOM_AUTH_PROVIDER] + ) + is False + ) + + def test_unknown_flag_fails_open(self): + entry = AuthProviderRegistryEntry( + **{**_UNFLAGGED.model_dump(), "feature_flag": "nonexistent_flag"} + ) + assert AuthProviderService._is_hidden_by_feature_flag(entry, []) is False + + +@pytest.mark.asyncio +async def test_list_metadata_hides_flagged_provider(): + registry = SimpleNamespace(list_all=lambda: [_UNFLAGGED, _FLAGGED]) + service = AuthProviderService(registry, connection_repo=None, credential_repo=None) + + ctx = _ctx_with_features() + result = await service.list_metadata(ctx=ctx) + names = [m.short_name for m in result] + assert "composio" in names + assert "custom" not in names + + +@pytest.mark.asyncio +async def test_list_metadata_shows_flagged_provider_with_flag(): + registry = SimpleNamespace(list_all=lambda: [_UNFLAGGED, _FLAGGED]) + service = AuthProviderService(registry, connection_repo=None, credential_repo=None) + + ctx = _ctx_with_features([FeatureFlag.CUSTOM_AUTH_PROVIDER]) + result = await service.list_metadata(ctx=ctx) + names = [m.short_name for m in result] + assert "composio" in names + assert "custom" in names diff --git a/backend/airweave/domains/auth_provider/types.py b/backend/airweave/domains/auth_provider/types.py index 4704756ed..19892ffbd 100644 --- a/backend/airweave/domains/auth_provider/types.py +++ b/backend/airweave/domains/auth_provider/types.py @@ -24,6 +24,9 @@ class AuthProviderRegistryEntry(BaseRegistryEntry): # Settings dashboard URL settings_url: str = "" + # Feature flag gating + feature_flag: str | None = None + class AuthProviderMetadata(BaseRegistryEntry): """Public auth provider metadata returned by API endpoints.""" diff --git a/backend/airweave/domains/collections/service.py b/backend/airweave/domains/collections/service.py index c77a6b3b6..86f77712a 100644 --- a/backend/airweave/domains/collections/service.py +++ b/backend/airweave/domains/collections/service.py @@ -21,7 +21,7 @@ ) from airweave.domains.embedders.protocols import DenseEmbedderRegistryProtocol from airweave.domains.source_connections.protocols import SourceConnectionRepositoryProtocol -from airweave.domains.syncs.protocols import SyncLifecycleServiceProtocol +from airweave.domains.syncs.protocols import SyncServiceProtocol from airweave.models.collection import Collection from airweave.schemas.collection import SourceConnectionSummary @@ -33,7 +33,7 @@ def __init__( self, collection_repo: CollectionRepositoryProtocol, sc_repo: SourceConnectionRepositoryProtocol, - sync_lifecycle: SyncLifecycleServiceProtocol, + sync_service: SyncServiceProtocol, event_bus: EventBus, settings: Settings, deployment_metadata_repo: VectorDbDeploymentMetadataRepositoryProtocol, @@ -42,7 +42,7 @@ def __init__( """Initialize with injected dependencies.""" self._collection_repo = collection_repo self._sc_repo = sc_repo - self._sync_lifecycle = sync_lifecycle + self._sync_service = sync_service self._event_bus = event_bus self._settings = settings self._deployment_metadata_repo = deployment_metadata_repo @@ -179,34 +179,31 @@ async def delete( if db_obj is None: raise CollectionNotFoundError(readable_id) - collection_id = db_obj.id - organization_id = ctx.organization.id - # Snapshot while session is fresh (teardown expires all objects via db.expire_all) result = self._to_response(db_obj) # Collect sync IDs before CASCADE removes them sync_ids = await self._sc_repo.get_sync_ids_for_collection( - db, organization_id=organization_id, readable_collection_id=result.readable_id + db, organization_id=ctx.organization.id, readable_collection_id=result.readable_id ) - # Cancel running workflows and wait for workers to stop - await self._sync_lifecycle.teardown_syncs_for_collection( - db, - sync_ids=sync_ids, - collection_id=collection_id, - organization_id=organization_id, - ctx=ctx, - ) + for sid in sync_ids: + await self._sync_service.delete( + db, + sync_id=sid, + collection_id=result.id, + organization_id=ctx.organization.id, + ctx=ctx, + ) # CASCADE-delete the collection and all child objects - await self._collection_repo.remove(db, id=collection_id, ctx=ctx) + await self._collection_repo.remove(db, id=result.id, ctx=ctx) # Publish event try: await self._event_bus.publish( CollectionLifecycleEvent.deleted( - organization_id=organization_id, + organization_id=ctx.organization.id, collection_id=result.id, collection_name=result.name, collection_readable_id=result.readable_id, diff --git a/backend/airweave/domains/collections/tests/test_service.py b/backend/airweave/domains/collections/tests/test_service.py index 1b16d38aa..7d019faf4 100644 --- a/backend/airweave/domains/collections/tests/test_service.py +++ b/backend/airweave/domains/collections/tests/test_service.py @@ -28,7 +28,7 @@ from airweave.domains.source_connections.fakes.repository import ( FakeSourceConnectionRepository, ) -from airweave.domains.syncs.fakes.lifecycle_service import FakeSyncLifecycleService +from airweave.domains.syncs.fakes.service import FakeSyncService from airweave.models.collection import Collection from airweave.schemas.organization import Organization @@ -132,7 +132,7 @@ def _fake_dense_registry() -> FakeDenseEmbedderRegistry: def _build_service( collection_repo=None, sc_repo=None, - sync_lifecycle=None, + sync_service=None, event_bus=None, settings=None, deployment_metadata_repo=None, @@ -141,7 +141,7 @@ def _build_service( return CollectionService( collection_repo=collection_repo or FakeCollectionRepository(), sc_repo=sc_repo or FakeSourceConnectionRepository(), - sync_lifecycle=sync_lifecycle or FakeSyncLifecycleService(), + sync_service=sync_service or FakeSyncService(), event_bus=event_bus or _FakeEventBus(), settings=settings or _fake_settings(), deployment_metadata_repo=deployment_metadata_repo @@ -388,7 +388,7 @@ async def test_delete_full_flow(): """delete() gathers sync IDs, calls teardown, cascade-deletes, publishes event.""" repo = FakeCollectionRepository() sc_repo = FakeSourceConnectionRepository() - sync_lifecycle = FakeSyncLifecycleService() + sync_service = FakeSyncService() event_bus = _FakeEventBus() col = _collection() @@ -402,7 +402,7 @@ async def test_delete_full_flow(): svc = _build_service( collection_repo=repo, sc_repo=sc_repo, - sync_lifecycle=sync_lifecycle, + sync_service=sync_service, event_bus=event_bus, ) @@ -410,13 +410,13 @@ async def test_delete_full_flow(): assert result is not None - # Verify teardown was called with correct args - teardown_calls = [c for c in sync_lifecycle._calls if c[0] == "teardown_syncs_for_collection"] - assert len(teardown_calls) == 1 - _, _, sync_ids, coll_id, org_id, _ = teardown_calls[0] - assert set(sync_ids) == {sync_id_1, sync_id_2} - assert coll_id == COLLECTION_ID - assert org_id == ORG_ID + # Verify one delete per sync ID: ("delete", sync_id, collection_id, organization_id) + delete_calls = [c for c in sync_service._calls if c[0] == "delete"] + assert len(delete_calls) == 2 + assert {c[1] for c in delete_calls} == {sync_id_1, sync_id_2} + for c in delete_calls: + assert c[2] == COLLECTION_ID # collection_id + assert c[3] == ORG_ID # organization_id # Verify cascade delete was called remove_calls = [c for c in repo._calls if c[0] == "remove"] @@ -438,10 +438,10 @@ async def test_delete_not_found(): @pytest.mark.asyncio async def test_delete_no_syncs(): - """delete() works when collection has no syncs — teardown called with empty list.""" + """delete() works when collection has no syncs — no sync delete calls (empty gather).""" repo = FakeCollectionRepository() sc_repo = FakeSourceConnectionRepository() - sync_lifecycle = FakeSyncLifecycleService() + sync_service = FakeSyncService() event_bus = _FakeEventBus() col = _collection() @@ -452,7 +452,7 @@ async def test_delete_no_syncs(): svc = _build_service( collection_repo=repo, sc_repo=sc_repo, - sync_lifecycle=sync_lifecycle, + sync_service=sync_service, event_bus=event_bus, ) @@ -460,7 +460,5 @@ async def test_delete_no_syncs(): assert result is not None - teardown_calls = [c for c in sync_lifecycle._calls if c[0] == "teardown_syncs_for_collection"] - assert len(teardown_calls) == 1 - _, _, sync_ids, _, _, _ = teardown_calls[0] - assert sync_ids == [] + delete_calls = [c for c in sync_service._calls if c[0] == "delete"] + assert len(delete_calls) == 0 diff --git a/backend/airweave/domains/connect/tests/conftest.py b/backend/airweave/domains/connect/tests/conftest.py index 2237507db..342f7c414 100644 --- a/backend/airweave/domains/connect/tests/conftest.py +++ b/backend/airweave/domains/connect/tests/conftest.py @@ -63,8 +63,8 @@ def org_repo(): @pytest.fixture -def sc_service(fake_sync_lifecycle): - return FakeSourceConnectionService(sync_lifecycle=fake_sync_lifecycle) +def sc_service(fake_sync_service): + return FakeSourceConnectionService(sync_service=fake_sync_service) @pytest.fixture diff --git a/backend/airweave/domains/oauth/callback_service.py b/backend/airweave/domains/oauth/callback_service.py index 8b6ee0650..a682eda78 100644 --- a/backend/airweave/domains/oauth/callback_service.py +++ b/backend/airweave/domains/oauth/callback_service.py @@ -22,7 +22,7 @@ from airweave.core.logging import logger from airweave.core.protocols.encryption import CredentialEncryptor from airweave.core.protocols.event_bus import EventBus -from airweave.core.shared_models import AuthMethod, ConnectionStatus, SyncJobStatus, SyncStatus +from airweave.core.shared_models import AuthMethod, ConnectionStatus, SyncJobStatus from airweave.db.unit_of_work import UnitOfWork from airweave.domains.collections.protocols import CollectionRepositoryProtocol from airweave.domains.connections.protocols import ConnectionRepositoryProtocol @@ -45,12 +45,7 @@ ) from airweave.domains.sources.types import SourceRegistryEntry from airweave.domains.syncs.jobs.protocols import SyncJobRepositoryProtocol -from airweave.domains.syncs.protocols import ( - SyncLifecycleServiceProtocol, - SyncRecordServiceProtocol, - SyncRepositoryProtocol, - SyncStateMachineProtocol, -) +from airweave.domains.syncs.protocols import SyncRepositoryProtocol, SyncServiceProtocol from airweave.domains.syncs.types import InvalidSyncTransitionError, OptimisticLockError from airweave.domains.temporal.protocols import TemporalWorkflowServiceProtocol from airweave.models.collection import Collection @@ -87,10 +82,8 @@ def __init__( response_builder: ResponseBuilderProtocol, source_registry: SourceRegistryProtocol, source_lifecycle: SourceLifecycleServiceProtocol, - sync_lifecycle: SyncLifecycleServiceProtocol, - sync_record_service: SyncRecordServiceProtocol, + sync_service: SyncServiceProtocol, temporal_workflow_service: TemporalWorkflowServiceProtocol, - sync_state_machine: SyncStateMachineProtocol, event_bus: EventBus, organization_repo: OrganizationRepositoryProtocol, sc_repo: SourceConnectionRepositoryProtocol, @@ -107,10 +100,8 @@ def __init__( self._response_builder = response_builder self._source_registry = source_registry self._source_lifecycle = source_lifecycle - self._sync_lifecycle = sync_lifecycle - self._sync_record_service = sync_record_service + self._sync_service = sync_service self._temporal_workflow_service = temporal_workflow_service - self._sync_state_machine = sync_state_machine self._event_bus = event_bus self._organization_repo = organization_repo self._sc_repo = sc_repo @@ -203,6 +194,7 @@ async def complete_oauth2_callback( await self._validate_oauth2_token_or_raise( source_entry=source_entry, access_token=token_response.access_token, + config=source_conn_shell.config_fields, ctx=ctx, ) @@ -529,11 +521,9 @@ async def _complete_connection_common( # noqa: C901 if raw_cron: schedule_config = ScheduleConfig(cron=raw_cron) - destination_ids = await self._sync_record_service.resolve_destination_ids( - uow.session, ctx - ) + destination_ids = await self._sync_service.resolve_destination_ids(uow.session, ctx) - sync_result = await self._sync_lifecycle.provision_sync( + sync_result = await self._sync_service.create( uow.session, name=payload.get("name") or source_entry.name, source_connection_id=connection.id, @@ -575,10 +565,9 @@ async def _complete_connection_common( # noqa: C901 if source_conn.sync_id: try: - await self._sync_state_machine.transition( - sync_id=source_conn.sync_id, - target=SyncStatus.ACTIVE, - ctx=ctx, + await self._sync_service.resume( + source_conn.sync_id, + ctx, reason="OAuth completed", ) except (InvalidSyncTransitionError, OptimisticLockError, ValueError): @@ -591,6 +580,7 @@ async def _validate_oauth2_token_or_raise( *, source_entry: SourceRegistryEntry | None, access_token: str, + config: dict | None = None, ctx: ApiContext, ) -> None: """Validate OAuth2 token using source lifecycle service; fail callback if invalid.""" @@ -601,6 +591,7 @@ async def _validate_oauth2_token_or_raise( await self._source_lifecycle.validate( short_name=source_entry.short_name, credentials=access_token, + config=config, ) except (SourceNotFoundError, SourceError) as e: raise http_exception_for_credential_validation( diff --git a/backend/airweave/domains/oauth/oauth2_service.py b/backend/airweave/domains/oauth/oauth2_service.py index e53fa4e26..b5c2188ef 100644 --- a/backend/airweave/domains/oauth/oauth2_service.py +++ b/backend/airweave/domains/oauth/oauth2_service.py @@ -475,6 +475,82 @@ async def refresh_and_persist( expires_in=response.expires_in, ) + async def exchange_token_for_scope( + self, + db: AsyncSession, + integration_short_name: str, + connection_id: UUID, + ctx: ApiContext, + scope: str, + ) -> str: + """Exchange refresh token for an access token with a different scope. + + Uses the existing refresh token but requests a different resource scope + (e.g., SharePoint REST API scope instead of Graph scope). + Does NOT persist the rotated refresh token. + + Returns the access token string for the requested scope. + """ + connection = await self.conn_repo.get(db=db, id=connection_id, ctx=ctx) + if not connection or not connection.integration_credential_id: + raise OAuthRefreshCredentialMissingError( + f"Connection {connection_id} not found or has no credential", + integration_short_name=integration_short_name, + ) + + credential = await self.cred_repo.get( + db=db, id=connection.integration_credential_id, ctx=ctx + ) + if not credential: + raise OAuthRefreshCredentialMissingError( + "Integration credential not found", + integration_short_name=integration_short_name, + ) + + decrypted = self.encryptor.decrypt(credential.encrypted_credentials) + refresh_token = await self._get_refresh_token(ctx.logger, decrypted) + + integration_config = await self._get_integration_config(ctx.logger, integration_short_name) + + client_id, client_secret = await self._get_client_credentials( + integration_config, None, decrypted + ) + + # Build request with explicit scope (unlike normal refresh which skips scope) + headers = {"Content-Type": integration_config.content_type} + payload = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "scope": scope, + } + + if integration_config.client_credential_location == "header": + encoded = self._encode_client_credentials(client_id, client_secret) + headers["Authorization"] = f"Basic {encoded}" + else: + payload["client_id"] = client_id + payload["client_secret"] = client_secret + + ctx.logger.info( + f"Exchanging token for scope {scope} (integration={integration_short_name})" + ) + + response = await self._make_token_request( + ctx.logger, + integration_config.backend_url, + headers, + payload, + integration_short_name=integration_short_name, + ) + + # Parse response but do NOT persist the refresh token + token_response = OAuth2TokenResponse(**response.json()) + ctx.logger.info( + f"Successfully exchanged token for scope {scope} " + f"(expires_in={token_response.expires_in})" + ) + return str(token_response.access_token) + # ------------------------------------------------------------------ # Private helpers # ------------------------------------------------------------------ diff --git a/backend/airweave/domains/oauth/protocols.py b/backend/airweave/domains/oauth/protocols.py index ba55594e1..4faa4cf05 100644 --- a/backend/airweave/domains/oauth/protocols.py +++ b/backend/airweave/domains/oauth/protocols.py @@ -143,6 +143,24 @@ async def refresh_and_persist( """ ... + async def exchange_token_for_scope( + self, + db: AsyncSession, + integration_short_name: str, + connection_id: UUID, + ctx: ApiContext, + scope: str, + ) -> str: + """Exchange refresh token for an access token with a different scope. + + Uses the existing refresh token but requests a different resource scope. + Does NOT persist the rotated refresh token (the response token is + scoped to the new resource and should not replace the original). + + Returns the access token string for the requested scope. + """ + ... + # --------------------------------------------------------------------------- # Init session + redirect session repositories diff --git a/backend/airweave/domains/oauth/tests/test_callback_service.py b/backend/airweave/domains/oauth/tests/test_callback_service.py index 82bd3ad9a..a2b1a762c 100644 --- a/backend/airweave/domains/oauth/tests/test_callback_service.py +++ b/backend/airweave/domains/oauth/tests/test_callback_service.py @@ -120,10 +120,8 @@ def _service( response_builder=None, source_registry=None, source_lifecycle=None, - sync_lifecycle=None, - sync_record_service=None, + sync_service=None, temporal_workflow_service=None, - sync_state_machine=None, event_bus=None, ) -> OAuthCallbackService: return OAuthCallbackService( @@ -132,10 +130,8 @@ def _service( response_builder=response_builder or AsyncMock(), source_registry=source_registry or MagicMock(), source_lifecycle=source_lifecycle or AsyncMock(), - sync_lifecycle=sync_lifecycle or AsyncMock(), - sync_record_service=sync_record_service or AsyncMock(), + sync_service=sync_service or AsyncMock(), temporal_workflow_service=temporal_workflow_service or AsyncMock(), - sync_state_machine=sync_state_machine or AsyncMock(), event_bus=event_bus or AsyncMock(), organization_repo=organization_repo or FakeOrganizationRepository(), sc_repo=sc_repo or FakeSourceConnectionRepository(), @@ -869,8 +865,8 @@ async def test_federated_source_skips_sync_provisioning(self): svc._source_registry.get = MagicMock( return_value=SimpleNamespace(source_class_ref=SimpleNamespace(federated_search=True)) ) - svc._sync_record_service.resolve_destination_ids = AsyncMock(return_value=[uuid4()]) - svc._sync_lifecycle.provision_sync = AsyncMock() + svc._sync_service.resolve_destination_ids = AsyncMock(return_value=[uuid4()]) + svc._sync_service.create = AsyncMock() from airweave.domains.oauth import callback_service as callback_module @@ -916,7 +912,7 @@ async def commit(self): finally: monkeypatch.undo() - svc._sync_lifecycle.provision_sync.assert_not_awaited() + svc._sync_service.create.assert_not_awaited() async def test_claim_token_session_skips_mark_completed(self): svc = _service() @@ -931,9 +927,9 @@ async def test_claim_token_session_skips_mark_completed(self): return_value=SimpleNamespace(id=sc_id, connection_id=conn_id, sync_id=None) ) svc._init_session_repo.mark_completed = AsyncMock() - svc._sync_record_service.resolve_destination_ids = AsyncMock(return_value=[uuid4()]) + svc._sync_service.resolve_destination_ids = AsyncMock(return_value=[uuid4()]) sync_id = uuid4() - svc._sync_lifecycle.provision_sync = AsyncMock( + svc._sync_service.create = AsyncMock( return_value=SimpleNamespace(sync_id=sync_id) ) @@ -1000,9 +996,9 @@ async def test_non_federated_source_provisions_sync_with_cron_schedule(self): svc._source_registry.get = MagicMock( return_value=SimpleNamespace(source_class_ref=SimpleNamespace(federated_search=False)) ) - svc._sync_record_service.resolve_destination_ids = AsyncMock(return_value=[uuid4()]) + svc._sync_service.resolve_destination_ids = AsyncMock(return_value=[uuid4()]) sync_id = uuid4() - svc._sync_lifecycle.provision_sync = AsyncMock( + svc._sync_service.create = AsyncMock( return_value=SimpleNamespace(sync_id=sync_id) ) @@ -1050,7 +1046,7 @@ async def commit(self): finally: monkeypatch.undo() - svc._sync_lifecycle.provision_sync.assert_awaited_once() + svc._sync_service.create.assert_awaited_once() svc._init_session_repo.mark_completed.assert_awaited_once() @@ -2219,7 +2215,7 @@ async def test_complete_connection_common_activates_sync(self): """_complete_connection_common transitions sync to ACTIVE after commit.""" from unittest.mock import patch - sync_state_machine = AsyncMock() + sync_service = AsyncMock() sc_repo = FakeSourceConnectionRepository() shell = _source_conn_shell() shell.connection_id = uuid4() @@ -2227,7 +2223,7 @@ async def test_complete_connection_common_activates_sync(self): svc = _service( sc_repo=sc_repo, - sync_state_machine=sync_state_machine, + sync_service=sync_service, ) svc._validate_config = MagicMock(return_value={}) collection = MagicMock() @@ -2267,4 +2263,4 @@ async def test_complete_connection_common_activates_sync(self): has_claim_token=True, ) - sync_state_machine.transition.assert_awaited_once() + sync_service.resume.assert_awaited_once() diff --git a/backend/airweave/domains/search/adapters/vector_db/fakes/vector_db.py b/backend/airweave/domains/search/adapters/vector_db/fakes/vector_db.py index 1a4a4312e..5c3c63e41 100644 --- a/backend/airweave/domains/search/adapters/vector_db/fakes/vector_db.py +++ b/backend/airweave/domains/search/adapters/vector_db/fakes/vector_db.py @@ -61,6 +61,7 @@ async def compile_query( plan: SearchPlan, embeddings: QueryEmbeddings, collection_id: str, + acl_principals: list[str] | None = None, ) -> CompiledQuery: """Return a fake compiled query, or raise seeded error.""" self._calls.append(("compile_query", plan, embeddings, collection_id)) diff --git a/backend/airweave/domains/search/adapters/vector_db/protocol.py b/backend/airweave/domains/search/adapters/vector_db/protocol.py index 2cefaabea..aa5f0dac2 100644 --- a/backend/airweave/domains/search/adapters/vector_db/protocol.py +++ b/backend/airweave/domains/search/adapters/vector_db/protocol.py @@ -1,6 +1,6 @@ """Vector database protocol for the search module.""" -from typing import Protocol +from typing import Optional, Protocol from airweave.domains.search.types.embeddings import QueryEmbeddings from airweave.domains.search.types.filters import FilterGroup @@ -23,6 +23,7 @@ async def compile_query( plan: SearchPlan, embeddings: QueryEmbeddings, collection_id: str, + acl_principals: Optional[list[str]] = None, ) -> CompiledQuery: """Compile plan and embeddings into a DB-specific query. @@ -30,6 +31,10 @@ async def compile_query( plan: Search plan with queries, filters, strategy, pagination. embeddings: Dense and sparse embeddings for the queries. collection_id: Collection readable ID for tenant filtering. + acl_principals: Resolved user principals for access control filtering. + None = no AC sources in collection (skip filtering). + [] = user has no principals (only public entities visible). + ["user:x", "group:y"] = match these principals. Returns: CompiledQuery with raw (full) and display (no embeddings) versions. diff --git a/backend/airweave/domains/search/adapters/vector_db/vespa_client.py b/backend/airweave/domains/search/adapters/vector_db/vespa_client.py index 9a864fafd..9dfcb2704 100644 --- a/backend/airweave/domains/search/adapters/vector_db/vespa_client.py +++ b/backend/airweave/domains/search/adapters/vector_db/vespa_client.py @@ -87,9 +87,10 @@ async def compile_query( plan: SearchPlan, embeddings: QueryEmbeddings, collection_id: str, + acl_principals: Optional[list[str]] = None, ) -> CompiledQuery: """Compile plan and embeddings into Vespa query.""" - yql = self._build_yql(plan, collection_id) + yql = self._build_yql(plan, collection_id, acl_principals=acl_principals) params = self._build_params(plan, embeddings) raw_query = {"yql": yql, "params": params} @@ -236,7 +237,12 @@ async def close(self) -> None: # YQL Building # ========================================================================= - def _build_yql(self, plan: SearchPlan, collection_id: str) -> str: + def _build_yql( + self, + plan: SearchPlan, + collection_id: str, + acl_principals: Optional[list[str]] = None, + ) -> str: """Build the complete YQL query string.""" num_embeddings = self._count_dense_embeddings(plan) retrieval_clause = self._build_retrieval_clause(plan.retrieval_strategy, num_embeddings) @@ -250,11 +256,31 @@ def _build_yql(self, plan: SearchPlan, collection_id: str) -> str: if filter_yql: where_parts.append(f"({filter_yql})") + acl_yql = self._build_acl_clause(acl_principals) + if acl_yql: + where_parts.append(f"({acl_yql})") + all_schemas = ", ".join(ALL_VESPA_SCHEMAS) yql = f"select * from sources {all_schemas} where {' AND '.join(where_parts)}" return yql + def _build_acl_clause(self, acl_principals: Optional[list[str]]) -> Optional[str]: + """Build access control YQL clause from resolved principals. + + Returns None if acl_principals is None (no AC sources → skip filtering). + Returns a clause that matches public entities or entities with matching viewers. + """ + if acl_principals is None: + return None + + clauses = ["access_is_public = true"] + for principal in acl_principals: + escaped = principal.replace("\\", "\\\\").replace("'", "\\'") + clauses.append(f"access_viewers contains '{escaped}'") + + return " OR ".join(clauses) + def _build_retrieval_clause( self, strategy: RetrievalStrategy, diff --git a/backend/airweave/domains/search/agentic/agent.py b/backend/airweave/domains/search/agentic/agent.py index 8097ebed1..75b0da96a 100644 --- a/backend/airweave/domains/search/agentic/agent.py +++ b/backend/airweave/domains/search/agentic/agent.py @@ -128,11 +128,13 @@ async def run( ctx: ApiContext, readable_id: str, request: AgenticSearchRequest, + user_principal_override: str | None = None, ) -> SearchResults: """Run the agent loop. Emits events throughout. Returns collected results.""" start_time = time.monotonic() state = AgentState() diag = _DiagnosticsAccumulator() + self._user_principal_override = user_principal_override ctx.logger.info( f"Agentic search started collection={readable_id} query={request.query!r} " f"thinking={request.thinking}" @@ -208,7 +210,14 @@ async def _run( # noqa: C901 — agent loop orchestration is inherently complex thinking_enabled = request.thinking # Construct per-request tools - dispatcher = self._build_dispatcher(collection_id, user_filter, db, ctx, readable_id) + dispatcher = self._build_dispatcher( + collection_id, + user_filter, + db, + ctx, + readable_id, + user_principal=self._user_principal_override, + ) context_mgr = ContextManager( tokenizer=self._tokenizer, @@ -587,6 +596,7 @@ def _build_dispatcher( db: AsyncSession, ctx: ApiContext, collection_readable_id: str, + user_principal: str | None = None, ) -> ToolDispatcher: """Construct tools and dispatcher for this request.""" return ToolDispatcher( @@ -598,6 +608,7 @@ def _build_dispatcher( db=db, ctx=ctx, collection_readable_id=collection_readable_id, + user_principal=user_principal, ), ToolName.READ: ReadTool( vector_db=self._vector_db, diff --git a/backend/airweave/domains/search/agentic/service.py b/backend/airweave/domains/search/agentic/service.py index bd167cd36..7f15fd193 100644 --- a/backend/airweave/domains/search/agentic/service.py +++ b/backend/airweave/domains/search/agentic/service.py @@ -72,6 +72,7 @@ async def search( ctx: ApiContext, readable_id: str, request: AgenticSearchRequest, + user_principal_override: str | None = None, ) -> SearchResults: """Run agentic search and return results.""" agent = Agent( @@ -85,4 +86,6 @@ async def search( event_bus=self._event_bus, config=SearchConfig(), ) - return await agent.run(db, ctx, readable_id, request) + return await agent.run( + db, ctx, readable_id, request, user_principal_override=user_principal_override + ) diff --git a/backend/airweave/domains/search/agentic/tools/search.py b/backend/airweave/domains/search/agentic/tools/search.py index 3818d6311..20a6e53c1 100644 --- a/backend/airweave/domains/search/agentic/tools/search.py +++ b/backend/airweave/domains/search/agentic/tools/search.py @@ -43,6 +43,7 @@ def __init__( db: AsyncSession, ctx: ApiContext, collection_readable_id: str, + user_principal: str | None = None, ) -> None: """Initialize with executor, user filter, collection ID, and request context.""" self._executor = executor @@ -51,6 +52,7 @@ def __init__( self._db = db self._ctx = ctx self._collection_readable_id = collection_readable_id + self._user_principal = user_principal async def execute( self, @@ -67,6 +69,7 @@ async def execute( db=self._db, ctx=self._ctx, collection_readable_id=self._collection_readable_id, + user_principal=self._user_principal, ) # Track new results in state diff --git a/backend/airweave/domains/search/classic/service.py b/backend/airweave/domains/search/classic/service.py index f8ed42a94..b2f98db5f 100644 --- a/backend/airweave/domains/search/classic/service.py +++ b/backend/airweave/domains/search/classic/service.py @@ -68,13 +68,16 @@ async def search( ctx: ApiContext, readable_id: str, request: ClassicSearchRequest, + user_principal_override: str | None = None, ) -> SearchResults: """Generate strategy via LLM, execute, optionally rerank, return results.""" start_time = time.monotonic() ctx.logger.info(f"Classic search started collection={readable_id} query={request.query!r}") try: - result = await self._execute(db, ctx, readable_id, request, start_time) + result = await self._execute( + db, ctx, readable_id, request, start_time, user_principal_override + ) duration_ms = int((time.monotonic() - start_time) * 1000) ctx.logger.info( f"Classic search completed collection={readable_id} " @@ -106,6 +109,7 @@ async def _execute( readable_id: str, request: ClassicSearchRequest, start_time: float, + user_principal_override: str | None = None, ) -> SearchResults: """Internal execution — resolve collection, LLM strategy, search, rerank.""" # 1. Resolve collection @@ -157,6 +161,7 @@ async def _execute( db=db, ctx=ctx, collection_readable_id=readable_id, + user_principal=user_principal_override, ) # 6. Optional rerank diff --git a/backend/airweave/domains/search/config.py b/backend/airweave/domains/search/config.py index d812862b4..224e2dad1 100644 --- a/backend/airweave/domains/search/config.py +++ b/backend/airweave/domains/search/config.py @@ -2,8 +2,70 @@ from enum import Enum -from airweave.adapters.llm.registry import LLMModel, LLMProvider +from airweave.adapters.llm.registry import MODEL_REGISTRY, LLMModel, LLMProvider from airweave.adapters.tokenizer.registry import TokenizerEncoding, TokenizerType +from airweave.core.config import settings + +_DEFAULT_LLM_FALLBACK_CHAIN: list[tuple[LLMProvider, LLMModel]] = [ + (LLMProvider.TOGETHER, LLMModel.ZAI_GLM_5), + (LLMProvider.ANTHROPIC, LLMModel.CLAUDE_SONNET_4_6), +] + +# Value → enum lookup tables built once at import time. Dict insertion order +# matches enum declaration order, which we surface in error messages. +_VALID_PROVIDERS: dict[str, LLMProvider] = {p.value: p for p in LLMProvider} +_VALID_MODELS: dict[str, LLMModel] = {m.value: m for m in LLMModel} + + +def parse_llm_fallback_chain(raw: str | None) -> list[tuple[LLMProvider, LLMModel]]: + """Parse the LLM_FALLBACK_CHAIN env var. + + Format: comma-separated ``provider:model`` pairs using the enum ``value`` + strings from ``airweave.adapters.llm.registry``. When ``raw`` is None or + empty, returns the in-code default chain. + + Raises ValueError at import time (startup) on unknown provider or model names, + listing the accepted values so deployers can fix the misconfiguration fast. + """ + if not raw or not raw.strip(): + return list(_DEFAULT_LLM_FALLBACK_CHAIN) + + parsed: list[tuple[LLMProvider, LLMModel]] = [] + for entry in raw.split(","): + entry = entry.strip() + if not entry: + continue + if ":" not in entry: + raise ValueError( + f"Invalid LLM_FALLBACK_CHAIN entry {entry!r}: expected 'provider:model'." + ) + provider_raw, model_raw = entry.split(":", 1) + provider_raw = provider_raw.strip() + model_raw = model_raw.strip() + + if provider_raw not in _VALID_PROVIDERS: + raise ValueError( + f"Unknown provider {provider_raw!r} in LLM_FALLBACK_CHAIN. " + f"Accepted: {list(_VALID_PROVIDERS)}." + ) + if model_raw not in _VALID_MODELS: + raise ValueError( + f"Unknown model {model_raw!r} in LLM_FALLBACK_CHAIN. " + f"Accepted: {list(_VALID_MODELS)}." + ) + provider = _VALID_PROVIDERS[provider_raw] + model = _VALID_MODELS[model_raw] + provider_models = MODEL_REGISTRY.get(provider, {}) + if model not in provider_models: + raise ValueError( + f"Model {model_raw!r} not available for provider {provider_raw!r}. " + f"Available: {[m.value for m in provider_models]}." + ) + parsed.append((provider, model)) + + if not parsed: + return list(_DEFAULT_LLM_FALLBACK_CHAIN) + return parsed class DatabaseImpl(str, Enum): @@ -35,13 +97,14 @@ class SearchConfig: # configured) and responds successfully handles the request. Subsequent # providers are only tried when the previous one fails. # - # To change the primary model, reorder this list or swap the model for a - # provider. For example, to use GPT_OSS_120B on Cerebras instead of GLM: - # (LLMProvider.CEREBRAS, LLMModel.GPT_OSS_120B), - LLM_FALLBACK_CHAIN: list[tuple[LLMProvider, LLMModel]] = [ - (LLMProvider.TOGETHER, LLMModel.ZAI_GLM_5), - (LLMProvider.ANTHROPIC, LLMModel.CLAUDE_SONNET_4_6), - ] + # Deployers can override via the LLM_FALLBACK_CHAIN env var + # (format: "provider:model,provider:model"). Unset → use the default below. + # Evaluated once at class-definition time. Tests that need to vary this must + # call parse_llm_fallback_chain directly or reload the module — monkey- + # patching settings.LLM_FALLBACK_CHAIN after import has no effect here. + LLM_FALLBACK_CHAIN: list[tuple[LLMProvider, LLMModel]] = parse_llm_fallback_chain( + settings.LLM_FALLBACK_CHAIN + ) # Tokenizer # Note: Must be compatible with the chosen LLM model (validated at startup) diff --git a/backend/airweave/domains/search/executor.py b/backend/airweave/domains/search/executor.py index 05c4c8ad4..8869e81f1 100644 --- a/backend/airweave/domains/search/executor.py +++ b/backend/airweave/domains/search/executor.py @@ -9,12 +9,13 @@ import asyncio from datetime import datetime -from typing import Any +from typing import Any, Optional from uuid import UUID from sqlalchemy.ext.asyncio import AsyncSession from airweave.api.context import ApiContext +from airweave.domains.access_control.protocols import AccessBrokerProtocol from airweave.domains.embedders.protocols import DenseEmbedderProtocol, SparseEmbedderProtocol from airweave.domains.search.adapters.vector_db.protocol import VectorDBProtocol from airweave.domains.search.builders.search_plan import SearchPlanBuilder @@ -67,6 +68,7 @@ def __init__( sc_repo: SourceConnectionRepositoryProtocol, source_registry: SourceRegistryProtocol, source_lifecycle: SourceLifecycleServiceProtocol, + access_broker: AccessBrokerProtocol, ) -> None: """Initialize with embedders, vector database, and federated source dependencies.""" self._dense_embedder = dense_embedder @@ -75,6 +77,7 @@ def __init__( self._sc_repo = sc_repo self._source_registry = source_registry self._source_lifecycle = source_lifecycle + self._access_broker = access_broker async def execute( self, @@ -84,8 +87,14 @@ async def execute( db: AsyncSession, ctx: ApiContext, collection_readable_id: str, + user_principal: Optional[str] = None, ) -> SearchResults: """Execute the full search pipeline including federated sources.""" + # 0. Resolve access control principals + acl_principals = await self._resolve_acl_principals( + db, ctx, user_principal, collection_readable_id + ) + # 1. Merge plan filters with user filters complete_plan = SearchPlanBuilder.build(plan, user_filter) @@ -106,7 +115,9 @@ async def execute( # 4. Run vector DB search and federated search in parallel fetch_limit = original_offset + original_limit - vector_task = asyncio.create_task(self._execute_vector_search(complete_plan, collection_id)) + vector_task = asyncio.create_task( + self._execute_vector_search(complete_plan, collection_id, acl_principals) + ) fed_task = None if federated_sources: @@ -130,6 +141,10 @@ async def execute( # Filter federated results in-memory and merge, or slice vector-only. fed_filtered = self._apply_filters_in_memory(fed_results, complete_plan.filter_groups) + # Also apply ACL filtering to federated results in-memory + if acl_principals is not None: + fed_filtered = self._apply_acl_in_memory(fed_filtered, acl_principals) + if fed_filtered: merged = self._merge_with_rrf(vector_results, fed_filtered) return SearchResults(results=merged[original_offset : original_offset + original_limit]) @@ -143,6 +158,7 @@ async def _execute_vector_search( self, plan: SearchPlan, collection_id: str, + acl_principals: Optional[list[str]] = None, ) -> list[SearchResult]: """Embed, compile, and execute vector DB search. @@ -174,9 +190,68 @@ async def _execute_vector_search( plan=plan, embeddings=embeddings, collection_id=collection_id, + acl_principals=acl_principals, ) return (await self._vector_db.execute_query(compiled_query)).results + # ------------------------------------------------------------------ + # Access control resolution + # ------------------------------------------------------------------ + + async def _resolve_acl_principals( + self, + db: AsyncSession, + ctx: ApiContext, + user_principal: Optional[str], + collection_readable_id: str, + ) -> Optional[list[str]]: + """Resolve user's ACL principals for a collection. + + Returns None if user_principal is not set or collection has no AC sources. + Returns a list of principals (possibly empty) otherwise. + """ + if not user_principal: + return None + + access_context = await self._access_broker.resolve_access_context_for_collection( + db=db, + user_principal=user_principal, + readable_collection_id=collection_readable_id, + organization_id=ctx.organization.id, + ) + + if access_context is None: + return None + + principals = list(access_context.all_principals) + ctx.logger.info(f"[ACL] Resolved {len(principals)} principals for user '{user_principal}'") + return principals + + @staticmethod + def _apply_acl_in_memory( + results: list[SearchResult], + principals: list[str], + ) -> list[SearchResult]: + """Apply ACL filtering to results in-memory (for federated sources). + + Keeps results that are: + - From non-AC sources (is_public is None — no ACL data means pass through) + - Explicitly public (is_public is True) + - Matching a viewer principal + """ + principal_set = set(principals) + + def _passes(r: SearchResult) -> bool: + if r.access.is_public is None: + return True # Non-AC source — no access data, pass through + if r.access.is_public: + return True + if r.access.viewers: + return bool(principal_set & set(r.access.viewers)) + return False + + return [r for r in results if _passes(r)] + # ------------------------------------------------------------------ # Federated source discovery # ------------------------------------------------------------------ diff --git a/backend/airweave/domains/search/fakes/executor.py b/backend/airweave/domains/search/fakes/executor.py index 5a97fb8d0..c49226d85 100644 --- a/backend/airweave/domains/search/fakes/executor.py +++ b/backend/airweave/domains/search/fakes/executor.py @@ -40,6 +40,7 @@ async def execute( db: Any = None, ctx: Any = None, collection_readable_id: str = "", + user_principal: str | None = None, ) -> SearchResults: """Record the call and return seeded result, or raise seeded error.""" self._calls.append(("execute", plan, user_filter, collection_id)) diff --git a/backend/airweave/domains/search/instant/service.py b/backend/airweave/domains/search/instant/service.py index 75ad1e7b0..5e973f8e2 100644 --- a/backend/airweave/domains/search/instant/service.py +++ b/backend/airweave/domains/search/instant/service.py @@ -49,13 +49,16 @@ async def search( ctx: ApiContext, readable_id: str, request: InstantSearchRequest, + user_principal_override: str | None = None, ) -> SearchResults: """Build plan from request and execute.""" start_time = time.monotonic() ctx.logger.info(f"Instant search started collection={readable_id} query={request.query!r}") try: - result = await self._execute(db, ctx, readable_id, request, start_time) + result = await self._execute( + db, ctx, readable_id, request, start_time, user_principal_override + ) duration_ms = int((time.monotonic() - start_time) * 1000) ctx.logger.info( f"Instant search completed collection={readable_id} " @@ -87,6 +90,7 @@ async def _execute( readable_id: str, request: InstantSearchRequest, start_time: float, + user_principal_override: str | None = None, ) -> SearchResults: """Internal execution — resolve collection, build plan, execute.""" collection = await self._collection_repo.get_by_readable_id(db, readable_id, ctx) @@ -107,6 +111,7 @@ async def _execute( db=db, ctx=ctx, collection_readable_id=readable_id, + user_principal=user_principal_override, ) duration_ms = int((time.monotonic() - start_time) * 1000) diff --git a/backend/airweave/domains/search/protocols.py b/backend/airweave/domains/search/protocols.py index 1ae3251b3..04cee0ae1 100644 --- a/backend/airweave/domains/search/protocols.py +++ b/backend/airweave/domains/search/protocols.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Optional, Protocol, runtime_checkable from sqlalchemy.ext.asyncio import AsyncSession @@ -39,6 +39,7 @@ async def execute( db: AsyncSession, ctx: ApiContext, collection_readable_id: str, + user_principal: Optional[str] = None, ) -> SearchResults: """Execute a search plan and return results.""" ... @@ -68,6 +69,7 @@ async def search( ctx: ApiContext, readable_id: str, request: InstantSearchRequest, + user_principal_override: Optional[str] = None, ) -> SearchResults: """Execute instant search and return results.""" ... @@ -83,6 +85,7 @@ async def search( ctx: ApiContext, readable_id: str, request: ClassicSearchRequest, + user_principal_override: Optional[str] = None, ) -> SearchResults: """Execute classic search and return results.""" ... @@ -98,6 +101,7 @@ async def search( ctx: ApiContext, readable_id: str, request: AgenticSearchRequest, + user_principal_override: Optional[str] = None, ) -> SearchResults: """Execute agentic search and return results.""" ... diff --git a/backend/airweave/domains/search/tests/test_config.py b/backend/airweave/domains/search/tests/test_config.py new file mode 100644 index 000000000..f1c267cba --- /dev/null +++ b/backend/airweave/domains/search/tests/test_config.py @@ -0,0 +1,84 @@ +"""Tests for LLM_FALLBACK_CHAIN env-var parser in search config.""" + +from __future__ import annotations + +import pytest + +from airweave.adapters.llm.registry import LLMModel, LLMProvider +from airweave.domains.search.config import ( + _DEFAULT_LLM_FALLBACK_CHAIN, + parse_llm_fallback_chain, +) + + +def test_none_returns_default_chain() -> None: + assert parse_llm_fallback_chain(None) == list(_DEFAULT_LLM_FALLBACK_CHAIN) + + +def test_empty_string_returns_default_chain() -> None: + assert parse_llm_fallback_chain("") == list(_DEFAULT_LLM_FALLBACK_CHAIN) + + +def test_whitespace_only_returns_default_chain() -> None: + assert parse_llm_fallback_chain(" ") == list(_DEFAULT_LLM_FALLBACK_CHAIN) + + +def test_single_entry_parsed_to_tuple() -> None: + parsed = parse_llm_fallback_chain("mistral:mistral-large") + assert parsed == [(LLMProvider.MISTRAL, LLMModel.MISTRAL_LARGE)] + + +def test_multiple_entries_preserve_order() -> None: + parsed = parse_llm_fallback_chain( + "together:zai-glm-5,anthropic:claude-sonnet-4.6,mistral:mistral-large" + ) + assert parsed == [ + (LLMProvider.TOGETHER, LLMModel.ZAI_GLM_5), + (LLMProvider.ANTHROPIC, LLMModel.CLAUDE_SONNET_4_6), + (LLMProvider.MISTRAL, LLMModel.MISTRAL_LARGE), + ] + + +def test_whitespace_around_entries_is_ignored() -> None: + parsed = parse_llm_fallback_chain(" mistral : mistral-large , anthropic : claude-sonnet-4.6 ") + assert parsed == [ + (LLMProvider.MISTRAL, LLMModel.MISTRAL_LARGE), + (LLMProvider.ANTHROPIC, LLMModel.CLAUDE_SONNET_4_6), + ] + + +def test_unknown_provider_raises_with_accepted_list() -> None: + with pytest.raises(ValueError) as excinfo: + parse_llm_fallback_chain("bogus:mistral-large") + message = str(excinfo.value) + assert "bogus" in message + assert "mistral" in message + assert "anthropic" in message + + +def test_unknown_model_raises_with_accepted_list() -> None: + with pytest.raises(ValueError) as excinfo: + parse_llm_fallback_chain("mistral:not-a-real-model") + message = str(excinfo.value) + assert "not-a-real-model" in message + assert "mistral-large" in message + + +def test_missing_colon_raises_helpful_error() -> None: + with pytest.raises(ValueError) as excinfo: + parse_llm_fallback_chain("mistral-only") + assert "provider:model" in str(excinfo.value) + + +def test_trailing_comma_is_tolerated() -> None: + parsed = parse_llm_fallback_chain("mistral:mistral-large,") + assert parsed == [(LLMProvider.MISTRAL, LLMModel.MISTRAL_LARGE)] + + +def test_valid_enums_but_invalid_pair_raises() -> None: + with pytest.raises(ValueError) as excinfo: + parse_llm_fallback_chain("together:mistral-large") + message = str(excinfo.value) + assert "mistral-large" in message + assert "together" in message + assert "zai-glm-5" in message diff --git a/backend/airweave/domains/search/tests/test_executor.py b/backend/airweave/domains/search/tests/test_executor.py index 8da30518a..1768da077 100644 --- a/backend/airweave/domains/search/tests/test_executor.py +++ b/backend/airweave/domains/search/tests/test_executor.py @@ -17,6 +17,7 @@ import pytest +from airweave.domains.access_control.fakes.broker import FakeAccessBroker from airweave.domains.embedders.fakes.embedder import FakeDenseEmbedder, FakeSparseEmbedder from airweave.domains.search.adapters.vector_db.fakes.vector_db import FakeVectorDB from airweave.domains.search.executor import ( @@ -255,6 +256,7 @@ def _build_executor( sc_repo=sc_repo or FakeSourceConnectionRepository(), source_registry=source_registry or FakeSourceRegistry(), source_lifecycle=source_lifecycle or FakeSourceLifecycleService(), + access_broker=FakeAccessBroker(), ) @@ -913,6 +915,7 @@ async def test_dense_embedding_failure_propagates(self): sc_repo=FakeSourceConnectionRepository(), source_registry=FakeSourceRegistry(), source_lifecycle=FakeSourceLifecycleService(), + access_broker=FakeAccessBroker(), ) with pytest.raises(RuntimeError, match="provider down"): @@ -938,6 +941,7 @@ async def test_sparse_embedding_failure_propagates(self): sc_repo=FakeSourceConnectionRepository(), source_registry=FakeSourceRegistry(), source_lifecycle=FakeSourceLifecycleService(), + access_broker=FakeAccessBroker(), ) with pytest.raises(RuntimeError, match="timeout"): diff --git a/backend/airweave/domains/source_connections/create.py b/backend/airweave/domains/source_connections/create.py index e1c6ade60..63bc07aca 100644 --- a/backend/airweave/domains/source_connections/create.py +++ b/backend/airweave/domains/source_connections/create.py @@ -35,10 +35,7 @@ SourceValidationServiceProtocol, ) from airweave.domains.syncs.jobs.protocols import SyncJobRepositoryProtocol -from airweave.domains.syncs.protocols import ( - SyncLifecycleServiceProtocol, - SyncRecordServiceProtocol, -) +from airweave.domains.syncs.protocols import SyncServiceProtocol from airweave.domains.temporal.protocols import TemporalWorkflowServiceProtocol from airweave.models.connection_init_session import ConnectionInitStatus from airweave.schemas.connection import ConnectionCreate @@ -75,8 +72,7 @@ def __init__( source_registry: SourceRegistryProtocol, source_validation: SourceValidationServiceProtocol, source_lifecycle: SourceLifecycleServiceProtocol, - sync_lifecycle: SyncLifecycleServiceProtocol, - sync_record_service: SyncRecordServiceProtocol, + sync_service: SyncServiceProtocol, response_builder: ResponseBuilderProtocol, oauth_flow_service: OAuthFlowServiceProtocol, temporal_workflow_service: TemporalWorkflowServiceProtocol, @@ -92,8 +88,7 @@ def __init__( self._source_registry = source_registry self._source_validation = source_validation self._source_lifecycle = source_lifecycle - self._sync_lifecycle = sync_lifecycle - self._sync_record_service = sync_record_service + self._sync_service = sync_service self._response_builder = response_builder self._oauth_flow_service = oauth_flow_service self._temporal_workflow_service = temporal_workflow_service @@ -450,23 +445,27 @@ async def _create_with_auth_provider( ) await uow.session.flush() connection_schema = schemas.Connection.model_validate(connection, from_attributes=True) - destination_ids = await self._sync_record_service.resolve_destination_ids( - uow.session, ctx - ) - sync_result = await self._sync_lifecycle.provision_sync( - uow.session, - name=obj_in.name, - source_connection_id=connection.id, - destination_connection_ids=destination_ids, - collection_id=collection.id, - collection_readable_id=collection.readable_id, - source_entry=entry, - schedule_config=obj_in.schedule, - run_immediately=bool(obj_in.sync_immediately), - ctx=ctx, - uow=uow, - ) - await uow.session.flush() + + has_schedule = obj_in.schedule is None or ( + obj_in.schedule and obj_in.schedule.cron is not None + ) + sync_result = None + if bool(obj_in.sync_immediately) or has_schedule: + destination_ids = await self._sync_service.resolve_destination_ids(uow.session, ctx) + sync_result = await self._sync_service.create( + uow.session, + name=obj_in.name or entry.name, + source_connection_id=connection.id, + destination_connection_ids=destination_ids, + collection_id=collection.id, + collection_readable_id=collection.readable_id, + source_entry=entry, + schedule_config=obj_in.schedule, + run_immediately=bool(obj_in.sync_immediately), + ctx=ctx, + uow=uow, + ) + await uow.session.flush() source_conn = await self._sc_repo.create( uow.session, @@ -658,23 +657,28 @@ async def _create_authenticated_connection( ) await uow.session.flush() connection_schema = schemas.Connection.model_validate(connection, from_attributes=True) - destination_ids = await self._sync_record_service.resolve_destination_ids( - uow.session, ctx - ) - sync_result = await self._sync_lifecycle.provision_sync( - uow.session, - name=obj_in.name, - source_connection_id=connection.id, - destination_connection_ids=destination_ids, - collection_id=collection.id, - collection_readable_id=collection.readable_id, - source_entry=entry, - schedule_config=obj_in.schedule, - run_immediately=bool(obj_in.sync_immediately), - ctx=ctx, - uow=uow, - ) - await uow.session.flush() + + has_schedule = obj_in.schedule is None or ( + obj_in.schedule and obj_in.schedule.cron is not None + ) + sync_result = None + if bool(obj_in.sync_immediately) or has_schedule: + destination_ids = await self._sync_service.resolve_destination_ids(uow.session, ctx) + sync_result = await self._sync_service.create( + uow.session, + name=obj_in.name or entry.name, + source_connection_id=connection.id, + destination_connection_ids=destination_ids, + collection_id=collection.id, + collection_readable_id=collection.readable_id, + source_entry=entry, + schedule_config=obj_in.schedule, + run_immediately=bool(obj_in.sync_immediately), + ctx=ctx, + uow=uow, + ) + await uow.session.flush() + source_conn = await self._sc_repo.create( uow.session, obj_in={ diff --git a/backend/airweave/domains/source_connections/delete.py b/backend/airweave/domains/source_connections/delete.py index 2c073b7ae..8a25a1416 100644 --- a/backend/airweave/domains/source_connections/delete.py +++ b/backend/airweave/domains/source_connections/delete.py @@ -1,22 +1,19 @@ """Source connection deletion service.""" -import asyncio from uuid import UUID from sqlalchemy.ext.asyncio import AsyncSession +from airweave import schemas from airweave.api.context import ApiContext from airweave.core.exceptions import NotFoundException -from airweave.core.shared_models import SyncJobStatus from airweave.domains.collections.protocols import CollectionRepositoryProtocol from airweave.domains.source_connections.protocols import ( ResponseBuilderProtocol, SourceConnectionDeletionServiceProtocol, SourceConnectionRepositoryProtocol, ) -from airweave.domains.syncs.jobs.protocols import SyncJobRepositoryProtocol -from airweave.domains.syncs.protocols import SyncLifecycleServiceProtocol -from airweave.domains.temporal.protocols import TemporalWorkflowServiceProtocol +from airweave.domains.syncs.protocols import SyncServiceProtocol from airweave.schemas.source_connection import SourceConnection as SourceConnectionSchema @@ -24,27 +21,21 @@ class SourceConnectionDeletionService(SourceConnectionDeletionServiceProtocol): """Deletes a source connection and all related data. The flow is: - 1. Cancel any running sync workflows and wait for them to stop. + 1. Delegate cancel + wait + cleanup scheduling to SyncService.delete. 2. CASCADE-delete the DB records (source connection, sync, jobs, entities). - 3. Fire-and-forget a Temporal cleanup workflow for the slow external - data deletion (Vespa, ARF, schedules) which can take minutes. """ - def __init__( + def __init__( # noqa: D107 self, sc_repo: SourceConnectionRepositoryProtocol, collection_repo: CollectionRepositoryProtocol, - sync_job_repo: SyncJobRepositoryProtocol, - sync_lifecycle: SyncLifecycleServiceProtocol, response_builder: ResponseBuilderProtocol, - temporal_workflow_service: TemporalWorkflowServiceProtocol, + sync_service: SyncServiceProtocol, ) -> None: self._sc_repo = sc_repo self._collection_repo = collection_repo - self._sync_job_repo = sync_job_repo - self._sync_lifecycle = sync_lifecycle self._response_builder = response_builder - self._temporal_workflow_service = temporal_workflow_service + self._sync_service = sync_service async def delete( self, @@ -58,112 +49,26 @@ async def delete( if not source_conn: raise NotFoundException("Source connection not found") - # Capture attributes upfront to avoid lazy-loading issues after session changes sync_id = source_conn.sync_id - collection = await self._collection_repo.get_by_readable_id( + collection_orm = await self._collection_repo.get_by_readable_id( db, readable_id=source_conn.readable_collection_id, ctx=ctx ) - if not collection: + if not collection_orm: raise NotFoundException("Collection not found") - collection_id = str(collection.id) - organization_id = str(collection.organization_id) + collection = schemas.CollectionRecord.model_validate(collection_orm, from_attributes=True) - # Build response before deletion response = await self._response_builder.build_response(db, source_conn, ctx) - # Cancel any running jobs and wait for the Temporal workflow to - # terminate before we cascade-delete the DB rows. if sync_id: - latest_job = await self._sync_job_repo.get_latest_by_sync_id(db, sync_id=sync_id) - if latest_job and latest_job.status in [ - SyncJobStatus.PENDING, - SyncJobStatus.RUNNING, - SyncJobStatus.CANCELLING, - ]: - if latest_job.status in [SyncJobStatus.PENDING, SyncJobStatus.RUNNING]: - ctx.logger.info( - f"Cancelling job {latest_job.id} for source connection {id} before deletion" - ) - try: - await self._sync_lifecycle.cancel_job( - db, - source_connection_id=id, - job_id=latest_job.id, - ctx=ctx, - ) - except Exception as e: - ctx.logger.warning( - f"Failed to cancel job {latest_job.id} during deletion: {e}" - ) + await self._sync_service.delete( + db, + sync_id=sync_id, + collection_id=collection.id, + organization_id=collection.organization_id, + ctx=ctx, + cancel_timeout_seconds=15, + ) - # BARRIER: Wait for the workflow to reach a terminal state so - # the worker stops writing before we cascade-delete the rows. - reached_terminal = await self._wait_for_sync_job_terminal_state( - db, sync_id, timeout_seconds=15 - ) - if not reached_terminal: - ctx.logger.warning( - f"Job for sync {sync_id} did not reach terminal state within 15s " - f"-- proceeding with deletion anyway" - ) - - # Delete the source connection first (CASCADE removes sync, jobs, entities). await self._sc_repo.remove(db, id=id, ctx=ctx) - # Fire-and-forget: schedule async cleanup of external data (Vespa, ARF, - # Temporal schedules). This can take minutes for Vespa and must not - # block the API response. - if sync_id: - try: - await self._temporal_workflow_service.start_cleanup_sync_data_workflow( - sync_ids=[str(sync_id)], - collection_id=collection_id, - organization_id=organization_id, - ctx=ctx, - ) - except Exception as e: - ctx.logger.error( - f"Failed to schedule async cleanup for sync {sync_id}: {e}. " - f"Data may be orphaned in Vespa/ARF." - ) - return response - - async def _wait_for_sync_job_terminal_state( - self, - db: AsyncSession, - sync_id: UUID, - *, - timeout_seconds: int = 30, - poll_interval: float = 1.0, - ) -> bool: - """Wait for the latest sync job to reach a terminal state. - - Polls the database until the job reaches COMPLETED, FAILED, or CANCELLED. - Used as a cancellation barrier to prevent cleanup from running while - a Temporal worker is still actively writing. - - Args: - db: Database session. - sync_id: Sync ID whose latest job to monitor. - timeout_seconds: Maximum time to wait before giving up. - poll_interval: Seconds between poll attempts. - - Returns: - True if a terminal state was reached, False on timeout. - """ - terminal_states = { - SyncJobStatus.COMPLETED, - SyncJobStatus.FAILED, - SyncJobStatus.CANCELLED, - } - elapsed = 0.0 - while elapsed < timeout_seconds: - await asyncio.sleep(poll_interval) - elapsed += poll_interval - # Expire cached ORM objects to force a fresh read from the database - db.expire_all() - job = await self._sync_job_repo.get_latest_by_sync_id(db, sync_id=sync_id) - if job and job.status in terminal_states: - return True - return False diff --git a/backend/airweave/domains/source_connections/fakes/service.py b/backend/airweave/domains/source_connections/fakes/service.py index 547fd932b..8e0280e3e 100644 --- a/backend/airweave/domains/source_connections/fakes/service.py +++ b/backend/airweave/domains/source_connections/fakes/service.py @@ -12,7 +12,7 @@ SourceConnectionDeletionServiceProtocol, SourceConnectionUpdateServiceProtocol, ) -from airweave.domains.syncs.protocols import SyncLifecycleServiceProtocol +from airweave.domains.syncs.protocols import SyncServiceProtocol from airweave.models.source_connection import SourceConnection from airweave.schemas.source_connection import ( SourceConnection as SourceConnectionSchema, @@ -30,7 +30,7 @@ class FakeSourceConnectionService: def __init__( self, - sync_lifecycle: SyncLifecycleServiceProtocol, + sync_service: SyncServiceProtocol, create_service: Optional[SourceConnectionCreateServiceProtocol] = None, update_service: Optional[SourceConnectionUpdateServiceProtocol] = None, deletion_service: Optional[SourceConnectionDeletionServiceProtocol] = None, @@ -39,7 +39,7 @@ def __init__( self._list_items: List[SourceConnectionListItem] = [] self._redirect_urls: dict[str, str] = {} self._calls: list[tuple[Any, ...]] = [] - self._sync_lifecycle = sync_lifecycle + self._sync_service = sync_service self._create_service = create_service self._update_service = update_service self._deletion_service = deletion_service @@ -112,7 +112,7 @@ async def run( force_full_sync: bool = False, ) -> SourceConnectionJob: self._calls.append(("run", db, id, ctx, force_full_sync)) - return await self._sync_lifecycle.run(db, id=id, ctx=ctx, force_full_sync=force_full_sync) + raise NotImplementedError("FakeSourceConnectionService.run not wired") async def get_jobs( self, @@ -123,7 +123,7 @@ async def get_jobs( limit: int = 100, ) -> List[SourceConnectionJob]: self._calls.append(("get_jobs", db, id, ctx, limit)) - return await self._sync_lifecycle.get_jobs(db, id=id, ctx=ctx, limit=limit) + return [] async def cancel_job( self, @@ -134,9 +134,7 @@ async def cancel_job( ctx: ApiContext, ) -> SourceConnectionJob: self._calls.append(("cancel_job", db, source_connection_id, job_id, ctx)) - return await self._sync_lifecycle.cancel_job( - db, source_connection_id=source_connection_id, job_id=job_id, ctx=ctx - ) + raise NotImplementedError("FakeSourceConnectionService.cancel_job not wired") async def get_sync_id(self, db: AsyncSession, *, id: UUID, ctx: ApiContext) -> dict: self._calls.append(("get_sync_id", db, id, ctx)) diff --git a/backend/airweave/domains/source_connections/service.py b/backend/airweave/domains/source_connections/service.py index 674bb1564..33cc05ef6 100644 --- a/backend/airweave/domains/source_connections/service.py +++ b/backend/airweave/domains/source_connections/service.py @@ -1,13 +1,17 @@ """Service for source connections.""" +from datetime import datetime from typing import List, Optional from uuid import UUID from sqlalchemy.ext.asyncio import AsyncSession +from airweave import schemas from airweave.api.context import ApiContext from airweave.core.datetime_utils import utc_now +from airweave.core.events.sync import SyncLifecycleEvent from airweave.core.exceptions import NotFoundException +from airweave.core.protocols.event_bus import EventBus from airweave.domains.auth_provider.protocols import AuthProviderRegistryProtocol from airweave.domains.collections.protocols import CollectionRepositoryProtocol from airweave.domains.connections.protocols import ConnectionRepositoryProtocol @@ -21,7 +25,7 @@ SourceConnectionUpdateServiceProtocol, ) from airweave.domains.sources.protocols import SourceRegistryProtocol -from airweave.domains.syncs.protocols import SyncLifecycleServiceProtocol +from airweave.domains.syncs.protocols import SyncServiceProtocol from airweave.models.source_connection import SourceConnection from airweave.schemas.source_connection import ( SourceConnection as SourceConnectionSchema, @@ -34,6 +38,14 @@ ) +def _duration_seconds( + started_at: Optional[datetime], completed_at: Optional[datetime] +) -> Optional[float]: + if started_at and completed_at: + return (completed_at - started_at).total_seconds() + return None + + class SourceConnectionService(SourceConnectionServiceProtocol): """Service for source connections.""" @@ -49,7 +61,8 @@ def __init__( # noqa: D107 auth_provider_registry: AuthProviderRegistryProtocol, # Helpers response_builder: ResponseBuilderProtocol, - sync_lifecycle: SyncLifecycleServiceProtocol, + sync_service: SyncServiceProtocol, + event_bus: EventBus, # Sub-services create_service: SourceConnectionCreateServiceProtocol, update_service: SourceConnectionUpdateServiceProtocol, @@ -62,7 +75,8 @@ def __init__( # noqa: D107 self.source_registry = source_registry self.auth_provider_registry = auth_provider_registry self.response_builder = response_builder - self._sync_lifecycle = sync_lifecycle + self._sync_service = sync_service + self._event_bus = event_bus self._create_service = create_service self._update_service = update_service self._deletion_service = deletion_service @@ -142,7 +156,8 @@ async def delete(self, db: AsyncSession, id: UUID, ctx: ApiContext) -> SourceCon return await self._deletion_service.delete(db, id=id, ctx=ctx) # ------------------------------------------------------------------ - # Sync lifecycle proxies + # Sync lifecycle proxies — resolve source_connection → sync_id, then + # delegate to the unified SyncService and map results. # ------------------------------------------------------------------ async def run( @@ -154,7 +169,52 @@ async def run( force_full_sync: bool = False, ) -> SourceConnectionJob: """Trigger a sync run for this source connection.""" - return await self._sync_lifecycle.run(db, id=id, ctx=ctx, force_full_sync=force_full_sync) + source_conn = await self._resolve_source_connection(db, id, ctx) + sync_id = source_conn.sync_id + assert sync_id is not None + + if force_full_sync: + await self._sync_service.validate_force_full_sync(db, sync_id, ctx) + + collection = await self._resolve_collection(db, source_conn, ctx) + connection = await self._resolve_connection(db, source_conn, ctx) + + sync, sync_job = await self._sync_service.trigger_run( + db, + sync_id=sync_id, + collection=collection, + connection=connection, + ctx=ctx, + force_full_sync=force_full_sync, + ) + + await self._event_bus.publish( + SyncLifecycleEvent.pending( + organization_id=ctx.organization.id, + source_connection_id=id, + sync_job_id=sync_job.id, + sync_id=sync_id, + collection_id=collection.id, + source_type=connection.short_name, + collection_name=collection.name, + collection_readable_id=collection.readable_id, + ) + ) + + return SourceConnectionJob( + id=sync_job.id, + source_connection_id=id, + status=sync_job.status, + started_at=sync_job.started_at, + completed_at=sync_job.completed_at, + duration_seconds=_duration_seconds(sync_job.started_at, sync_job.completed_at), + entities_inserted=sync_job.entities_inserted or 0, + entities_updated=sync_job.entities_updated or 0, + entities_deleted=sync_job.entities_deleted or 0, + entities_failed=sync_job.entities_skipped or 0, + error=sync_job.error, + error_category=sync_job.error_category, + ) async def get_jobs( self, @@ -165,7 +225,29 @@ async def get_jobs( limit: int = 100, ) -> List[SourceConnectionJob]: """List sync jobs for this source connection.""" - return await self._sync_lifecycle.get_jobs(db, id=id, ctx=ctx, limit=limit) + source_conn = await self._resolve_source_connection(db, id, ctx) + sync_id = source_conn.sync_id + assert sync_id is not None + + jobs = await self._sync_service.get_jobs(db, sync_id=sync_id, ctx=ctx, limit=limit) + + return [ + SourceConnectionJob( + id=j.id, + source_connection_id=id, + status=j.status, + started_at=j.started_at, + completed_at=j.completed_at, + duration_seconds=_duration_seconds(j.started_at, j.completed_at), + entities_inserted=j.entities_inserted or 0, + entities_updated=j.entities_updated or 0, + entities_deleted=j.entities_deleted or 0, + entities_failed=j.entities_skipped or 0, + error=j.error, + error_category=j.error_category, + ) + for j in jobs + ] async def cancel_job( self, @@ -176,8 +258,21 @@ async def cancel_job( ctx: ApiContext, ) -> SourceConnectionJob: """Cancel a running sync job.""" - return await self._sync_lifecycle.cancel_job( - db, source_connection_id=source_connection_id, job_id=job_id, ctx=ctx + sync_job = await self._sync_service.cancel_job(db, job_id=job_id, ctx=ctx) + + return SourceConnectionJob( + id=sync_job.id, + source_connection_id=source_connection_id, + status=sync_job.status, + started_at=sync_job.started_at, + completed_at=sync_job.completed_at, + duration_seconds=_duration_seconds(sync_job.started_at, sync_job.completed_at), + entities_inserted=sync_job.entities_inserted or 0, + entities_updated=sync_job.entities_updated or 0, + entities_deleted=sync_job.entities_deleted or 0, + entities_failed=sync_job.entities_skipped or 0, + error=sync_job.error, + error_category=sync_job.error_category, ) async def get_sync_id(self, db: AsyncSession, *, id: UUID, ctx: ApiContext) -> dict: @@ -194,15 +289,50 @@ async def count_by_organization(self, db: AsyncSession, organization_id: UUID) - return await self.sc_repo.count_by_organization(db, organization_id) async def get_redirect_url(self, db: AsyncSession, *, code: str) -> str: - """Resolve a short redirect code to its final OAuth authorization URL. - - The redirect session is atomically consumed (deleted) on lookup, - enforcing one-time use per CASA Requirement #23. - """ + """Resolve a short redirect code to its final OAuth authorization URL.""" redirect_info = await self._redirect_session_repo.consume(db, code=code) if not redirect_info: raise NotFoundException("Authorization link expired or invalid") - # Check expiry *after* consume so expired tokens can't be replayed. if redirect_info.expires_at <= utc_now(): raise NotFoundException("Authorization link expired or invalid") return redirect_info.final_url + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + async def _resolve_source_connection( + self, db: AsyncSession, id: UUID, ctx: ApiContext + ) -> SourceConnection: + """Get a source connection and validate it has an associated sync.""" + source_conn = await self.sc_repo.get(db, id=id, ctx=ctx) + if not source_conn: + raise NotFoundException("Source connection not found") + if not source_conn.sync_id: + raise NotFoundException("No sync found for this source connection") + return source_conn + + async def _resolve_collection( + self, db: AsyncSession, source_conn: SourceConnection, ctx: ApiContext + ) -> schemas.CollectionRecord: + """Resolve the CollectionRecord schema for a source connection.""" + readable_id = source_conn.readable_collection_id + if not readable_id: + raise NotFoundException( + f"Source connection {source_conn.id} has no readable_collection_id" + ) + collection = await self.collection_repo.get_by_readable_id(db, str(readable_id), ctx) + if not collection: + raise NotFoundException("Collection not found") + return schemas.CollectionRecord.model_validate(collection, from_attributes=True) + + async def _resolve_connection( + self, db: AsyncSession, source_conn: SourceConnection, ctx: ApiContext + ) -> schemas.Connection: + """Resolve the Connection schema (not SourceConnection) for a source connection.""" + if not source_conn.connection_id: + raise NotFoundException(f"Source connection {source_conn.id} has no connection_id") + conn = await self.connection_repo.get(db, source_conn.connection_id, ctx) + if not conn: + raise NotFoundException(f"Connection {source_conn.connection_id} not found") + return schemas.Connection.model_validate(conn, from_attributes=True) diff --git a/backend/airweave/domains/source_connections/tests/test_create.py b/backend/airweave/domains/source_connections/tests/test_create.py index 61eef2dec..b57f596f2 100644 --- a/backend/airweave/domains/source_connections/tests/test_create.py +++ b/backend/airweave/domains/source_connections/tests/test_create.py @@ -26,8 +26,7 @@ from airweave.domains.sources.fakes.registry import FakeSourceRegistry from airweave.domains.sources.fakes.validation import FakeSourceValidationService from airweave.domains.syncs.jobs.fakes.repository import FakeSyncJobRepository -from airweave.domains.syncs.fakes.lifecycle_service import FakeSyncLifecycleService -from airweave.domains.syncs.fakes.record_service import FakeSyncRecordService +from airweave.domains.syncs.fakes.service import FakeSyncService from airweave.domains.temporal.fakes.service import FakeTemporalWorkflowService from airweave.schemas.organization import Organization from airweave.schemas.source_connection import ( @@ -86,8 +85,7 @@ def _service(entry) -> SourceConnectionCreationService: source_registry=registry, source_validation=FakeSourceValidationService(), source_lifecycle=FakeSourceLifecycleService(), - sync_lifecycle=FakeSyncLifecycleService(), - sync_record_service=FakeSyncRecordService(), + sync_service=FakeSyncService(), response_builder=FakeResponseBuilder(), oauth_flow_service=FakeOAuthFlowService(), temporal_workflow_service=FakeTemporalWorkflowService(), diff --git a/backend/airweave/domains/source_connections/tests/test_delete.py b/backend/airweave/domains/source_connections/tests/test_delete.py index 22b8d07fc..a70bf4428 100644 --- a/backend/airweave/domains/source_connections/tests/test_delete.py +++ b/backend/airweave/domains/source_connections/tests/test_delete.py @@ -1,37 +1,31 @@ """Unit tests for SourceConnectionDeletionService. Table-driven tests covering: -- Happy paths: no sync, completed job, running/cancelling/pending jobs -- Error paths: not found, collection not found, cancel failure, cleanup failure, timeout +- Happy paths: no sync vs with sync (delegates to sync_service.delete) +- Error paths: not found, collection not found """ from dataclasses import dataclass from datetime import datetime, timezone -from typing import Optional from unittest.mock import AsyncMock, MagicMock from uuid import uuid4 import pytest -import airweave.domains.source_connections.delete as delete_module from airweave.api.context import ApiContext from airweave.core.exceptions import NotFoundException from airweave.core.logging import logger -from airweave.core.shared_models import AuthMethod, SyncJobStatus +from airweave.core.shared_models import AuthMethod from airweave.domains.collections.fakes.repository import FakeCollectionRepository from airweave.domains.source_connections.delete import SourceConnectionDeletionService from airweave.domains.source_connections.fakes.repository import ( FakeSourceConnectionRepository, ) from airweave.domains.source_connections.fakes.response import FakeResponseBuilder -from airweave.domains.syncs.jobs.fakes.repository import FakeSyncJobRepository -from airweave.domains.syncs.fakes.lifecycle_service import FakeSyncLifecycleService -from airweave.domains.temporal.fakes.service import FakeTemporalWorkflowService +from airweave.domains.syncs.fakes.service import FakeSyncService from airweave.models.collection import Collection from airweave.models.source_connection import SourceConnection -from airweave.models.sync_job import SyncJob from airweave.schemas.organization import Organization -from airweave.schemas.source_connection import SourceConnectionJob NOW = datetime.now(timezone.utc) ORG_ID = uuid4() @@ -67,33 +61,28 @@ def _make_collection(*, id=None, readable_id="test-col"): col = MagicMock(spec=Collection) col.id = id or COLLECTION_ID col.readable_id = readable_id + col.name = "Test Collection" col.organization_id = ORG_ID + col.vector_db_deployment_metadata_id = uuid4() + col.sync_config = None + col.created_at = NOW + col.modified_at = NOW + col.created_by_email = None + col.modified_by_email = None return col -def _make_job(*, status=SyncJobStatus.COMPLETED, sync_id=None): - job = MagicMock(spec=SyncJob) - job.id = uuid4() - job.sync_id = sync_id or uuid4() - job.status = status - return job - - def _build_service( sc_repo=None, collection_repo=None, - sync_job_repo=None, - sync_lifecycle=None, response_builder=None, - temporal_workflow_service=None, + sync_service=None, ): return SourceConnectionDeletionService( sc_repo=sc_repo or FakeSourceConnectionRepository(), collection_repo=collection_repo or FakeCollectionRepository(), - sync_job_repo=sync_job_repo or FakeSyncJobRepository(), - sync_lifecycle=sync_lifecycle or FakeSyncLifecycleService(), response_builder=response_builder or FakeResponseBuilder(), - temporal_workflow_service=temporal_workflow_service or FakeTemporalWorkflowService(), + sync_service=sync_service or FakeSyncService(), ) @@ -106,18 +95,12 @@ def _build_service( class DeleteCase: desc: str has_sync: bool - job_status: Optional[SyncJobStatus] - expect_cancel: bool - expect_wait: bool - expect_cleanup: bool + expect_sync_delete: bool DELETE_CASES = [ - DeleteCase("no_sync", has_sync=False, job_status=None, expect_cancel=False, expect_wait=False, expect_cleanup=False), - DeleteCase("sync_no_running_job", has_sync=True, job_status=SyncJobStatus.COMPLETED, expect_cancel=False, expect_wait=False, expect_cleanup=True), - DeleteCase("running_job", has_sync=True, job_status=SyncJobStatus.RUNNING, expect_cancel=True, expect_wait=True, expect_cleanup=True), - DeleteCase("cancelling_job", has_sync=True, job_status=SyncJobStatus.CANCELLING, expect_cancel=False, expect_wait=True, expect_cleanup=True), - DeleteCase("pending_job", has_sync=True, job_status=SyncJobStatus.PENDING, expect_cancel=True, expect_wait=True, expect_cleanup=True), + DeleteCase("no_sync", has_sync=False, expect_sync_delete=False), + DeleteCase("with_sync", has_sync=True, expect_sync_delete=True), ] @@ -132,38 +115,23 @@ async def test_delete_happy_path(case: DeleteCase): col_repo = FakeCollectionRepository() col_repo.seed_readable(sc.readable_collection_id, col) - job_repo = FakeSyncJobRepository() - if case.job_status is not None and sync_id: - job_repo.seed_last_job(sync_id, _make_job(status=case.job_status, sync_id=sync_id)) - - lifecycle = FakeSyncLifecycleService() - if case.expect_cancel: - lifecycle.set_cancel_result(MagicMock(spec=SourceConnectionJob)) - - temporal = FakeTemporalWorkflowService() + sync_service = FakeSyncService() svc = _build_service( sc_repo=sc_repo, collection_repo=col_repo, - sync_job_repo=job_repo, - sync_lifecycle=lifecycle, - temporal_workflow_service=temporal, + sync_service=sync_service, ) - if case.expect_wait: - svc._wait_for_sync_job_terminal_state = AsyncMock(return_value=True) - result = await svc.delete(AsyncMock(), id=sc.id, ctx=_make_ctx()) assert result.id == sc.id assert sc_repo._store.get(sc.id) is None - if case.expect_cancel: - assert any(c[0] == "cancel_job" for c in lifecycle._calls) - if case.expect_wait: - svc._wait_for_sync_job_terminal_state.assert_awaited_once() - if case.expect_cleanup: - assert any(c[0] == "start_cleanup_sync_data_workflow" for c in temporal._calls) + if case.expect_sync_delete: + assert any(c[0] == "delete" for c in sync_service._calls) + else: + assert not any(c[0] == "delete" for c in sync_service._calls) # --------------------------------------------------------------------------- @@ -203,8 +171,8 @@ async def test_delete_error(case: DeleteErrorCase): await svc.delete(AsyncMock(), id=sc.id, ctx=_make_ctx()) -async def test_delete_cancel_failure_is_swallowed(): - """Cancel failure during delete is warned but not re-raised.""" +async def test_delete_sync_service_failure_propagates(): + """If sync_service.delete raises, the error propagates.""" sync_id = uuid4() sc = _make_sc(sync_id=sync_id) col = _make_collection() @@ -214,132 +182,14 @@ async def test_delete_cancel_failure_is_swallowed(): col_repo = FakeCollectionRepository() col_repo.seed_readable(sc.readable_collection_id, col) - running_job = _make_job(status=SyncJobStatus.RUNNING, sync_id=sync_id) - job_repo = FakeSyncJobRepository() - job_repo.seed_last_job(sync_id, running_job) - - lifecycle = FakeSyncLifecycleService() - lifecycle.set_error(RuntimeError("cancel boom")) - temporal = FakeTemporalWorkflowService() + sync_service = FakeSyncService() + sync_service.set_error(RuntimeError("sync delete boom")) svc = _build_service( sc_repo=sc_repo, collection_repo=col_repo, - sync_job_repo=job_repo, - sync_lifecycle=lifecycle, - temporal_workflow_service=temporal, + sync_service=sync_service, ) - svc._wait_for_sync_job_terminal_state = AsyncMock(return_value=True) - - result = await svc.delete(AsyncMock(), id=sc.id, ctx=_make_ctx()) - assert result.id == sc.id - -async def test_delete_temporal_cleanup_failure_is_logged(): - """Temporal cleanup failure is logged but not re-raised.""" - sync_id = uuid4() - sc = _make_sc(sync_id=sync_id) - col = _make_collection() - - sc_repo = FakeSourceConnectionRepository() - sc_repo.seed(sc.id, sc) - col_repo = FakeCollectionRepository() - col_repo.seed_readable(sc.readable_collection_id, col) - - job_repo = FakeSyncJobRepository() - completed_job = _make_job(status=SyncJobStatus.COMPLETED, sync_id=sync_id) - job_repo.seed_last_job(sync_id, completed_job) - - temporal = FakeTemporalWorkflowService() - temporal.set_error(RuntimeError("cleanup boom")) - - svc = _build_service( - sc_repo=sc_repo, - collection_repo=col_repo, - sync_job_repo=job_repo, - temporal_workflow_service=temporal, - ) - result = await svc.delete(AsyncMock(), id=sc.id, ctx=_make_ctx()) - assert result.id == sc.id - - -async def test_delete_wait_timeout_proceeds(): - """If wait_for_terminal_state returns False, deletion still proceeds.""" - sync_id = uuid4() - sc = _make_sc(sync_id=sync_id) - col = _make_collection() - - sc_repo = FakeSourceConnectionRepository() - sc_repo.seed(sc.id, sc) - col_repo = FakeCollectionRepository() - col_repo.seed_readable(sc.readable_collection_id, col) - - running_job = _make_job(status=SyncJobStatus.RUNNING, sync_id=sync_id) - job_repo = FakeSyncJobRepository() - job_repo.seed_last_job(sync_id, running_job) - - lifecycle = FakeSyncLifecycleService() - lifecycle.set_cancel_result(MagicMock(spec=SourceConnectionJob)) - temporal = FakeTemporalWorkflowService() - - svc = _build_service( - sc_repo=sc_repo, - collection_repo=col_repo, - sync_job_repo=job_repo, - sync_lifecycle=lifecycle, - temporal_workflow_service=temporal, - ) - svc._wait_for_sync_job_terminal_state = AsyncMock(return_value=False) - - result = await svc.delete(AsyncMock(), id=sc.id, ctx=_make_ctx()) - assert result.id == sc.id - assert sc_repo._store.get(sc.id) is None - - -async def test_wait_for_sync_job_terminal_state_reaches_terminal(monkeypatch): - sync_id = uuid4() - running_job = _make_job(status=SyncJobStatus.RUNNING, sync_id=sync_id) - cancelled_job = _make_job(status=SyncJobStatus.CANCELLED, sync_id=sync_id) - - job_repo = FakeSyncJobRepository() - job_repo.get_latest_by_sync_id = AsyncMock(side_effect=[running_job, cancelled_job]) # type: ignore[method-assign] - svc = _build_service(sync_job_repo=job_repo) - - async def _no_sleep(_: float) -> None: - return None - - monkeypatch.setattr(delete_module.asyncio, "sleep", _no_sleep) - - db = MagicMock() - db.expire_all = MagicMock() - - reached = await svc._wait_for_sync_job_terminal_state( - db, sync_id, timeout_seconds=2, poll_interval=1 - ) - - assert reached is True - assert db.expire_all.call_count == 2 - - -async def test_wait_for_sync_job_terminal_state_times_out(monkeypatch): - sync_id = uuid4() - running_job = _make_job(status=SyncJobStatus.RUNNING, sync_id=sync_id) - - job_repo = FakeSyncJobRepository() - job_repo.get_latest_by_sync_id = AsyncMock(side_effect=[running_job, running_job]) # type: ignore[method-assign] - svc = _build_service(sync_job_repo=job_repo) - - async def _no_sleep(_: float) -> None: - return None - - monkeypatch.setattr(delete_module.asyncio, "sleep", _no_sleep) - - db = MagicMock() - db.expire_all = MagicMock() - - reached = await svc._wait_for_sync_job_terminal_state( - db, sync_id, timeout_seconds=2, poll_interval=1 - ) - - assert reached is False - assert db.expire_all.call_count == 2 + with pytest.raises(RuntimeError, match="sync delete boom"): + await svc.delete(AsyncMock(), id=sc.id, ctx=_make_ctx()) diff --git a/backend/airweave/domains/source_connections/tests/test_fake_service.py b/backend/airweave/domains/source_connections/tests/test_fake_service.py index bd619d0e8..bba4f3e76 100644 --- a/backend/airweave/domains/source_connections/tests/test_fake_service.py +++ b/backend/airweave/domains/source_connections/tests/test_fake_service.py @@ -8,7 +8,7 @@ from airweave.core.exceptions import NotFoundException from airweave.domains.source_connections.fakes.delete import FakeSourceConnectionDeletionService from airweave.domains.source_connections.fakes.service import FakeSourceConnectionService -from airweave.domains.syncs.fakes.lifecycle_service import FakeSyncLifecycleService +from airweave.domains.syncs.fakes.service import FakeSyncService def _make_source_connection(): @@ -23,7 +23,7 @@ async def test_delete_delegation_removes_fake_store_entry_on_success(): deletion_service.seed_response(source_connection.id, MagicMock()) service = FakeSourceConnectionService( - sync_lifecycle=FakeSyncLifecycleService(), + sync_service=FakeSyncService(), deletion_service=deletion_service, ) service.seed(source_connection.id, source_connection) @@ -40,7 +40,7 @@ async def test_delete_delegation_keeps_fake_store_entry_when_delegate_raises(): deletion_service.set_should_raise(RuntimeError("boom")) service = FakeSourceConnectionService( - sync_lifecycle=FakeSyncLifecycleService(), + sync_service=FakeSyncService(), deletion_service=deletion_service, ) service.seed(source_connection.id, source_connection) diff --git a/backend/airweave/domains/source_connections/tests/test_service.py b/backend/airweave/domains/source_connections/tests/test_service.py index b8f2a911e..def0b1169 100644 --- a/backend/airweave/domains/source_connections/tests/test_service.py +++ b/backend/airweave/domains/source_connections/tests/test_service.py @@ -10,31 +10,30 @@ from datetime import datetime, timedelta from typing import Optional from unittest.mock import AsyncMock, MagicMock -from uuid import uuid4 +from uuid import UUID, uuid4 import pytest +from airweave.api.context import ApiContext from airweave.core.datetime_utils import utc_now from airweave.core.exceptions import NotFoundException - -from airweave.api.context import ApiContext from airweave.core.logging import logger from airweave.core.shared_models import AuthMethod, SourceConnectionStatus, SyncJobStatus from airweave.domains.auth_provider.fake import FakeAuthProviderRegistry from airweave.domains.collections.fakes.repository import FakeCollectionRepository from airweave.domains.connections.fakes.repository import FakeConnectionRepository from airweave.domains.oauth.fakes.repository import FakeOAuthRedirectSessionRepository +from airweave.domains.source_connections.fakes.create import FakeSourceConnectionCreateService +from airweave.domains.source_connections.fakes.delete import FakeSourceConnectionDeletionService from airweave.domains.source_connections.fakes.repository import ( FakeSourceConnectionRepository, ) from airweave.domains.source_connections.fakes.response import FakeResponseBuilder +from airweave.domains.source_connections.fakes.update import FakeSourceConnectionUpdateService from airweave.domains.source_connections.service import SourceConnectionService from airweave.domains.source_connections.types import LastJobInfo, SourceConnectionStats from airweave.domains.sources.fakes.registry import FakeSourceRegistry -from airweave.domains.source_connections.fakes.delete import FakeSourceConnectionDeletionService -from airweave.domains.source_connections.fakes.create import FakeSourceConnectionCreateService -from airweave.domains.source_connections.fakes.update import FakeSourceConnectionUpdateService -from airweave.domains.syncs.fakes.lifecycle_service import FakeSyncLifecycleService +from airweave.domains.syncs.fakes.service import FakeSyncService from airweave.schemas.organization import Organization from airweave.schemas.source_connection import AuthenticationMethod, SourceConnectionListItem @@ -99,7 +98,8 @@ def _build_service( source_registry=FakeSourceRegistry(), auth_provider_registry=FakeAuthProviderRegistry(), response_builder=FakeResponseBuilder(), - sync_lifecycle=FakeSyncLifecycleService(), + sync_service=FakeSyncService(), + event_bus=AsyncMock(), create_service=FakeSourceConnectionCreateService(), update_service=FakeSourceConnectionUpdateService(), deletion_service=FakeSourceConnectionDeletionService(), @@ -399,3 +399,293 @@ async def test_get_redirect_url_missing_code_raises(): svc = _build_service() with pytest.raises(NotFoundException, match="Authorization link expired or invalid"): await svc.get_redirect_url(AsyncMock(), code="nonexistent") + + +# --------------------------------------------------------------------------- +# Sync lifecycle proxies: run, get_jobs, cancel_job, count_by_organization +# --------------------------------------------------------------------------- + +SYNC_ID = uuid4() +SC_ID = uuid4() +JOB_ID = uuid4() +CONN_ID = uuid4() +COL_READABLE_ID = "test-col-abc" + + +def _make_source_conn( + *, + id: UUID = SC_ID, + sync_id: UUID = SYNC_ID, + connection_id: UUID = CONN_ID, + readable_collection_id: str = COL_READABLE_ID, +): + sc = MagicMock() + sc.id = id + sc.sync_id = sync_id + sc.connection_id = connection_id + sc.readable_collection_id = readable_collection_id + return sc + + +def _make_collection(readable_id: str = COL_READABLE_ID): + col = MagicMock() + col.id = uuid4() + col.name = "Test Collection" + col.readable_id = readable_id + col.organization_id = ORG_ID + return col + + +def _make_connection(id: UUID = CONN_ID): + conn = MagicMock() + conn.id = id + conn.short_name = "github" + conn.name = "GitHub" + return conn + + +def _make_sync_schema(): + return MagicMock(spec_set=["id", "status"]) + + +def _make_sync_job_schema(*, sync_id: UUID = SYNC_ID, job_id: UUID = JOB_ID): + job = MagicMock() + job.id = job_id + job.sync_id = sync_id + job.status = SyncJobStatus.PENDING + job.started_at = None + job.completed_at = None + job.error = None + job.error_category = None + job.entities_inserted = 0 + job.entities_updated = 0 + job.entities_deleted = 0 + job.entities_skipped = 0 + return job + + +def _build_run_service( + sc_repo=None, + sync_service=None, + collection_repo=None, + connection_repo=None, + event_bus=None, +): + return SourceConnectionService( + sc_repo=sc_repo or FakeSourceConnectionRepository(), + collection_repo=collection_repo or FakeCollectionRepository(), + connection_repo=connection_repo or FakeConnectionRepository(), + redirect_session_repo=FakeOAuthRedirectSessionRepository(), + source_registry=FakeSourceRegistry(), + auth_provider_registry=FakeAuthProviderRegistry(), + response_builder=FakeResponseBuilder(), + sync_service=sync_service or FakeSyncService(), + event_bus=event_bus or AsyncMock(), + create_service=FakeSourceConnectionCreateService(), + update_service=FakeSourceConnectionUpdateService(), + deletion_service=FakeSourceConnectionDeletionService(), + ) + + +class _RecordingFakeSyncService(FakeSyncService): + """Records keyword arguments passed to trigger_run for assertions.""" + + def __init__(self) -> None: + super().__init__() + self.last_trigger_run: Optional[dict] = None + + async def trigger_run( + self, + db, + *, + sync_id, + collection, + connection, + ctx, + force_full_sync: bool = False, + ): + self.last_trigger_run = { + "sync_id": sync_id, + "collection": collection, + "connection": connection, + "force_full_sync": force_full_sync, + } + return await super().trigger_run( + db, + sync_id=sync_id, + collection=collection, + connection=connection, + ctx=ctx, + force_full_sync=force_full_sync, + ) + + +async def test_run_triggers_workflow_and_returns_job(): + sc = _make_source_conn() + sc_repo = FakeSourceConnectionRepository() + sc_repo.seed(SC_ID, sc) + + sync_svc = _RecordingFakeSyncService() + sync_svc.set_trigger_run_result(_make_sync_schema(), _make_sync_job_schema()) + + event_bus = AsyncMock() + + svc = _build_run_service( + sc_repo=sc_repo, + sync_service=sync_svc, + event_bus=event_bus, + ) + + col_id = uuid4() + col_schema = MagicMock(id=col_id, readable_id="col-x") + col_schema.name = "Col" + conn_schema = MagicMock(short_name="github") + svc._resolve_collection = AsyncMock(return_value=col_schema) + svc._resolve_connection = AsyncMock(return_value=conn_schema) + + result = await svc.run(AsyncMock(), id=SC_ID, ctx=_make_ctx()) + + assert result.id == JOB_ID + assert result.source_connection_id == SC_ID + assert result.status == SyncJobStatus.PENDING + assert sync_svc.last_trigger_run is not None + assert sync_svc.last_trigger_run["sync_id"] == SYNC_ID + assert sync_svc.last_trigger_run["collection"] is col_schema + assert sync_svc.last_trigger_run["connection"] is conn_schema + assert sync_svc.last_trigger_run["force_full_sync"] is False + + +async def test_run_event_failure_propagates(): + sc = _make_source_conn() + sc_repo = FakeSourceConnectionRepository() + sc_repo.seed(SC_ID, sc) + + sync_svc = FakeSyncService() + sync_svc.set_trigger_run_result(_make_sync_schema(), _make_sync_job_schema()) + + event_bus = AsyncMock() + event_bus.publish.side_effect = RuntimeError("event bus down") + + svc = _build_run_service(sc_repo=sc_repo, sync_service=sync_svc, event_bus=event_bus) + col_schema = MagicMock(id=uuid4(), readable_id="col-x") + col_schema.name = "Col" + svc._resolve_collection = AsyncMock(return_value=col_schema) + svc._resolve_connection = AsyncMock(return_value=MagicMock(short_name="github")) + + with pytest.raises(RuntimeError, match="event bus down"): + await svc.run(AsyncMock(), id=SC_ID, ctx=_make_ctx()) + + +async def test_run_not_found_raises(): + svc = _build_run_service() + with pytest.raises(NotFoundException, match="Source connection not found"): + await svc.run(AsyncMock(), id=uuid4(), ctx=_make_ctx()) + + +async def test_get_jobs_returns_mapped_jobs(): + sc = _make_source_conn() + sc_repo = FakeSourceConnectionRepository() + sc_repo.seed(SC_ID, sc) + + sync_svc = FakeSyncService() + j1 = _make_sync_job_schema(job_id=uuid4()) + j1.entities_inserted = 42 + j1.entities_updated = 3 + j2 = _make_sync_job_schema(job_id=uuid4()) + sync_svc.seed_jobs(SYNC_ID, [j1, j2]) + + svc = _build_run_service(sc_repo=sc_repo, sync_service=sync_svc) + jobs = await svc.get_jobs(AsyncMock(), id=SC_ID, ctx=_make_ctx()) + + assert len(jobs) == 2 + assert jobs[0].source_connection_id == SC_ID + assert jobs[0].entities_inserted == 42 + assert jobs[0].entities_updated == 3 + assert jobs[1].source_connection_id == SC_ID + assert jobs[1].entities_inserted == 0 + + +async def test_get_jobs_not_found_raises(): + svc = _build_run_service() + with pytest.raises(NotFoundException, match="Source connection not found"): + await svc.get_jobs(AsyncMock(), id=uuid4(), ctx=_make_ctx()) + + +async def test_cancel_job_delegates_to_sync_service(): + sync_svc = FakeSyncService() + job = _make_sync_job_schema(sync_id=SYNC_ID) + sync_svc.set_cancel_result(job) + + svc = _build_run_service(sync_service=sync_svc) + result = await svc.cancel_job( + AsyncMock(), source_connection_id=SC_ID, job_id=JOB_ID, ctx=_make_ctx() + ) + assert result.id == JOB_ID + + +async def test_run_with_force_full_sync(): + sc = _make_source_conn() + sc_repo = FakeSourceConnectionRepository() + sc_repo.seed(SC_ID, sc) + + sync_svc = _RecordingFakeSyncService() + sync_svc.set_trigger_run_result(_make_sync_schema(), _make_sync_job_schema()) + + svc = _build_run_service(sc_repo=sc_repo, sync_service=sync_svc) + col_schema = MagicMock(id=uuid4(), readable_id="col-x") + col_schema.name = "Col" + svc._resolve_collection = AsyncMock(return_value=col_schema) + svc._resolve_connection = AsyncMock(return_value=MagicMock(short_name="github")) + + result = await svc.run(AsyncMock(), id=SC_ID, ctx=_make_ctx(), force_full_sync=True) + assert result.id == JOB_ID + assert ("validate_force_full_sync", SYNC_ID) in sync_svc._calls + assert sync_svc.last_trigger_run is not None + assert sync_svc.last_trigger_run["force_full_sync"] is True + + +async def test_resolve_collection_not_found(): + sc = _make_source_conn() + sc_repo = FakeSourceConnectionRepository() + sc_repo.seed(SC_ID, sc) + + col_repo = FakeCollectionRepository() + svc = _build_run_service(sc_repo=sc_repo, collection_repo=col_repo) + + with pytest.raises(NotFoundException, match="Collection not found"): + await svc._resolve_collection(AsyncMock(), sc, _make_ctx()) + + +async def test_resolve_collection_no_readable_id(): + sc = _make_source_conn(readable_collection_id=None) + svc = _build_run_service() + + with pytest.raises(NotFoundException, match="has no readable_collection_id"): + await svc._resolve_collection(AsyncMock(), sc, _make_ctx()) + + +async def test_resolve_connection_not_found(): + sc = _make_source_conn() + sc_repo = FakeSourceConnectionRepository() + sc_repo.seed(SC_ID, sc) + + conn_repo = FakeConnectionRepository() + svc = _build_run_service(sc_repo=sc_repo, connection_repo=conn_repo) + + with pytest.raises(NotFoundException, match="not found"): + await svc._resolve_connection(AsyncMock(), sc, _make_ctx()) + + +async def test_resolve_connection_no_connection_id(): + sc = _make_source_conn(connection_id=None) + svc = _build_run_service() + + with pytest.raises(NotFoundException, match="has no connection_id"): + await svc._resolve_connection(AsyncMock(), sc, _make_ctx()) + + +async def test_count_by_organization(): + sc_repo = FakeSourceConnectionRepository() + svc = _build_run_service(sc_repo=sc_repo) + count = await svc.count_by_organization(AsyncMock(), organization_id=ORG_ID) + assert count == 0 diff --git a/backend/airweave/domains/source_connections/tests/test_update.py b/backend/airweave/domains/source_connections/tests/test_update.py index 9063c4bc1..581cec880 100644 --- a/backend/airweave/domains/source_connections/tests/test_update.py +++ b/backend/airweave/domains/source_connections/tests/test_update.py @@ -18,7 +18,7 @@ from airweave.api.context import ApiContext from airweave.core.exceptions import NotFoundException from airweave.core.logging import logger -from airweave.core.shared_models import AuthMethod, SyncStatus +from airweave.core.shared_models import AuthMethod from airweave.domains.syncs.types import InvalidSyncTransitionError from airweave.domains.collections.fakes.repository import FakeCollectionRepository from airweave.domains.connections.fakes.repository import FakeConnectionRepository @@ -30,10 +30,13 @@ ) from airweave.domains.source_connections.fakes.response import FakeResponseBuilder from airweave.domains.source_connections.update import SourceConnectionUpdateService +from airweave.domains.sources.fakes.registry import FakeSourceRegistry from airweave.domains.sources.fakes.service import FakeSourceService from airweave.domains.sources.fakes.validation import FakeSourceValidationService -from airweave.domains.syncs.fakes.record_service import FakeSyncRecordService +from airweave.domains.sources.types import SourceRegistryEntry +from airweave.domains.syncs.fakes.service import FakeSyncService from airweave.domains.syncs.fakes.repository import FakeSyncRepository +from airweave.domains.syncs.types import SyncProvisionResult from airweave.domains.temporal.fakes.schedule_service import FakeTemporalScheduleService from airweave.models.collection import Collection from airweave.models.connection import Connection @@ -105,13 +108,13 @@ def _build_service( connection_repo=None, cred_repo=None, sync_repo=None, - sync_record_service=None, + sync_service=None, source_service=None, + source_registry=None, source_validation=None, credential_encryptor=None, response_builder=None, temporal_schedule_service=None, - sync_state_machine=None, ): return SourceConnectionUpdateService( sc_repo=sc_repo or FakeSourceConnectionRepository(), @@ -119,13 +122,13 @@ def _build_service( connection_repo=connection_repo or FakeConnectionRepository(), cred_repo=cred_repo or FakeIntegrationCredentialRepository(), sync_repo=sync_repo or FakeSyncRepository(), - sync_record_service=sync_record_service or FakeSyncRecordService(), + sync_service=sync_service or FakeSyncService(), source_service=source_service or FakeSourceService(), + source_registry=source_registry or FakeSourceRegistry(), source_validation=source_validation or FakeSourceValidationService(), credential_encryptor=credential_encryptor or FakeCredentialEncryptor(), response_builder=response_builder or FakeResponseBuilder(), temporal_schedule_service=temporal_schedule_service or FakeTemporalScheduleService(), - sync_state_machine=sync_state_machine or AsyncMock(), ) @@ -211,7 +214,7 @@ class ScheduleCase: SCHEDULE_CASES = [ ScheduleCase("update_existing", has_sync=True, new_cron="0 * * * *", expect_temporal_create=True, expect_temporal_delete=False), ScheduleCase("remove_schedule", has_sync=True, new_cron=None, expect_temporal_create=False, expect_temporal_delete=True), - ScheduleCase("add_no_sync", has_sync=False, new_cron="0 * * * *", expect_temporal_create=True, expect_temporal_delete=False, expect_sync_record_create=True), + ScheduleCase("add_no_sync", has_sync=False, new_cron="0 * * * *", expect_temporal_create=False, expect_temporal_delete=False, expect_sync_record_create=True), ScheduleCase("no_connection_id_warning", has_sync=False, new_cron="0 * * * *", has_connection_id=False, expect_temporal_create=False, expect_temporal_delete=False), ] @@ -235,17 +238,38 @@ async def test_schedule_update(case: ScheduleCase): source_svc.seed(_make_source_schema(short_name="github")) temporal = FakeTemporalScheduleService() + sync_svc = FakeSyncService() + + source_registry = FakeSourceRegistry() + source_entry = MagicMock(spec=SourceRegistryEntry) + source_entry.short_name = "github" + source_entry.federated_search = False + source_registry.seed(source_entry) - sync_record_svc = FakeSyncRecordService() if case.expect_sync_record_create: - mock_sync = MagicMock(spec=schemas.Sync) - mock_sync.id = uuid4() - sync_record_svc.set_create_result(mock_sync) + created_sync_id = uuid4() + mock_sync_schema = MagicMock(spec=schemas.Sync) + mock_sync_schema.id = created_sync_id + sync_svc.set_create_result( + SyncProvisionResult( + sync_id=created_sync_id, + sync=mock_sync_schema, + sync_job=None, + cron_schedule=case.new_cron, + ) + ) col = MagicMock(spec=Collection) col.id = uuid4() col.readable_id = "test-col" + col.name = "Test Collection" col.organization_id = ORG_ID + col.vector_db_deployment_metadata_id = uuid4() + col.sync_config = None + col.created_at = NOW + col.modified_at = NOW + col.created_by_email = None + col.modified_by_email = None col_repo = FakeCollectionRepository() col_repo.seed_readable("test-col", col) else: @@ -257,9 +281,10 @@ async def test_schedule_update(case: ScheduleCase): svc = _build_service( sc_repo=sc_repo, sync_repo=sync_repo, + sync_service=sync_svc, source_service=source_svc, + source_registry=source_registry, temporal_schedule_service=temporal, - sync_record_service=sync_record_svc, collection_repo=col_repo, ) @@ -271,7 +296,7 @@ async def test_schedule_update(case: ScheduleCase): if case.expect_temporal_delete: assert any(c[0] == "delete_all_schedules_for_sync" for c in temporal._calls) if case.expect_sync_record_create: - assert any(c[0] == "create_sync" for c in sync_record_svc._calls) + assert any(c[0] == "create" for c in sync_svc._calls) async def test_schedule_add_collection_not_found(): @@ -289,6 +314,45 @@ async def test_schedule_add_collection_not_found(): await svc.update(AsyncMock(), id=sc.id, obj_in=obj_in, ctx=_make_ctx()) +async def test_schedule_add_rejects_federated_source(): + """Adding a schedule to a federated search source is rejected with 400.""" + sc = _make_sc(sync_id=None) + sc_repo = FakeSourceConnectionRepository() + sc_repo.seed(sc.id, sc) + + col = MagicMock(spec=Collection) + col.id = uuid4() + col.readable_id = "test-col" + col.name = "Test Collection" + col.organization_id = ORG_ID + col.vector_db_deployment_metadata_id = uuid4() + col.sync_config = None + col.created_at = NOW + col.modified_at = NOW + col.created_by_email = None + col.modified_by_email = None + col_repo = FakeCollectionRepository() + col_repo.seed_readable("test-col", col) + + federated_entry = MagicMock(spec=SourceRegistryEntry) + federated_entry.short_name = "github" + federated_entry.federated_search = True + source_registry = FakeSourceRegistry() + source_registry.seed(federated_entry) + + svc = _build_service( + sc_repo=sc_repo, + collection_repo=col_repo, + source_registry=source_registry, + ) + obj_in = SourceConnectionUpdate(schedule={"cron": "0 * * * *"}) + + with pytest.raises(HTTPException) as exc_info: + await svc.update(AsyncMock(), id=sc.id, obj_in=obj_in, ctx=_make_ctx()) + assert exc_info.value.status_code == 400 + assert "federated search" in str(exc_info.value.detail) + + # --------------------------------------------------------------------------- # Credential updates -- table-driven # --------------------------------------------------------------------------- @@ -480,7 +544,7 @@ def test_cron_validation(case: CronCase): @pytest.mark.asyncio async def test_credential_update_triggers_unpause(): - """Successful direct auth credential update calls sync_state_machine.transition → ACTIVE.""" + """Successful direct auth credential update calls sync_service.resume.""" conn_id = uuid4() cred_id = uuid4() sync_id = uuid4() @@ -503,7 +567,7 @@ async def test_credential_update_triggers_unpause(): validation = FakeSourceValidationService() validation.seed_auth_result("github", _AuthPayload(token="secret")) - state_machine = AsyncMock() + sync_svc = FakeSyncService() svc = _build_service( sc_repo=sc_repo, @@ -511,24 +575,18 @@ async def test_credential_update_triggers_unpause(): cred_repo=cred_repo, source_validation=validation, credential_encryptor=FakeCredentialEncryptor(), - sync_state_machine=state_machine, + sync_service=sync_svc, ) obj_in = SourceConnectionUpdate(authentication={"credentials": {"token": "new_secret"}}) await svc.update(AsyncMock(), id=sc.id, obj_in=obj_in, ctx=_make_ctx()) - state_machine.transition.assert_called_once() - call_kwargs = state_machine.transition.call_args - from airweave.core.shared_models import SyncStatus - - assert call_kwargs.kwargs.get("target") == SyncStatus.ACTIVE or ( - len(call_kwargs.args) >= 2 and call_kwargs.args[1] == SyncStatus.ACTIVE - ) + assert any(c[0] == "resume" for c in sync_svc._calls) @pytest.mark.asyncio async def test_credential_update_unpause_failure_is_nonfatal(): - """If sync_state_machine.transition raises, the update still succeeds.""" + """If sync_service.resume raises, the update still succeeds.""" conn_id = uuid4() cred_id = uuid4() sync_id = uuid4() @@ -551,10 +609,8 @@ async def test_credential_update_unpause_failure_is_nonfatal(): validation = FakeSourceValidationService() validation.seed_auth_result("github", _AuthPayload(token="secret")) - state_machine = AsyncMock() - state_machine.transition.side_effect = InvalidSyncTransitionError( - SyncStatus.ACTIVE, SyncStatus.ACTIVE - ) + sync_svc = FakeSyncService() + sync_svc.set_error(ValueError("sync not active")) svc = _build_service( sc_repo=sc_repo, @@ -562,7 +618,7 @@ async def test_credential_update_unpause_failure_is_nonfatal(): cred_repo=cred_repo, source_validation=validation, credential_encryptor=FakeCredentialEncryptor(), - sync_state_machine=state_machine, + sync_service=sync_svc, ) obj_in = SourceConnectionUpdate(authentication={"credentials": {"token": "new_secret"}}) diff --git a/backend/airweave/domains/source_connections/update.py b/backend/airweave/domains/source_connections/update.py index 80af6484d..670220f2a 100644 --- a/backend/airweave/domains/source_connections/update.py +++ b/backend/airweave/domains/source_connections/update.py @@ -11,7 +11,6 @@ from airweave.api.context import ApiContext from airweave.core.exceptions import NotFoundException from airweave.core.protocols.encryption import CredentialEncryptor -from airweave.core.shared_models import SyncStatus from airweave.db.unit_of_work import UnitOfWork from airweave.domains.collections.protocols import CollectionRepositoryProtocol from airweave.domains.connections.protocols import ConnectionRepositoryProtocol @@ -23,19 +22,17 @@ ) from airweave.domains.sources.exceptions import SourceNotFoundError from airweave.domains.sources.protocols import ( + SourceRegistryProtocol, SourceServiceProtocol, SourceValidationServiceProtocol, ) -from airweave.domains.syncs.protocols import ( - SyncRecordServiceProtocol, - SyncRepositoryProtocol, - SyncStateMachineProtocol, -) +from airweave.domains.syncs.protocols import SyncRepositoryProtocol, SyncServiceProtocol from airweave.domains.syncs.types import InvalidSyncTransitionError, OptimisticLockError from airweave.domains.temporal.protocols import TemporalScheduleServiceProtocol from airweave.models.source_connection import SourceConnection from airweave.schemas.source_connection import ( AuthenticationMethod, + ScheduleConfig, SourceConnectionUpdate, ) from airweave.schemas.source_connection import ( @@ -59,13 +56,13 @@ def __init__( connection_repo: ConnectionRepositoryProtocol, cred_repo: IntegrationCredentialRepositoryProtocol, sync_repo: SyncRepositoryProtocol, - sync_record_service: SyncRecordServiceProtocol, + sync_service: SyncServiceProtocol, source_service: SourceServiceProtocol, + source_registry: SourceRegistryProtocol, source_validation: SourceValidationServiceProtocol, credential_encryptor: CredentialEncryptor, response_builder: ResponseBuilderProtocol, temporal_schedule_service: TemporalScheduleServiceProtocol, - sync_state_machine: SyncStateMachineProtocol, ) -> None: """Initialize with repositories and collaborator services.""" self._sc_repo = sc_repo @@ -73,13 +70,13 @@ def __init__( self._connection_repo = connection_repo self._cred_repo = cred_repo self._sync_repo = sync_repo - self._sync_record_service = sync_record_service + self._sync_service = sync_service self._source_service = source_service + self._source_registry = source_registry self._source_validation = source_validation self._credential_encryptor = credential_encryptor self._response_builder = response_builder self._temporal_schedule_service = temporal_schedule_service - self._sync_state_machine = sync_state_machine async def update( self, @@ -134,10 +131,9 @@ async def update( if source_conn.sync_id: try: - await self._sync_state_machine.transition( - sync_id=source_conn.sync_id, - target=SyncStatus.ACTIVE, - ctx=ctx, + await self._sync_service.resume( + source_conn.sync_id, + ctx, reason="Credential update completed", ) except (InvalidSyncTransitionError, OptimisticLockError, ValueError): @@ -203,62 +199,55 @@ async def _handle_schedule_update( uow, ) elif new_cron: - # No sync exists but we're adding a schedule - create a new sync - # Get the source to validate schedule - source = await self._get_and_validate_source(source_conn.short_name, ctx) - self._validate_cron_schedule_for_source(new_cron, source, ctx) - - # Check if connection_id exists (might be None for OAuth flows) if not source_conn.connection_id: ctx.logger.warning( f"Cannot create schedule for SC {source_conn.id} without connection_id" ) - # Skip schedule creation for connections without connection_id del update_data["schedule"] return - # Get the collection - collection = await self._collection_repo.get_by_readable_id( + collection_orm = await self._collection_repo.get_by_readable_id( uow.session, readable_id=source_conn.readable_collection_id, ctx=ctx ) - if not collection: + if not collection_orm: raise NotFoundException("Collection not found") + collection = schemas.CollectionRecord.model_validate( + collection_orm, from_attributes=True + ) - # Resolve destination IDs - dest_ids = await self._sync_record_service.resolve_destination_ids(uow.session, ctx) + source_entry = self._source_registry.get(source_conn.short_name) + if source_entry.federated_search: + raise HTTPException( + status_code=400, + detail=f"Source '{source_conn.short_name}' is a federated search source " + "and does not support scheduled syncs.", + ) - # Create a new sync with the schedule - sync, _ = await self._sync_record_service.create_sync( + dest_ids = await self._sync_service.resolve_destination_ids(uow.session, ctx) + + sync_result = await self._sync_service.create( uow.session, - name=f"Sync for {source_conn.name}", + name=source_conn.name, source_connection_id=source_conn.connection_id, destination_connection_ids=dest_ids, - cron_schedule=new_cron, + collection_id=collection.id, + collection_readable_id=collection.readable_id, + source_entry=source_entry, + schedule_config=ScheduleConfig(cron=new_cron), run_immediately=False, ctx=ctx, uow=uow, ) - # Apply the sync_id update to the source connection now - # so that temporal_schedule_service can find it source_conn = await self._sc_repo.update( uow.session, db_obj=source_conn, - obj_in={"sync_id": sync.id}, + obj_in={"sync_id": sync_result.sync_id}, ctx=ctx, uow=uow, ) await uow.session.flush() - # Create the Temporal schedule - await self._temporal_schedule_service.create_or_update_schedule( - sync_id=sync.id, - cron_schedule=new_cron, - db=uow.session, - ctx=ctx, - uow=uow, - ) - if "schedule" in update_data: del update_data["schedule"] diff --git a/backend/airweave/domains/sources/exceptions/classifier.py b/backend/airweave/domains/sources/exceptions/classifier.py index 5641f512a..5b555219c 100644 --- a/backend/airweave/domains/sources/exceptions/classifier.py +++ b/backend/airweave/domains/sources/exceptions/classifier.py @@ -66,6 +66,14 @@ def classify_error(exc: Exception) -> ErrorClassification: message=str(exc), ) + # Catch-all for remaining AuthProviderError subtypes (e.g. MissingFieldsError, + # ConfigError) — these are auth provider issues the user needs to address. + if isinstance(exc, AuthProviderError): + return ErrorClassification( + category=SourceConnectionErrorCategory.AUTH_PROVIDER_CREDENTIALS_INVALID, + message=str(exc), + ) + # --- Legacy SourceTokenRefreshError --- if isinstance(exc, SourceTokenRefreshError): return ErrorClassification( diff --git a/backend/airweave/domains/sources/lifecycle.py b/backend/airweave/domains/sources/lifecycle.py index 8c147f3de..cc568f47d 100644 --- a/backend/airweave/domains/sources/lifecycle.py +++ b/backend/airweave/domains/sources/lifecycle.py @@ -7,7 +7,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast from uuid import UUID if TYPE_CHECKING: @@ -22,10 +22,11 @@ from airweave.core.exceptions import NotFoundException from airweave.core.logging import ContextualLogger, LoggerConfigurator from airweave.core.shared_models import FeatureFlag -from airweave.domains.auth_provider._base import BaseAuthProvider +from airweave.domains.auth_provider._base import AUTH_PROVIDER_OPTIONAL_FIELDS, BaseAuthProvider from airweave.domains.auth_provider.exceptions import ( AuthProviderAccountNotFoundError, AuthProviderAuthError, + AuthProviderError, ) from airweave.domains.auth_provider.protocols import AuthProviderRegistryProtocol from airweave.domains.connections.protocols import ConnectionRepositoryProtocol @@ -40,11 +41,6 @@ SourceNotFoundError, SourceValidationError, ) -from airweave.domains.sources.token_providers.exceptions import ( - TokenCredentialsInvalidError, - TokenExpiredError, - TokenProviderAccountGoneError, -) from airweave.domains.sources.protocols import ( SourceLifecycleServiceProtocol, SourceRegistryProtocol, @@ -52,6 +48,11 @@ from airweave.domains.sources.rate_limiting.service import SourceRateLimiter from airweave.domains.sources.token_providers.auth_provider import AuthProviderTokenProvider from airweave.domains.sources.token_providers.credential import DirectCredentialProvider +from airweave.domains.sources.token_providers.exceptions import ( + TokenCredentialsInvalidError, + TokenExpiredError, + TokenProviderAccountGoneError, +) from airweave.domains.sources.token_providers.oauth import OAuthTokenProvider from airweave.domains.sources.token_providers.static import StaticTokenProvider from airweave.domains.sources.types import AuthConfig, SourceConnectionData, SourceRegistryEntry @@ -128,13 +129,19 @@ async def create( ) # 2. Get auth configuration (credentials + proxy setup) - auth_config = await self._get_auth_configuration( - db=db, - source_connection_data=source_connection_data, - ctx=ctx, - logger=logger, - access_token=access_token, - ) + try: + auth_config = await self._get_auth_configuration( + db=db, + source_connection_data=source_connection_data, + ctx=cast(ApiContext, ctx), + logger=logger, + access_token=access_token, + ) + except AuthProviderError as exc: + raise SourceValidationError( + short_name=source_connection_data.short_name, + reason=f"auth provider error: {exc}", + ) from exc # 3. Resolve auth provider token_provider = await self._resolve_token_provider( @@ -328,15 +335,12 @@ async def _get_auth_configuration( ) # Case 2: Auth provider connection - if ( - source_connection_data.readable_auth_provider_id - and source_connection_data.auth_provider_config - ): + if source_connection_data.readable_auth_provider_id: return await self._get_auth_provider_configuration( db=db, source_connection_data=source_connection_data, readable_auth_provider_id=source_connection_data.readable_auth_provider_id, - auth_provider_config=source_connection_data.auth_provider_config, + auth_provider_config=source_connection_data.auth_provider_config or {}, ctx=ctx, logger=logger, ) @@ -369,11 +373,13 @@ async def _get_auth_provider_configuration( logger=logger, ) - # Get runtime auth fields from the source registry (precomputed at startup) + # Get runtime auth fields from the source registry (precomputed at startup). + # Auth providers handle token refresh, so OAuth lifecycle fields + # (refresh_token, client_id, client_secret) are always optional. short_name = source_connection_data.short_name entry = self._source_registry.get(short_name) auth_fields_all = entry.runtime_auth_all_fields - auth_fields_optional = entry.runtime_auth_optional_fields + auth_fields_optional = entry.runtime_auth_optional_fields | AUTH_PROVIDER_OPTIONAL_FIELDS source_config_field_mappings = self._build_source_config_field_mappings( source_connection_data @@ -384,6 +390,7 @@ async def _get_auth_provider_configuration( source_auth_config_fields=auth_fields_all, optional_fields=auth_fields_optional, source_config_field_mappings=source_config_field_mappings or None, + source_connection_id=source_connection_data.source_connection_id, ) if auth_result.source_config: @@ -630,6 +637,17 @@ async def _resolve_token_provider( if access_token is not None: return StaticTokenProvider(access_token, source_short_name=short_name) + # Auth provider takes priority — ensures errors are classified correctly + # regardless of whether the source uses OAuth or direct auth. + if auth_provider_instance: + return AuthProviderTokenProvider( + auth_provider_instance=auth_provider_instance, + source_short_name=short_name, + source_registry=self._source_registry, + logger=logger, + source_connection_id=source_connection_data.source_connection_id, + ) + entry = self._source_registry.get(short_name) source_credentials = self._normalize_credentials(source_credentials, entry, logger) @@ -641,14 +659,6 @@ async def _resolve_token_provider( return DirectCredentialProvider(source_credentials, source_short_name=short_name) try: - if auth_provider_instance: - return AuthProviderTokenProvider( - auth_provider_instance=auth_provider_instance, - source_short_name=short_name, - source_registry=self._source_registry, - logger=logger, - ) - # Sources that support both OAuth and API key auth (e.g. calcom, coda) # may have structured credentials without access_token when using # API key mode — route those to DirectCredentialProvider. diff --git a/backend/airweave/domains/sources/tests/test_lifecycle.py b/backend/airweave/domains/sources/tests/test_lifecycle.py index 135ecb99d..c698e1b7b 100644 --- a/backend/airweave/domains/sources/tests/test_lifecycle.py +++ b/backend/airweave/domains/sources/tests/test_lifecycle.py @@ -402,7 +402,7 @@ class AuthConfigRoutingCase: AuthConfigRoutingCase(id="database-fallthrough", expected_route="database"), AuthConfigRoutingCase(id="auth-provider-id-but-no-config", readable_auth_provider_id="pd-1", - expected_route="database"), + expected_route="auth_provider"), ] diff --git a/backend/airweave/domains/sources/token_providers/auth_provider.py b/backend/airweave/domains/sources/token_providers/auth_provider.py index d16e7d298..9b833bc48 100644 --- a/backend/airweave/domains/sources/token_providers/auth_provider.py +++ b/backend/airweave/domains/sources/token_providers/auth_provider.py @@ -4,11 +4,12 @@ import time from typing import TYPE_CHECKING, Optional +from uuid import UUID from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential from airweave.core.logging import ContextualLogger -from airweave.domains.auth_provider._base import BaseAuthProvider +from airweave.domains.auth_provider._base import AUTH_PROVIDER_OPTIONAL_FIELDS, BaseAuthProvider from airweave.domains.auth_provider.exceptions import ( AuthProviderAccountNotFoundError, AuthProviderAuthError, @@ -50,12 +51,14 @@ def __init__( source_registry: SourceRegistryProtocol, *, logger: ContextualLogger, + source_connection_id: Optional[UUID] = None, ): """Initialize with an auth provider instance and source registry.""" self._provider = auth_provider_instance self._source_short_name = source_short_name self._source_registry = source_registry self._logger = logger + self._source_connection_id = source_connection_id self._cached_token: Optional[str] = None self._cached_at: float = 0.0 @@ -134,15 +137,28 @@ async def _fetch_token(self) -> str: provider_kind=self.provider_kind, ) from e - if not isinstance(creds, dict) or "access_token" not in creds: + if not isinstance(creds, dict) or not creds: raise TokenProviderMissingCredsError( - f"No access_token in auth provider response for {self._source_short_name}", + f"Empty auth provider response for {self._source_short_name}", source_short_name=self._source_short_name, provider_kind=self.provider_kind, - missing_fields=["access_token"], + missing_fields=entry.runtime_auth_all_fields, ) - return creds["access_token"] + # Extract the primary credential value. Prefer access_token if present, + # otherwise use the first runtime auth field (e.g. personal_access_token, + # api_key). This supports both OAuth and non-OAuth sources. + for field in ["access_token"] + entry.runtime_auth_all_fields: + if field in creds: + return str(creds[field]) + + raise TokenProviderMissingCredsError( + f"No usable credential in auth provider response for {self._source_short_name}. " + f"Expected one of: {entry.runtime_auth_all_fields}", + source_short_name=self._source_short_name, + provider_kind=self.provider_kind, + missing_fields=entry.runtime_auth_all_fields, + ) @retry( retry=retry_if_exception_type((AuthProviderRateLimitError, AuthProviderServerError)), @@ -154,7 +170,8 @@ async def _call_provider_with_retry(self, entry) -> dict: return await self._provider.get_creds_for_source( source_short_name=self._source_short_name, source_auth_config_fields=entry.runtime_auth_all_fields, - optional_fields=entry.runtime_auth_optional_fields, + optional_fields=entry.runtime_auth_optional_fields | AUTH_PROVIDER_OPTIONAL_FIELDS, + source_connection_id=self._source_connection_id, ) async def get_token(self) -> str: diff --git a/backend/airweave/domains/sources/token_providers/oauth.py b/backend/airweave/domains/sources/token_providers/oauth.py index 844a6326a..6ce63cfc4 100644 --- a/backend/airweave/domains/sources/token_providers/oauth.py +++ b/backend/airweave/domains/sources/token_providers/oauth.py @@ -146,6 +146,31 @@ async def force_refresh(self) -> str: self._apply_refresh(result) return self._token + async def get_token_for_resource(self, resource_scope: str) -> Optional[str]: + """Exchange refresh token for an access token scoped to a different resource. + + Used by SharePoint Online to get a SP REST API token from a Graph-scoped + refresh token. The exchange does not persist the rotated refresh token. + + Returns None if refresh is not supported. + """ + if not self._can_refresh: + return None + + try: + async with get_db_context() as db: + result = await self._oauth2_service.exchange_token_for_scope( + db=db, + integration_short_name=self._source_short_name, + connection_id=self._connection_id, + ctx=self._ctx, + scope=resource_scope, + ) + return str(result) + except Exception as e: + self._logger.warning(f"Failed to exchange token for scope {resource_scope}: {e}") + return None + # ------------------------------------------------------------------ # Private # ------------------------------------------------------------------ diff --git a/backend/airweave/domains/sources/validation.py b/backend/airweave/domains/sources/validation.py index 195e9f567..25edf4ecb 100644 --- a/backend/airweave/domains/sources/validation.py +++ b/backend/airweave/domains/sources/validation.py @@ -37,14 +37,12 @@ def validate_config( """ entry = self._get_entry_or_404(short_name) - if not config_fields: - return {} - - payload = self._as_mapping(config_fields) - config_class = entry.config_ref if config_class is None: - return payload + # Source has no config schema — anything goes + return self._as_mapping(config_fields) if config_fields else {} + + payload = self._as_mapping(config_fields) if config_fields else {} self._enforce_feature_flags(short_name, payload, config_class, ctx) diff --git a/backend/airweave/domains/storage/file_service.py b/backend/airweave/domains/storage/file_service.py index 1d3897942..945ebf40d 100644 --- a/backend/airweave/domains/storage/file_service.py +++ b/backend/airweave/domains/storage/file_service.py @@ -61,7 +61,7 @@ def _ensure_base_dir(self) -> None: @staticmethod async def _resolve_headers(auth: SourceAuthProvider, url: str) -> dict: """Build auth headers. Pre-signed URLs skip the bearer token.""" - if "X-Amz-Algorithm" in url: + if "X-Amz-Algorithm" in url or "tempauth=" in url: return {} token = await auth.get_token() if hasattr(auth, "get_token") else None if not token: diff --git a/backend/airweave/domains/sync_pipeline/builders/destinations.py b/backend/airweave/domains/sync_pipeline/builders/destinations.py index f96bb9faa..2fa2896db 100644 --- a/backend/airweave/domains/sync_pipeline/builders/destinations.py +++ b/backend/airweave/domains/sync_pipeline/builders/destinations.py @@ -24,6 +24,7 @@ async def build_destinations( collection: schemas.CollectionRecord, logger: ContextualLogger, execution_config: Optional[SyncConfig] = None, + source_supports_acl: bool = False, ) -> List[BaseDestination]: """Build destinations.""" return await cls._create_destinations( @@ -31,6 +32,7 @@ async def build_destinations( collection=collection, logger=logger, execution_config=execution_config, + source_supports_acl=source_supports_acl, ) # ------------------------------------------------------------------------- @@ -44,6 +46,7 @@ async def _create_destinations( collection: schemas.CollectionRecord, logger: ContextualLogger, execution_config: Optional[SyncConfig] = None, + source_supports_acl: bool = False, ) -> List[BaseDestination]: """Create destination instances.""" destinations = [] @@ -59,6 +62,7 @@ async def _create_destinations( destination_connection_id=destination_connection_id, collection=collection, logger=logger, + source_supports_acl=source_supports_acl, ) if destination: destinations.append(destination) @@ -87,18 +91,20 @@ async def _create_single_destination( destination_connection_id: UUID, collection: schemas.CollectionRecord, logger: ContextualLogger, + source_supports_acl: bool = False, ) -> Optional[BaseDestination]: """Create a single destination instance.""" if destination_connection_id != NATIVE_VESPA_UUID: logger.warning(f"Unknown destination connection {destination_connection_id}, skipping") return None - return await cls._create_vespa(collection, logger) + return await cls._create_vespa(collection, logger, source_supports_acl=source_supports_acl) @classmethod async def _create_vespa( cls, collection: schemas.CollectionRecord, logger: ContextualLogger, + source_supports_acl: bool = False, ) -> BaseDestination: """Create native Vespa destination directly.""" logger.info("Using native Vespa destination (settings-based)") @@ -109,6 +115,7 @@ async def _create_vespa( organization_id=collection.organization_id, vector_size=None, logger=logger, + source_supports_acl=source_supports_acl, ) logger.info("Created native Vespa destination") return destination diff --git a/backend/airweave/domains/sync_pipeline/factory.py b/backend/airweave/domains/sync_pipeline/factory.py index ca0dff8c9..cdf225b02 100644 --- a/backend/airweave/domains/sync_pipeline/factory.py +++ b/backend/airweave/domains/sync_pipeline/factory.py @@ -184,12 +184,14 @@ async def create_orchestrator( execution_config=resolved_config, access_token=access_token, ) + source_entry = self._source_registry.get(sc.short_name) destinations = await self._build_destinations( db=db, sync=sync, collection=collection, ctx=ctx, execution_config=resolved_config, + source_supports_acl=source_entry.supports_access_control, ) entity_tracker = await self._build_entity_tracker( db=db, @@ -498,6 +500,7 @@ async def _build_destinations( collection: schemas.CollectionRecord, ctx: BaseContext, execution_config: SyncConfig, + source_supports_acl: bool = False, ) -> list: """Build destination instances for the sync.""" dest_logger = LoggerConfigurator.configure_logger( @@ -513,6 +516,7 @@ async def _build_destinations( collection=collection, logger=dest_logger, execution_config=execution_config, + source_supports_acl=source_supports_acl, ) # ------------------------------------------------------------------------- diff --git a/backend/airweave/domains/sync_pipeline/orchestrator.py b/backend/airweave/domains/sync_pipeline/orchestrator.py index 4a2f34de5..4796e4cf6 100644 --- a/backend/airweave/domains/sync_pipeline/orchestrator.py +++ b/backend/airweave/domains/sync_pipeline/orchestrator.py @@ -23,7 +23,7 @@ from airweave.domains.sync_pipeline.worker_pool import AsyncWorkerPool from airweave.domains.syncs.cursors.service import SyncCursorService from airweave.domains.syncs.jobs.protocols import SyncJobStateMachineProtocol -from airweave.domains.syncs.jobs.types import LifecycleData +from airweave.domains.syncs.jobs.types import InvalidTransitionError, LifecycleData from airweave.domains.syncs.protocols import SyncStateMachineProtocol from airweave.domains.temporal.metrics import worker_metrics from airweave.domains.usage.exceptions import ( @@ -714,21 +714,49 @@ async def _handle_cancellation(self) -> None: """Centralized cancellation handler - explicit and immediate.""" self.sync_context.logger.info("Handling cancellation...") - # 1. Cancel all pending tasks IMMEDIATELY + # Cancel all pending tasks immediately if self.worker_pool: await self.worker_pool.cancel_all() - # 2. Cancel stream to stop producer + # Cancel stream to stop producer await self.stream.cancel() - await self._state_machine.transition( - sync_job_id=self.sync_context.sync_job.id, - target=SyncJobStatus.CANCELLED, - ctx=self.sync_context, - lifecycle_data=self._lifecycle_data, - ) + # Transition through CANCELLING → CANCELLED. + # RUNNING → CANCELLING is required by the state machine. + # If still PENDING (cancellation before _start_sync completed), + # PENDING → CANCELLING is invalid, so fall through to direct CANCELLED. + # + # The workflow will also attempt these transitions via + # TransitionSyncJobActivity once the CancelledError propagates. + # That redundancy is intentional — it guards against the activity + # being killed before the error reaches the workflow. + try: + await self._state_machine.transition( + sync_job_id=self.sync_context.sync_job.id, + target=SyncJobStatus.CANCELLING, + ctx=self.sync_context, + lifecycle_data=self._lifecycle_data, + ) + except InvalidTransitionError as exc: + self.sync_context.logger.debug( + "Skipped CANCELLING transition", + current_state=exc.current.value, + ) + + try: + await self._state_machine.transition( + sync_job_id=self.sync_context.sync_job.id, + target=SyncJobStatus.CANCELLED, + ctx=self.sync_context, + lifecycle_data=self._lifecycle_data, + ) + except InvalidTransitionError as exc: + self.sync_context.logger.warning( + "Skipped CANCELLED transition — job in unexpected terminal state", + current_state=exc.current.value, + ) - # 4. Track sync cancelled + # Track sync cancelled if not self.sync_context.sync_job.started_at: # This can happen if cancellation occurs during _start_sync before # the job status is updated with started_at diff --git a/backend/airweave/domains/sync_pipeline/tests/test_orchestrator_coverage.py b/backend/airweave/domains/sync_pipeline/tests/test_orchestrator_coverage.py index 2044b5b2d..2bf4a0a24 100644 --- a/backend/airweave/domains/sync_pipeline/tests/test_orchestrator_coverage.py +++ b/backend/airweave/domains/sync_pipeline/tests/test_orchestrator_coverage.py @@ -1,6 +1,5 @@ """Coverage tests for SyncOrchestrator — missing state_machine.transition lines.""" -import asyncio from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch from uuid import UUID @@ -12,6 +11,7 @@ from airweave.domains.sync_pipeline.contexts.sync import SyncContext from airweave.domains.sync_pipeline.orchestrator import SyncOrchestrator from airweave.domains.sync_pipeline.pipeline.entity_tracker import SyncStats +from airweave.domains.syncs.jobs.types import InvalidTransitionError MODULE = "airweave.domains.sync_pipeline.orchestrator" @@ -184,11 +184,37 @@ async def test_handle_sync_failure_calls_state_machine_transition(): @pytest.mark.unit -async def test_handle_cancellation_calls_state_machine_transition(): - """_handle_cancellation calls state_machine.transition with CANCELLED.""" +async def test_handle_cancellation_transitions_through_cancelling(): + """_handle_cancellation transitions RUNNING → CANCELLING → CANCELLED.""" orch, sm = _make_orchestrator() with patch(f"{MODULE}.business_events"): await orch._handle_cancellation() - assert any(c["target"] == SyncJobStatus.CANCELLED for c in sm.calls) + targets = [c["target"] for c in sm.calls] + assert targets == [SyncJobStatus.CANCELLING, SyncJobStatus.CANCELLED] + + +@pytest.mark.unit +async def test_handle_cancellation_pending_falls_through_to_cancelled(): + """When CANCELLING raises InvalidTransitionError (PENDING), go directly to CANCELLED.""" + + class PendingStateMachine: + def __init__(self): + self.calls: list[dict] = [] + + async def transition(self, **kwargs): + self.calls.append(kwargs) + if kwargs["target"] == SyncJobStatus.CANCELLING: + raise InvalidTransitionError( + SyncJobStatus.PENDING, SyncJobStatus.CANCELLING + ) + return MagicMock(applied=True) + + orch, sm = _make_orchestrator(state_machine=PendingStateMachine()) + + with patch(f"{MODULE}.business_events"): + await orch._handle_cancellation() + + targets = [c["target"] for c in sm.calls] + assert targets == [SyncJobStatus.CANCELLING, SyncJobStatus.CANCELLED] diff --git a/backend/airweave/domains/syncs/fakes/lifecycle_service.py b/backend/airweave/domains/syncs/fakes/lifecycle_service.py deleted file mode 100644 index 0f5be25f2..000000000 --- a/backend/airweave/domains/syncs/fakes/lifecycle_service.py +++ /dev/null @@ -1,129 +0,0 @@ -"""Fake sync lifecycle service for testing.""" - -from typing import List, Optional -from uuid import UUID - -from sqlalchemy.ext.asyncio import AsyncSession - -from airweave.api.context import ApiContext -from airweave.db.unit_of_work import UnitOfWork -from airweave.domains.sources.types import SourceRegistryEntry -from airweave.domains.syncs.types import SyncProvisionResult -from airweave.schemas.source_connection import ScheduleConfig, SourceConnectionJob - - -class FakeSyncLifecycleService: - """In-memory fake for SyncLifecycleServiceProtocol.""" - - def __init__(self) -> None: - """Initialize with empty state.""" - self._calls: list[tuple] = [] - self._provision_result: Optional[SyncProvisionResult] = None - self._run_result: Optional[SourceConnectionJob] = None - self._jobs: dict[UUID, List[SourceConnectionJob]] = {} - self._cancel_result: Optional[SourceConnectionJob] = None - self._should_raise: Optional[Exception] = None - - def set_provision_result(self, result: Optional[SyncProvisionResult]) -> None: - """Configure provision_sync() return value.""" - self._provision_result = result - - def set_run_result(self, result: SourceConnectionJob) -> None: - """Configure run() return value.""" - self._run_result = result - - def seed_jobs(self, sc_id: UUID, jobs: List[SourceConnectionJob]) -> None: - """Seed jobs returned by get_jobs.""" - self._jobs[sc_id] = jobs - - def set_cancel_result(self, result: SourceConnectionJob) -> None: - """Configure cancel_job() return value.""" - self._cancel_result = result - - def set_error(self, error: Exception) -> None: - """Make all subsequent calls raise this error.""" - self._should_raise = error - - async def teardown_syncs_for_collection( - self, - db: AsyncSession, - *, - sync_ids: List[UUID], - collection_id: UUID, - organization_id: UUID, - ctx: ApiContext, - cancel_timeout_seconds: int = 15, - ) -> None: - """Record call — no-op fake.""" - self._calls.append( - ("teardown_syncs_for_collection", db, sync_ids, collection_id, organization_id, ctx) - ) - if self._should_raise: - raise self._should_raise - - async def provision_sync( - self, - db: AsyncSession, - *, - name: str, - source_connection_id: UUID, - destination_connection_ids: List[UUID], - collection_id: UUID, - collection_readable_id: str, - source_entry: SourceRegistryEntry, - schedule_config: Optional[ScheduleConfig], - run_immediately: bool, - ctx: ApiContext, - uow: UnitOfWork, - ) -> Optional[SyncProvisionResult]: - """Record call and return canned result.""" - self._calls.append(("provision_sync", name, source_connection_id, collection_id)) - if self._should_raise: - raise self._should_raise - return self._provision_result - - async def run( - self, - db: AsyncSession, - *, - id: UUID, - ctx: ApiContext, - force_full_sync: bool = False, - ) -> SourceConnectionJob: - """Record call and return canned result.""" - self._calls.append(("run", db, id, ctx, force_full_sync)) - if self._should_raise: - raise self._should_raise - if self._run_result is None: - raise RuntimeError("FakeSyncLifecycleService.run_result not configured") - return self._run_result - - async def get_jobs( - self, - db: AsyncSession, - *, - id: UUID, - ctx: ApiContext, - limit: int = 100, - ) -> List[SourceConnectionJob]: - """Record call and return seeded jobs.""" - self._calls.append(("get_jobs", db, id, ctx, limit)) - if self._should_raise: - raise self._should_raise - return self._jobs.get(id, [])[:limit] - - async def cancel_job( - self, - db: AsyncSession, - *, - source_connection_id: UUID, - job_id: UUID, - ctx: ApiContext, - ) -> SourceConnectionJob: - """Record call and return canned result.""" - self._calls.append(("cancel_job", db, source_connection_id, job_id, ctx)) - if self._should_raise: - raise self._should_raise - if self._cancel_result is None: - raise RuntimeError("FakeSyncLifecycleService.cancel_result not configured") - return self._cancel_result diff --git a/backend/airweave/domains/syncs/fakes/record_service.py b/backend/airweave/domains/syncs/fakes/record_service.py deleted file mode 100644 index f0d45fb95..000000000 --- a/backend/airweave/domains/syncs/fakes/record_service.py +++ /dev/null @@ -1,91 +0,0 @@ -"""Fake sync record service for testing.""" - -from typing import List, Optional, Tuple -from uuid import UUID - -from sqlalchemy.ext.asyncio import AsyncSession - -from airweave import schemas -from airweave.api.context import ApiContext -from airweave.db.unit_of_work import UnitOfWork - - -class FakeSyncRecordService: - """In-memory fake for SyncRecordServiceProtocol.""" - - def __init__(self) -> None: - """Initialize with empty state.""" - self._calls: list[tuple] = [] - self._create_result: Optional[Tuple[schemas.Sync, Optional[schemas.SyncJob]]] = None - self._trigger_result: Optional[Tuple[schemas.Sync, schemas.SyncJob]] = None - self._resolve_dest_ids: Optional[List[UUID]] = None - self._should_raise: Optional[Exception] = None - - def set_create_result( - self, sync: schemas.Sync, sync_job: Optional[schemas.SyncJob] = None - ) -> None: - """Configure create_sync return value.""" - self._create_result = (sync, sync_job) - - def set_trigger_result(self, sync: schemas.Sync, sync_job: schemas.SyncJob) -> None: - """Configure trigger_sync_run return value.""" - self._trigger_result = (sync, sync_job) - - def set_resolve_dest_ids(self, ids: List[UUID]) -> None: - """Configure resolve_destination_ids return value.""" - self._resolve_dest_ids = ids - - def set_error(self, error: Exception) -> None: - """Make all subsequent calls raise this error.""" - self._should_raise = error - - async def resolve_destination_ids( - self, - db: AsyncSession, - ctx: ApiContext, - ) -> List[UUID]: - """Record call and return canned result.""" - self._calls.append(("resolve_destination_ids",)) - if self._should_raise: - raise self._should_raise - if self._resolve_dest_ids is None: - from airweave.core.constants.reserved_ids import NATIVE_VESPA_UUID - - return [NATIVE_VESPA_UUID] - return self._resolve_dest_ids - - async def create_sync( - self, - db: AsyncSession, - *, - name: str, - source_connection_id: UUID, - destination_connection_ids: List[UUID], - cron_schedule: Optional[str], - run_immediately: bool, - ctx: ApiContext, - uow: UnitOfWork, - ) -> Tuple[schemas.Sync, Optional[schemas.SyncJob]]: - """Record call and return canned result.""" - self._calls.append( - ("create_sync", name, source_connection_id, cron_schedule, run_immediately) - ) - if self._should_raise: - raise self._should_raise - if self._create_result is None: - raise RuntimeError("FakeSyncRecordService.create_result not configured") - return self._create_result - - async def trigger_sync_run( - self, - db: AsyncSession, - sync_id: UUID, - ctx: ApiContext, - ) -> Tuple[schemas.Sync, schemas.SyncJob]: - """Record call and return canned result.""" - self._calls.append(("trigger_sync_run", db, sync_id, ctx)) - if self._should_raise: - raise self._should_raise - if self._trigger_result is None: - raise RuntimeError("FakeSyncRecordService.trigger_result not configured") - return self._trigger_result diff --git a/backend/airweave/domains/syncs/fakes/service.py b/backend/airweave/domains/syncs/fakes/service.py index d31050fb6..eade30d2e 100644 --- a/backend/airweave/domains/syncs/fakes/service.py +++ b/backend/airweave/domains/syncs/fakes/service.py @@ -1,17 +1,181 @@ -"""Fake sync service for testing.""" +"""Fake sync service for testing — matches unified SyncServiceProtocol.""" -from typing import Optional +from typing import Dict, List, Optional, Tuple +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession from airweave import schemas from airweave.api.context import ApiContext +from airweave.core.context import BaseContext +from airweave.core.shared_models import SyncStatus +from airweave.db.unit_of_work import UnitOfWork +from airweave.domains.sources.types import SourceRegistryEntry from airweave.domains.sync_pipeline.config import SyncConfig +from airweave.domains.syncs.protocols import SyncServiceProtocol +from airweave.domains.syncs.types import SyncProvisionResult, SyncTransitionResult +from airweave.schemas.source_connection import ScheduleConfig -class FakeSyncService: - """In-memory fake for SyncServiceProtocol.""" +class FakeSyncService(SyncServiceProtocol): + """In-memory fake for the unified SyncServiceProtocol.""" def __init__(self) -> None: self._calls: list[tuple] = [] + self._create_result: Optional[SyncProvisionResult] = None + self._get_result: Optional[schemas.Sync] = None + self._trigger_run_result: Optional[Tuple[schemas.Sync, schemas.SyncJob]] = None + self._jobs: Dict[UUID, List[schemas.SyncJob]] = {} + self._cancel_result: Optional[schemas.SyncJob] = None + self._resolve_dest_ids: Optional[List[UUID]] = None + self._run_result: Optional[schemas.Sync] = None + self._should_raise: Optional[Exception] = None + + # -- Configuration helpers -- + + def set_create_result(self, result: SyncProvisionResult) -> None: + self._create_result = result + + def set_get_result(self, result: schemas.Sync) -> None: + self._get_result = result + + def set_trigger_run_result(self, sync: schemas.Sync, job: schemas.SyncJob) -> None: + self._trigger_run_result = (sync, job) + + def seed_jobs(self, sync_id: UUID, jobs: List[schemas.SyncJob]) -> None: + self._jobs[sync_id] = jobs + + def set_cancel_result(self, result: schemas.SyncJob) -> None: + self._cancel_result = result + + def set_resolve_dest_ids(self, ids: List[UUID]) -> None: + self._resolve_dest_ids = ids + + def set_run_result(self, result: schemas.Sync) -> None: + self._run_result = result + + def set_error(self, error: Exception) -> None: + self._should_raise = error + + # -- Lifecycle -- + + async def create( + self, + db: AsyncSession, + *, + name: str, + source_connection_id: UUID, + destination_connection_ids: List[UUID], + collection_id: UUID, + collection_readable_id: str, + source_entry: SourceRegistryEntry, + schedule_config: Optional[ScheduleConfig], + run_immediately: bool, + ctx: ApiContext, + uow: UnitOfWork, + ) -> SyncProvisionResult: + self._calls.append(("create", name, source_connection_id, collection_id)) + if self._should_raise: + raise self._should_raise + if self._create_result is None: + raise RuntimeError("FakeSyncService.create_result not configured") + return self._create_result + + async def get(self, db: AsyncSession, *, sync_id: UUID, ctx: BaseContext) -> schemas.Sync: + self._calls.append(("get", sync_id)) + if self._should_raise: + raise self._should_raise + if self._get_result is None: + raise ValueError(f"Sync {sync_id} not found") + return self._get_result + + async def pause( + self, sync_id: UUID, ctx: BaseContext, *, reason: str = "" + ) -> SyncTransitionResult: + self._calls.append(("pause", sync_id, reason)) + if self._should_raise: + raise self._should_raise + return SyncTransitionResult( + applied=True, previous=SyncStatus.ACTIVE, current=SyncStatus.PAUSED + ) + + async def resume( + self, sync_id: UUID, ctx: BaseContext, *, reason: str = "" + ) -> SyncTransitionResult: + self._calls.append(("resume", sync_id, reason)) + if self._should_raise: + raise self._should_raise + return SyncTransitionResult( + applied=True, previous=SyncStatus.PAUSED, current=SyncStatus.ACTIVE + ) + + async def delete( + self, + db: AsyncSession, + *, + sync_id: UUID, + collection_id: UUID, + organization_id: UUID, + ctx: ApiContext, + cancel_timeout_seconds: int = 15, + ) -> None: + self._calls.append(("delete", sync_id, collection_id, organization_id)) + if self._should_raise: + raise self._should_raise + + # -- Jobs -- + + async def resolve_destination_ids(self, db: AsyncSession, ctx: ApiContext) -> List[UUID]: + self._calls.append(("resolve_destination_ids",)) + if self._should_raise: + raise self._should_raise + if self._resolve_dest_ids is not None: + return self._resolve_dest_ids + from airweave.core.constants.reserved_ids import NATIVE_VESPA_UUID + + return [NATIVE_VESPA_UUID] + + async def trigger_run( + self, + db: AsyncSession, + *, + sync_id: UUID, + collection: schemas.CollectionRecord, + connection: schemas.Connection, + ctx: ApiContext, + force_full_sync: bool = False, + ) -> Tuple[schemas.Sync, schemas.SyncJob]: + self._calls.append(("trigger_run", sync_id)) + if self._should_raise: + raise self._should_raise + if self._trigger_run_result is None: + raise RuntimeError("FakeSyncService.trigger_run_result not configured") + return self._trigger_run_result + + async def get_jobs( + self, db: AsyncSession, *, sync_id: UUID, ctx: ApiContext, limit: int = 100 + ) -> List[schemas.SyncJob]: + self._calls.append(("get_jobs", sync_id, limit)) + if self._should_raise: + raise self._should_raise + return self._jobs.get(sync_id, [])[:limit] + + async def cancel_job( + self, db: AsyncSession, *, job_id: UUID, ctx: ApiContext + ) -> schemas.SyncJob: + self._calls.append(("cancel_job", job_id)) + if self._should_raise: + raise self._should_raise + if self._cancel_result is None: + raise RuntimeError("FakeSyncService.cancel_result not configured") + return self._cancel_result + + async def validate_force_full_sync( + self, db: AsyncSession, sync_id: UUID, ctx: ApiContext + ) -> None: + self._calls.append(("validate_force_full_sync", sync_id)) + + # -- Execution -- async def run( self, @@ -22,6 +186,9 @@ async def run( ctx: ApiContext, force_full_sync: bool = False, execution_config: Optional[SyncConfig] = None, + access_token: Optional[str] = None, ) -> schemas.Sync: self._calls.append(("run", sync, sync_job)) - return sync + if self._should_raise: + raise self._should_raise + return self._run_result or sync diff --git a/backend/airweave/domains/syncs/jobs/fakes/repository.py b/backend/airweave/domains/syncs/jobs/fakes/repository.py index 16867e26f..7419acea6 100644 --- a/backend/airweave/domains/syncs/jobs/fakes/repository.py +++ b/backend/airweave/domains/syncs/jobs/fakes/repository.py @@ -54,11 +54,19 @@ async def get_active_for_sync( return [j for j in jobs if j.status in ("PENDING", "RUNNING", "CANCELLING")] async def get_all_by_sync_id( - self, db: AsyncSession, sync_id: UUID, ctx: ApiContext + self, + db: AsyncSession, + sync_id: UUID, + ctx: ApiContext, + limit: Optional[int] = None, ) -> List[SyncJob]: - """Return all seeded jobs for the sync.""" - self._calls.append(("get_all_by_sync_id", db, sync_id, ctx)) - return self._by_sync.get(sync_id, []) + """Return all seeded jobs for the sync (newest first; optional limit).""" + self._calls.append(("get_all_by_sync_id", db, sync_id, ctx, limit)) + jobs = list(self._by_sync.get(sync_id, [])) + jobs.sort(key=lambda j: j.created_at, reverse=True) + if limit is not None: + jobs = jobs[:limit] + return jobs async def create( self, diff --git a/backend/airweave/domains/syncs/jobs/protocols.py b/backend/airweave/domains/syncs/jobs/protocols.py index 59eb906fe..3143426f7 100644 --- a/backend/airweave/domains/syncs/jobs/protocols.py +++ b/backend/airweave/domains/syncs/jobs/protocols.py @@ -35,7 +35,11 @@ async def get_active_for_sync( ... async def get_all_by_sync_id( - self, db: AsyncSession, sync_id: UUID, ctx: BaseContext + self, + db: AsyncSession, + sync_id: UUID, + ctx: BaseContext, + limit: Optional[int] = None, ) -> List[SyncJob]: """Get all jobs for a specific sync.""" ... diff --git a/backend/airweave/domains/syncs/jobs/repository.py b/backend/airweave/domains/syncs/jobs/repository.py index 9a504b4ee..66d0000aa 100644 --- a/backend/airweave/domains/syncs/jobs/repository.py +++ b/backend/airweave/domains/syncs/jobs/repository.py @@ -8,6 +8,7 @@ from airweave import crud from airweave.api.context import ApiContext +from airweave.core.context import BaseContext from airweave.core.shared_models import SyncJobStatus from airweave.db.unit_of_work import UnitOfWork from airweave.domains.syncs.jobs.protocols import SyncJobRepositoryProtocol @@ -41,10 +42,14 @@ async def get_active_for_sync( ) async def get_all_by_sync_id( - self, db: AsyncSession, sync_id: UUID, ctx: ApiContext + self, + db: AsyncSession, + sync_id: UUID, + ctx: BaseContext, + limit: Optional[int] = None, ) -> List[SyncJob]: """Get all jobs for a specific sync.""" - return await crud.sync_job.get_all_by_sync_id(db, sync_id=sync_id) + return await crud.sync_job.get_all_by_sync_id(db, sync_id=sync_id, limit=limit) async def create( self, diff --git a/backend/airweave/domains/syncs/lifecycle_service.py b/backend/airweave/domains/syncs/lifecycle_service.py deleted file mode 100644 index e2912a462..000000000 --- a/backend/airweave/domains/syncs/lifecycle_service.py +++ /dev/null @@ -1,466 +0,0 @@ -"""Sync lifecycle service: provision, run, get_jobs, cancel_job, teardown.""" - -import asyncio -import re -from datetime import datetime, timezone -from typing import List, Optional -from uuid import UUID - -from fastapi import HTTPException -from sqlalchemy.ext.asyncio import AsyncSession - -from airweave import schemas -from airweave.api.context import ApiContext -from airweave.core.events.sync import SyncLifecycleEvent -from airweave.core.protocols.event_bus import EventBus -from airweave.core.shared_models import SyncJobStatus -from airweave.db.unit_of_work import UnitOfWork -from airweave.domains.collections.protocols import CollectionRepositoryProtocol -from airweave.domains.connections.protocols import ConnectionRepositoryProtocol -from airweave.domains.source_connections.protocols import ( - ResponseBuilderProtocol, - SourceConnectionRepositoryProtocol, -) -from airweave.domains.sources.types import SourceRegistryEntry -from airweave.domains.syncs.jobs.protocols import ( - SyncJobRepositoryProtocol, - SyncJobStateMachineProtocol, -) -from airweave.domains.syncs.protocols import ( - SyncCursorRepositoryProtocol, - SyncLifecycleServiceProtocol, - SyncRecordServiceProtocol, -) -from airweave.domains.syncs.types import ( - CONTINUOUS_SOURCE_DEFAULT_CRON, - DAILY_CRON_TEMPLATE, - SyncProvisionResult, -) -from airweave.domains.temporal.protocols import ( - TemporalScheduleServiceProtocol, - TemporalWorkflowServiceProtocol, -) -from airweave.schemas.source_connection import ScheduleConfig, SourceConnectionJob - -_SUB_HOURLY_PATTERN = re.compile(r"^\*/([1-5]?[0-9]) \* \* \* \*$") - - -class SyncLifecycleService(SyncLifecycleServiceProtocol): - """API-facing facade for sync lifecycle: provision, run, get_jobs, cancel_job.""" - - def __init__( - self, - sc_repo: SourceConnectionRepositoryProtocol, - collection_repo: CollectionRepositoryProtocol, - connection_repo: ConnectionRepositoryProtocol, - sync_cursor_repo: SyncCursorRepositoryProtocol, - sync_service: SyncRecordServiceProtocol, - state_machine: SyncJobStateMachineProtocol, - sync_job_repo: SyncJobRepositoryProtocol, - temporal_workflow_service: TemporalWorkflowServiceProtocol, - temporal_schedule_service: TemporalScheduleServiceProtocol, - response_builder: ResponseBuilderProtocol, - event_bus: EventBus, - ) -> None: - """Initialize with all injected dependencies.""" - self._sc_repo = sc_repo - self._collection_repo = collection_repo - self._connection_repo = connection_repo - self._sync_cursor_repo = sync_cursor_repo - self._sync_service = sync_service - self._state_machine = state_machine - self._sync_job_repo = sync_job_repo - self._temporal_workflow_service = temporal_workflow_service - self._temporal_schedule_service = temporal_schedule_service - self._response_builder = response_builder - self._event_bus = event_bus - - # ------------------------------------------------------------------ - # Public API (protocol surface) - # ------------------------------------------------------------------ - - async def teardown_syncs_for_collection( - self, - db: AsyncSession, - *, - sync_ids: List[UUID], - collection_id: UUID, - organization_id: UUID, - ctx: ApiContext, - cancel_timeout_seconds: int = 15, - ) -> None: - """Cancel running workflows and schedule async cleanup for a collection's syncs. - - 1. Cancels PENDING/RUNNING workflows via Temporal. - 2. Polls until terminal state (up to cancel_timeout_seconds). - 3. Schedules async cleanup workflow for Vespa/ARF/schedules. - """ - syncs_to_wait = await self._cancel_active_syncs(db, sync_ids, ctx) - await self._wait_for_terminal(db, syncs_to_wait, cancel_timeout_seconds, ctx) - await self._schedule_collection_cleanup(sync_ids, collection_id, organization_id, ctx) - - async def provision_sync( - self, - db: AsyncSession, - *, - name: str, - source_connection_id: UUID, - destination_connection_ids: List[UUID], - collection_id: UUID, - collection_readable_id: str, - source_entry: SourceRegistryEntry, - schedule_config: Optional[ScheduleConfig], - run_immediately: bool, - ctx: ApiContext, - uow: UnitOfWork, - ) -> Optional[SyncProvisionResult]: - """Create sync + job + Temporal schedule atomically. - - Returns None for federated search sources (no sync needed) - or when there is neither a schedule nor an immediate run. - """ - if source_entry.federated_search: - ctx.logger.info(f"Skipping sync for federated source '{source_entry.short_name}'") - return None - - cron = self._resolve_cron(schedule_config, source_entry, ctx) - - if not cron and not run_immediately: - ctx.logger.info("No cron schedule and run_immediately=False, skipping sync creation") - return None - - if cron: - self._validate_cron_for_source(cron, source_entry) - - sync_schema, sync_job_schema = await self._sync_service.create_sync( - uow.session, - name=f"Sync for {name}", - source_connection_id=source_connection_id, - destination_connection_ids=destination_connection_ids, - cron_schedule=cron, - run_immediately=run_immediately, - ctx=ctx, - uow=uow, - ) - - if cron: - await self._temporal_schedule_service.create_or_update_schedule( - sync_id=sync_schema.id, - cron_schedule=cron, - db=uow.session, - ctx=ctx, - uow=uow, - collection_readable_id=collection_readable_id, - connection_id=source_connection_id, - ) - - return SyncProvisionResult( - sync_id=sync_schema.id, - sync=sync_schema, - sync_job=sync_job_schema, - cron_schedule=cron, - ) - - async def run( - self, - db: AsyncSession, - *, - id: UUID, - ctx: ApiContext, - force_full_sync: bool = False, - ) -> SourceConnectionJob: - """Trigger a sync run for a source connection. - - Args: - db: Database session. - id: Source connection ID. - ctx: API context. - force_full_sync: Only valid for continuous syncs. - """ - source_conn = await self._sc_repo.get(db, id, ctx) - if not source_conn: - raise HTTPException(status_code=404, detail="Source connection not found") - if not source_conn.sync_id: - raise HTTPException(status_code=400, detail="Source connection has no associated sync") - - sc_id = source_conn.id - sc_sync_id = source_conn.sync_id - - if force_full_sync: - await self._validate_force_full_sync(db, sc_sync_id, ctx) - - collection = await self._collection_repo.get_by_readable_id( - db, source_conn.readable_collection_id, ctx - ) - collection_schema = schemas.CollectionRecord.model_validate( - collection, from_attributes=True - ) - - connection_schema = await self._resolve_connection(db, source_conn, ctx) - - sync, sync_job = await self._sync_service.trigger_sync_run(db, sync_id=sc_sync_id, ctx=ctx) - sync_job_schema = schemas.SyncJob.model_validate(sync_job, from_attributes=True) - - await self._event_bus.publish( - SyncLifecycleEvent.pending( - organization_id=ctx.organization.id, - source_connection_id=sc_id, - sync_job_id=sync_job_schema.id, - sync_id=sc_sync_id, - collection_id=collection_schema.id, - source_type=connection_schema.short_name, - collection_name=collection_schema.name, - collection_readable_id=collection_schema.readable_id, - ) - ) - - await self._temporal_workflow_service.run_source_connection_workflow( - sync=sync, - sync_job=sync_job, - collection=collection_schema, - connection=connection_schema, - ctx=ctx, - force_full_sync=force_full_sync, - ) - - return sync_job_schema.to_source_connection_job(sc_id) - - async def get_jobs( - self, - db: AsyncSession, - *, - id: UUID, - ctx: ApiContext, - limit: int = 100, - ) -> List[SourceConnectionJob]: - """Get sync jobs for a source connection.""" - source_conn = await self._sc_repo.get(db, id, ctx) - if not source_conn: - raise HTTPException(status_code=404, detail="Source connection not found") - if not source_conn.sync_id: - return [] - - jobs = await self._sync_job_repo.get_all_by_sync_id(db, source_conn.sync_id, ctx) - return [self._response_builder.map_sync_job(j, source_conn.id) for j in jobs] - - async def cancel_job( - self, - db: AsyncSession, - *, - source_connection_id: UUID, - job_id: UUID, - ctx: ApiContext, - ) -> SourceConnectionJob: - """Cancel a running sync job. - - Sets CANCELLING, sends cancel to Temporal, and handles - edge cases (workflow not found, Temporal failure). - """ - source_conn = await self._sc_repo.get(db, source_connection_id, ctx) - if not source_conn: - raise HTTPException(status_code=404, detail="Source connection not found") - if not source_conn.sync_id: - raise HTTPException(status_code=400, detail="Source connection has no associated sync") - - sync_job = await self._sync_job_repo.get(db, job_id, ctx) - if not sync_job: - raise HTTPException(status_code=404, detail="Sync job not found") - if sync_job.sync_id != source_conn.sync_id: - raise HTTPException( - status_code=400, - detail="Sync job does not belong to this source connection", - ) - if sync_job.status not in (SyncJobStatus.PENDING, SyncJobStatus.RUNNING): - raise HTTPException( - status_code=400, - detail=f"Cannot cancel job in {sync_job.status} state", - ) - - await self._state_machine.transition( - sync_job_id=job_id, target=SyncJobStatus.CANCELLING, ctx=ctx - ) - - cancel_result = await self._temporal_workflow_service.cancel_sync_job_workflow( - str(job_id), ctx - ) - - if not cancel_result["success"]: - raise HTTPException( - status_code=502, detail="Failed to request cancellation from Temporal" - ) - - if not cancel_result["workflow_found"]: - ctx.logger.info(f"Workflow not found for job {job_id} - marking CANCELLED directly") - await self._state_machine.transition( - sync_job_id=job_id, - target=SyncJobStatus.CANCELLED, - ctx=ctx, - error="Workflow not found in Temporal - may have already completed", - ) - - await db.refresh(sync_job) - sync_job_schema = schemas.SyncJob.model_validate(sync_job, from_attributes=True) - return sync_job_schema.to_source_connection_job(source_connection_id) - - # ------------------------------------------------------------------ - # Private helpers - # ------------------------------------------------------------------ - - async def _cancel_active_syncs( - self, - db: AsyncSession, - sync_ids: List[UUID], - ctx: ApiContext, - ) -> List[UUID]: - """Cancel PENDING/RUNNING jobs and return IDs that need waiting.""" - non_terminal = {SyncJobStatus.PENDING, SyncJobStatus.RUNNING, SyncJobStatus.CANCELLING} - syncs_to_wait: List[UUID] = [] - for sync_id in sync_ids: - latest_job = await self._sync_job_repo.get_latest_by_sync_id(db, sync_id=sync_id) - if not latest_job or latest_job.status not in non_terminal: - continue - if latest_job.status in (SyncJobStatus.PENDING, SyncJobStatus.RUNNING): - try: - await self._temporal_workflow_service.cancel_sync_job_workflow( - str(latest_job.id), ctx - ) - ctx.logger.info(f"Cancelled job {latest_job.id} before deletion") - except Exception as e: - ctx.logger.warning(f"Failed to cancel job {latest_job.id}: {e}") - syncs_to_wait.append(sync_id) - return syncs_to_wait - - async def _wait_for_terminal( - self, - db: AsyncSession, - syncs_to_wait: List[UUID], - timeout_seconds: int, - ctx: ApiContext, - ) -> None: - """Poll until all syncs reach a terminal state or timeout.""" - if not syncs_to_wait: - return - terminal = {SyncJobStatus.COMPLETED, SyncJobStatus.FAILED, SyncJobStatus.CANCELLED} - elapsed = 0.0 - remaining = list(syncs_to_wait) - while elapsed < timeout_seconds and remaining: - await asyncio.sleep(1.0) - elapsed += 1.0 - db.expire_all() - still_waiting = [] - for sid in remaining: - job = await self._sync_job_repo.get_latest_by_sync_id(db, sync_id=sid) - if job and job.status not in terminal: - still_waiting.append(sid) - remaining = still_waiting - if remaining: - ctx.logger.warning( - f"{len(remaining)} sync(s) did not reach terminal state " - f"within {timeout_seconds}s -- proceeding with deletion anyway" - ) - - async def _schedule_collection_cleanup( - self, - sync_ids: List[UUID], - collection_id: UUID, - organization_id: UUID, - ctx: ApiContext, - ) -> None: - """Schedule a Temporal workflow for async Vespa/ARF cleanup.""" - if not sync_ids: - return - try: - await self._temporal_workflow_service.start_cleanup_sync_data_workflow( - sync_ids=[str(sid) for sid in sync_ids], - collection_id=str(collection_id), - organization_id=str(organization_id), - ctx=ctx, - ) - except Exception as e: - ctx.logger.error( - f"Failed to schedule async cleanup for collection {collection_id}: {e}. " - f"Data may be orphaned in Vespa/ARF." - ) - - async def _validate_force_full_sync( - self, db: AsyncSession, sync_id: UUID, ctx: ApiContext - ) -> None: - """Log force_full_sync intent. No-op if no cursor (already a full sync).""" - cursor = await self._sync_cursor_repo.get_by_sync_id(db, sync_id, ctx) - if not cursor or not cursor.cursor_data: - ctx.logger.info( - f"force_full_sync requested but no cursor data exists for sync {sync_id}. " - "This sync will perform a full sync by default." - ) - return - ctx.logger.info( - f"Force full sync requested for continuous sync {sync_id}. " - "Will ignore cursor data and perform full sync with orphaned entity cleanup." - ) - - async def _resolve_connection( - self, db: AsyncSession, source_conn, ctx: ApiContext - ) -> schemas.Connection: - """Resolve the Connection (not SourceConnection!) for a source connection.""" - if not source_conn.connection_id: - raise ValueError(f"Source connection {source_conn.id} has no connection_id") - conn = await self._connection_repo.get(db, source_conn.connection_id, ctx) - if not conn: - raise ValueError(f"Connection {source_conn.connection_id} not found") - return schemas.Connection.model_validate(conn, from_attributes=True) - - def _resolve_cron( - self, - schedule_config: Optional[ScheduleConfig], - source_entry: SourceRegistryEntry, - ctx: ApiContext, - ) -> Optional[str]: - """Resolve cron schedule from config or source defaults. - - When schedule_config is provided: - - cron is a string → use it - - cron is None → caller explicitly wants no schedule - When schedule_config is None → apply source-type defaults. - """ - if schedule_config is not None: - if schedule_config.cron is not None: - return schedule_config.cron - ctx.logger.info("Schedule cron explicitly null, no schedule") - return None - - if source_entry.supports_continuous: - ctx.logger.info("Continuous source, defaulting to 5-minute schedule") - return CONTINUOUS_SOURCE_DEFAULT_CRON - - now_utc = datetime.now(timezone.utc) - cron = DAILY_CRON_TEMPLATE.format(minute=now_utc.minute, hour=now_utc.hour) - ctx.logger.info(f"Defaulting to daily at {now_utc.hour:02d}:{now_utc.minute:02d} UTC") - return cron - - def _validate_cron_for_source( - self, - cron: str, - source_entry: SourceRegistryEntry, - ) -> None: - """Reject sub-hourly schedules for non-continuous sources.""" - if source_entry.supports_continuous: - return - - if cron == "* * * * *": - raise HTTPException( - status_code=400, - detail=( - f"Source '{source_entry.short_name}' does not support " - f"continuous syncs. Minimum interval is 1 hour." - ), - ) - - match = _SUB_HOURLY_PATTERN.match(cron) - if match and int(match.group(1)) < 60: - raise HTTPException( - status_code=400, - detail=( - f"Source '{source_entry.short_name}' does not support " - f"continuous syncs. Minimum interval is 1 hour." - ), - ) diff --git a/backend/airweave/domains/syncs/protocols.py b/backend/airweave/domains/syncs/protocols.py index eeda5e2bc..7d17b55fe 100644 --- a/backend/airweave/domains/syncs/protocols.py +++ b/backend/airweave/domains/syncs/protocols.py @@ -17,7 +17,7 @@ from airweave.domains.syncs.types import SyncProvisionResult, SyncTransitionResult from airweave.models.sync import Sync from airweave.models.sync_cursor import SyncCursor -from airweave.schemas.source_connection import ScheduleConfig, SourceConnectionJob +from airweave.schemas.source_connection import ScheduleConfig from airweave.schemas.sync import SyncCreate, SyncUpdate @@ -79,147 +79,148 @@ async def get_by_sync_id( ... -class SyncRecordServiceProtocol(Protocol): - """Sync record management: create syncs and trigger runs.""" +class SyncStateMachineProtocol(Protocol): + """Validated, idempotent sync status transitions with schedule side effects.""" - async def resolve_destination_ids(self, db: AsyncSession, ctx: ApiContext) -> List[UUID]: - """Resolve destination connection IDs based on feature flags.""" + async def transition( + self, + sync_id: UUID, + target: SyncStatus, + ctx: BaseContext, + *, + reason: str = "", + ) -> SyncTransitionResult: + """Execute a validated, idempotent sync status transition. + + Side effects (schedule pause/unpause) run after the DB commit. + """ ... - async def create_sync( + +class SyncServiceProtocol(Protocol): + """Unified sync service — the public interface for the syncs domain. + + Provides lifecycle (create, get, pause, resume, delete), job management + (trigger_run, get_jobs, cancel_job), and execution (run) operations. + All methods speak the sync domain language; no source_connection types + cross this boundary. + """ + + # -- Lifecycle -- + + async def create( self, db: AsyncSession, *, name: str, source_connection_id: UUID, destination_connection_ids: List[UUID], - cron_schedule: Optional[str], + collection_id: UUID, + collection_readable_id: str, + source_entry: SourceRegistryEntry, + schedule_config: Optional[ScheduleConfig], run_immediately: bool, ctx: ApiContext, uow: UnitOfWork, - ) -> Tuple[schemas.Sync, Optional[schemas.SyncJob]]: - """Create a Sync record and optionally a PENDING SyncJob. - - All writes happen inside the caller's UoW (no commit). - """ + ) -> SyncProvisionResult: + """Create sync + optional job + Temporal schedule atomically.""" ... - async def trigger_sync_run( - self, - db: AsyncSession, - sync_id: UUID, - ctx: ApiContext, - ) -> Tuple[schemas.Sync, schemas.SyncJob]: - """Trigger a manual sync run. - - Returns (sync_schema, sync_job_schema). - Raises HTTPException if a job is already active. - """ + async def get(self, db: AsyncSession, *, sync_id: UUID, ctx: BaseContext) -> schemas.Sync: + """Get a sync by ID.""" ... - -class SyncStateMachineProtocol(Protocol): - """Validated, idempotent sync status transitions with schedule side effects.""" - - async def transition( + async def pause( self, sync_id: UUID, - target: SyncStatus, ctx: BaseContext, *, reason: str = "", ) -> SyncTransitionResult: - """Execute a validated, idempotent sync status transition. - - Side effects (schedule pause/unpause) run after the DB commit. - """ + """Pause a sync.""" ... - -class SyncServiceProtocol(Protocol): - """Sync execution: build orchestrator and run.""" - - async def run( + async def resume( self, - sync: schemas.Sync, - sync_job: schemas.SyncJob, - collection: schemas.CollectionRecord, - source_connection: schemas.Connection, + sync_id: UUID, ctx: BaseContext, - force_full_sync: bool = False, - execution_config: Optional[SyncConfig] = None, - access_token: Optional[str] = None, - ) -> schemas.Sync: - """Run a sync via SyncFactory + SyncOrchestrator.""" + *, + reason: str = "", + ) -> SyncTransitionResult: + """Resume a paused sync.""" ... - -class SyncLifecycleServiceProtocol(Protocol): - """Sync lifecycle: provision, run, get jobs, cancel, teardown.""" - - async def teardown_syncs_for_collection( + async def delete( self, db: AsyncSession, *, - sync_ids: List[UUID], + sync_id: UUID, collection_id: UUID, organization_id: UUID, ctx: ApiContext, cancel_timeout_seconds: int = 15, ) -> None: - """Cancel running workflows and schedule async cleanup for a collection's syncs.""" + """Cancel active workflows and schedule async cleanup.""" ... - async def provision_sync( - self, - db: AsyncSession, - *, - name: str, - source_connection_id: UUID, - destination_connection_ids: List[UUID], - collection_id: UUID, - collection_readable_id: str, - source_entry: SourceRegistryEntry, - schedule_config: Optional[ScheduleConfig], - run_immediately: bool, - ctx: ApiContext, - uow: UnitOfWork, - ) -> Optional[SyncProvisionResult]: - """Create sync + job + Temporal schedule atomically. + # -- Jobs -- - Returns None for federated search sources (no sync needed). - """ + async def resolve_destination_ids(self, db: AsyncSession, ctx: ApiContext) -> List[UUID]: + """Resolve destination connection IDs (interim — will move to a registry).""" ... - async def run( + async def trigger_run( self, db: AsyncSession, *, - id: UUID, + sync_id: UUID, + collection: schemas.CollectionRecord, + connection: schemas.Connection, ctx: ApiContext, force_full_sync: bool = False, - ) -> SourceConnectionJob: - """Trigger a sync run for a source connection.""" + ) -> Tuple[schemas.Sync, schemas.SyncJob]: + """Create a PENDING job and start the Temporal workflow.""" ... async def get_jobs( self, db: AsyncSession, *, - id: UUID, + sync_id: UUID, ctx: ApiContext, limit: int = 100, - ) -> List[SourceConnectionJob]: - """Get sync jobs for a source connection.""" + ) -> List[schemas.SyncJob]: + """List jobs for a sync.""" ... async def cancel_job( self, db: AsyncSession, *, - source_connection_id: UUID, job_id: UUID, ctx: ApiContext, - ) -> SourceConnectionJob: - """Cancel a running sync job for a source connection.""" + ) -> schemas.SyncJob: + """Cancel a running sync job.""" + ... + + async def validate_force_full_sync( + self, db: AsyncSession, sync_id: UUID, ctx: ApiContext + ) -> None: + """Validate and log force_full_sync intent.""" + ... + + # -- Execution -- + + async def run( + self, + sync: schemas.Sync, + sync_job: schemas.SyncJob, + collection: schemas.CollectionRecord, + source_connection: schemas.Connection, + ctx: BaseContext, + force_full_sync: bool = False, + execution_config: Optional[SyncConfig] = None, + access_token: Optional[str] = None, + ) -> schemas.Sync: + """Run a sync via SyncFactory + SyncOrchestrator.""" ... diff --git a/backend/airweave/domains/syncs/record_service.py b/backend/airweave/domains/syncs/record_service.py deleted file mode 100644 index eb223b233..000000000 --- a/backend/airweave/domains/syncs/record_service.py +++ /dev/null @@ -1,139 +0,0 @@ -"""Sync record service: create and trigger operations for Sync/SyncJob records.""" - -from typing import List, Optional, Tuple -from uuid import UUID - -from fastapi import HTTPException -from sqlalchemy.ext.asyncio import AsyncSession - -from airweave import schemas -from airweave.api.context import ApiContext -from airweave.core.constants.reserved_ids import NATIVE_VESPA_UUID -from airweave.core.shared_models import SyncJobStatus, SyncStatus -from airweave.db.unit_of_work import UnitOfWork -from airweave.domains.connections.protocols import ConnectionRepositoryProtocol -from airweave.domains.syncs.jobs.protocols import SyncJobRepositoryProtocol -from airweave.domains.syncs.protocols import ( - SyncRecordServiceProtocol, - SyncRepositoryProtocol, -) -from airweave.schemas.sync import SyncCreate -from airweave.schemas.sync_job import SyncJobCreate - - -class SyncRecordService(SyncRecordServiceProtocol): - """Create syncs, trigger sync runs, and list sync jobs via injected repositories.""" - - def __init__( - self, - sync_repo: SyncRepositoryProtocol, - sync_job_repo: SyncJobRepositoryProtocol, - connection_repo: ConnectionRepositoryProtocol, - ) -> None: - """Initialize with injected repositories.""" - self._sync_repo = sync_repo - self._sync_job_repo = sync_job_repo - self._connection_repo = connection_repo - - async def resolve_destination_ids(self, db: AsyncSession, ctx: ApiContext) -> List[UUID]: - """Resolve destination connection IDs.""" - return [NATIVE_VESPA_UUID] - - async def create_sync( - self, - db: AsyncSession, - *, - name: str, - source_connection_id: UUID, - destination_connection_ids: List[UUID], - cron_schedule: Optional[str], - run_immediately: bool, - ctx: ApiContext, - uow: UnitOfWork, - ) -> Tuple[schemas.Sync, Optional[schemas.SyncJob]]: - """Create a Sync record and optionally a PENDING SyncJob. - - All writes happen inside the caller's UoW (no commit). - """ - sync_in = SyncCreate( - name=name, - source_connection_id=source_connection_id, - destination_connection_ids=destination_connection_ids, - cron_schedule=cron_schedule, - status=SyncStatus.ACTIVE, - run_immediately=run_immediately, - ) - - sync_schema = await self._sync_repo.create( - uow.session, - obj_in=sync_in, - ctx=ctx, - uow=uow, - ) - await uow.session.flush() - - sync_job_schema: Optional[schemas.SyncJob] = None - if run_immediately: - sync_job = await self._sync_job_repo.create( - uow.session, - SyncJobCreate(sync_id=sync_schema.id, status=SyncJobStatus.PENDING), - ctx, - uow=uow, - ) - await uow.session.flush() - await uow.session.refresh(sync_job) - sync_job_schema = schemas.SyncJob.model_validate(sync_job, from_attributes=True) - - return sync_schema, sync_job_schema - - async def trigger_sync_run( - self, - db: AsyncSession, - sync_id: UUID, - ctx: ApiContext, - ) -> Tuple[schemas.Sync, schemas.SyncJob]: - """Trigger a manual sync run. - - Checks for existing active jobs, fetches the sync with - connections, creates a new SyncJob inside a UoW, and returns - both schemas. - - Raises: - HTTPException 400: if a job is already active. - ValueError: if the sync is not found. - """ - sync = await self._sync_repo.get(db, sync_id, ctx) - if not sync: - raise ValueError(f"Sync {sync_id} not found") - - if SyncStatus(sync.status) != SyncStatus.ACTIVE: - raise HTTPException( - status_code=409, - detail=f"Cannot trigger sync: sync is {sync.status}", - ) - - active_jobs = await self._sync_job_repo.get_active_for_sync(db, sync_id, ctx) - if active_jobs: - job_status = active_jobs[0].status.lower() - raise HTTPException( - status_code=400, - detail=f"Cannot start new sync: a sync job is already {job_status}", - ) - - sync_schema = schemas.Sync.model_validate(sync, from_attributes=True) - - async with UnitOfWork(db) as uow: - sync_job = await self._sync_job_repo.create( - uow.session, - schemas.SyncJobCreate( - sync_id=sync_id, - status=SyncJobStatus.PENDING, - ), - ctx, - uow=uow, - ) - await uow.commit() - await uow.session.refresh(sync_job) - sync_job_schema = schemas.SyncJob.model_validate(sync_job, from_attributes=True) - - return sync_schema, sync_job_schema diff --git a/backend/airweave/domains/syncs/service.py b/backend/airweave/domains/syncs/service.py index 2685ad624..a1589c07a 100644 --- a/backend/airweave/domains/syncs/service.py +++ b/backend/airweave/domains/syncs/service.py @@ -1,40 +1,382 @@ -"""Sync execution service — runs a sync via SyncFactory + SyncOrchestrator. +"""Unified sync service — single transactional interface for the syncs domain. -Called exclusively from RunSyncActivity (Temporal worker). +Consolidates SyncRecordService, SyncLifecycleService, and the sync runner +into one service with clean, directed semantics. All callers interact through +this interface; internal implementation details (state machines, repos) are hidden. + +Methods speak the sync domain language: create, get, pause, resume, delete, +trigger_run, cancel_job, get_jobs, run. No source_connection types cross +this boundary. """ -from typing import Optional +import asyncio +import re +from datetime import datetime, timezone +from typing import List, Optional, Tuple +from uuid import UUID + +from fastapi import HTTPException +from sqlalchemy.ext.asyncio import AsyncSession from airweave import schemas from airweave.api.context import ApiContext +from airweave.core.constants.reserved_ids import NATIVE_VESPA_UUID +from airweave.core.context import BaseContext from airweave.core.shared_models import SyncJobStatus, SyncStatus from airweave.db.session import get_db_context +from airweave.db.unit_of_work import UnitOfWork from airweave.domains.sources.exceptions.classifier import classify_error +from airweave.domains.sources.types import SourceRegistryEntry from airweave.domains.sync_pipeline.config import SyncConfig from airweave.domains.sync_pipeline.protocols import SyncFactoryProtocol -from airweave.domains.syncs.jobs.protocols import SyncJobStateMachineProtocol +from airweave.domains.syncs.jobs.protocols import ( + SyncJobRepositoryProtocol, + SyncJobStateMachineProtocol, +) from airweave.domains.syncs.protocols import ( + SyncCursorRepositoryProtocol, + SyncRepositoryProtocol, SyncServiceProtocol, SyncStateMachineProtocol, ) +from airweave.domains.syncs.types import ( + CONTINUOUS_SOURCE_DEFAULT_CRON, + DAILY_CRON_TEMPLATE, + SyncProvisionResult, + SyncTransitionResult, +) +from airweave.domains.temporal.protocols import ( + TemporalScheduleServiceProtocol, + TemporalWorkflowServiceProtocol, +) +from airweave.schemas.source_connection import ScheduleConfig +from airweave.schemas.sync import SyncCreate +from airweave.schemas.sync_job import SyncJobCreate + +_SUB_HOURLY_PATTERN = re.compile(r"^\*/([1-5]?[0-9]) \* \* \* \*$") class SyncService(SyncServiceProtocol): - """Runs a sync via SyncFactory + SyncOrchestrator. + """Unified sync service — the public interface for the syncs domain. - Stateless — the only production caller is RunSyncActivity. + Callers use directed methods (create, pause, resume, delete) rather than + raw state transitions. The state machine is an internal implementation detail. """ - def __init__( + def __init__( # noqa: D107 self, - state_machine: SyncJobStateMachineProtocol, + sync_repo: SyncRepositoryProtocol, + sync_job_repo: SyncJobRepositoryProtocol, + sync_cursor_repo: SyncCursorRepositoryProtocol, + state_machine: SyncStateMachineProtocol, + job_state_machine: SyncJobStateMachineProtocol, + temporal_workflow_service: TemporalWorkflowServiceProtocol, + temporal_schedule_service: TemporalScheduleServiceProtocol, sync_factory: SyncFactoryProtocol, - sync_state_machine: SyncStateMachineProtocol, ) -> None: - """Initialize with state machine and factory dependencies.""" + self._sync_repo = sync_repo + self._sync_job_repo = sync_job_repo + self._sync_cursor_repo = sync_cursor_repo self._state_machine = state_machine + self._job_state_machine = job_state_machine + self._temporal_workflow_service = temporal_workflow_service + self._temporal_schedule_service = temporal_schedule_service self._sync_factory = sync_factory - self._sync_state_machine = sync_state_machine + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + async def create( + self, + db: AsyncSession, + *, + name: str, + source_connection_id: UUID, + destination_connection_ids: List[UUID], + collection_id: UUID, + collection_readable_id: str, + source_entry: SourceRegistryEntry, + schedule_config: Optional[ScheduleConfig], + run_immediately: bool, + ctx: ApiContext, + uow: UnitOfWork, + ) -> SyncProvisionResult: + """Create sync + optional job + Temporal schedule atomically. + + Raises ValueError if called for federated sources or when there is + neither a schedule nor an immediate run request — callers must guard + these cases before calling create. + """ + if source_entry.federated_search: + raise ValueError(f"Cannot create sync for federated source '{source_entry.short_name}'") + + cron = self._resolve_cron(schedule_config, source_entry, ctx) + + if not cron and not run_immediately: + raise ValueError("Cannot create sync: no schedule and run_immediately=False") + + if cron: + self._validate_cron_for_source(cron, source_entry) + + sync_schema, sync_job_schema = await self._create_sync_records( + uow.session, + name=f"Sync for {name}", + source_connection_id=source_connection_id, + destination_connection_ids=destination_connection_ids, + cron_schedule=cron, + run_immediately=run_immediately, + ctx=ctx, + uow=uow, + ) + + if cron: + await self._temporal_schedule_service.create_or_update_schedule( + sync_id=sync_schema.id, + cron_schedule=cron, + db=uow.session, + ctx=ctx, + uow=uow, + collection_readable_id=collection_readable_id, + connection_id=source_connection_id, + ) + + return SyncProvisionResult( + sync_id=sync_schema.id, + sync=sync_schema, + sync_job=sync_job_schema, + cron_schedule=cron, + ) + + async def get(self, db: AsyncSession, *, sync_id: UUID, ctx: BaseContext) -> schemas.Sync: + """Get a sync by ID.""" + sync = await self._sync_repo.get(db, sync_id, ctx) + if not sync: + raise HTTPException(status_code=404, detail=f"Sync {sync_id} not found") + return sync + + async def pause( + self, + sync_id: UUID, + ctx: BaseContext, + *, + reason: str = "", + ) -> SyncTransitionResult: + """Pause a sync: update DB status, pause Temporal schedules.""" + return await self._state_machine.transition( + sync_id=sync_id, target=SyncStatus.PAUSED, ctx=ctx, reason=reason + ) + + async def resume( + self, + sync_id: UUID, + ctx: BaseContext, + *, + reason: str = "", + ) -> SyncTransitionResult: + """Resume a paused sync: update DB status, unpause Temporal schedules.""" + return await self._state_machine.transition( + sync_id=sync_id, target=SyncStatus.ACTIVE, ctx=ctx, reason=reason + ) + + async def delete( + self, + db: AsyncSession, + *, + sync_id: UUID, + collection_id: UUID, + organization_id: UUID, + ctx: ApiContext, + cancel_timeout_seconds: int = 15, + ) -> None: + """Cancel active workflows and schedule async cleanup for a single sync. + + 1. Cancels PENDING/RUNNING workflows via Temporal. + 2. Polls until terminal state (up to cancel_timeout_seconds). + 3. Schedules async cleanup workflow for Vespa/ARF/schedules. + + The caller is responsible for the CASCADE delete of DB records. + """ + needs_wait = await self._cancel_active_sync(db, sync_id, ctx) + if needs_wait: + await self._wait_for_terminal(db, sync_id, cancel_timeout_seconds, ctx) + await self._schedule_cleanup(sync_id, collection_id, organization_id, ctx) + + # ------------------------------------------------------------------ + # Jobs + # ------------------------------------------------------------------ + + async def resolve_destination_ids(self, db: AsyncSession, ctx: ApiContext) -> List[UUID]: + """Resolve destination connection IDs.""" + return [NATIVE_VESPA_UUID] + + async def trigger_run( + self, + db: AsyncSession, + *, + sync_id: UUID, + collection: schemas.CollectionRecord, + connection: schemas.Connection, + ctx: ApiContext, + force_full_sync: bool = False, + ) -> Tuple[schemas.Sync, schemas.SyncJob]: + """Create a PENDING job and start the Temporal workflow. + + Validates the sync is ACTIVE and no active jobs exist, creates the + job record, then starts the Temporal workflow. + """ + sync = await self._sync_repo.get(db, sync_id, ctx) + if not sync: + raise HTTPException(status_code=404, detail=f"Sync {sync_id} not found") + + if SyncStatus(sync.status) != SyncStatus.ACTIVE: + raise HTTPException( + status_code=409, + detail=f"Cannot trigger sync: sync is {sync.status}", + ) + + active_jobs = await self._sync_job_repo.get_active_for_sync(db, sync_id, ctx) + if active_jobs: + job_status = active_jobs[0].status.lower() + raise HTTPException( + status_code=400, + detail=f"Cannot start new sync: a sync job is already {job_status}", + ) + + sync_schema = schemas.Sync.model_validate(sync, from_attributes=True) + + sync_job = await self._sync_job_repo.create( + db, + SyncJobCreate(sync_id=sync_id, status=SyncJobStatus.PENDING), + ctx, + ) + await db.flush() + await db.refresh(sync_job) + sync_job_schema = schemas.SyncJob.model_validate(sync_job, from_attributes=True) + + await self._temporal_workflow_service.run_source_connection_workflow( + sync=sync_schema, + sync_job=sync_job_schema, + collection=collection, + connection=connection, + ctx=ctx, + force_full_sync=force_full_sync, + ) + + return sync_schema, sync_job_schema + + async def get_jobs( + self, + db: AsyncSession, + *, + sync_id: UUID, + ctx: ApiContext, + limit: int = 100, + ) -> List[schemas.SyncJob]: + """List jobs for a sync, most recent first.""" + jobs = await self._sync_job_repo.get_all_by_sync_id(db, sync_id, ctx, limit=limit) + return [schemas.SyncJob.model_validate(j, from_attributes=True) for j in jobs] + + async def cancel_job( + self, + db: AsyncSession, + *, + job_id: UUID, + ctx: ApiContext, + ) -> schemas.SyncJob: + """Cancel a pending or running sync job. + + PENDING jobs are transitioned directly to CANCELLED (the CANCELLING + intermediate state is only valid from RUNNING). RUNNING jobs go + through CANCELLING and rely on the Temporal workflow for the final + CANCELLED transition. + """ + sync_job = await self._sync_job_repo.get(db, job_id, ctx) + if not sync_job: + raise HTTPException(status_code=404, detail="Sync job not found") + + if sync_job.status not in (SyncJobStatus.PENDING, SyncJobStatus.RUNNING): + raise HTTPException( + status_code=400, + detail=f"Cannot cancel job in {sync_job.status} state", + ) + + if sync_job.status == SyncJobStatus.PENDING: + # PENDING → CANCELLED directly (PENDING → CANCELLING is invalid) + await self._job_state_machine.transition( + sync_job_id=job_id, target=SyncJobStatus.CANCELLED, ctx=ctx + ) + # Best-effort workflow cleanup — workflow may not exist yet for + # PENDING jobs, but if it does we need to stop it. + cancel_result = await self._cancel_temporal_workflow_with_retry(job_id, ctx) + if not cancel_result["success"] and cancel_result["workflow_found"]: + ctx.logger.warning( + f"Temporal cancel failed for PENDING job {job_id}, " + "workflow may continue running against cancelled job", + ) + await db.refresh(sync_job) + return schemas.SyncJob.model_validate(sync_job, from_attributes=True) + + # RUNNING → CANCELLING; workflow handles CANCELLING → CANCELLED + await self._job_state_machine.transition( + sync_job_id=job_id, target=SyncJobStatus.CANCELLING, ctx=ctx + ) + + cancel_result = await self._cancel_temporal_workflow_with_retry(job_id, ctx) + + if not cancel_result["success"]: + raise HTTPException( + status_code=502, detail="Failed to request cancellation from Temporal" + ) + + if not cancel_result["workflow_found"]: + # NOT_FOUND means the workflow already completed, never started, + # or was cleaned up. Check current DB state before marking cancelled. + await db.refresh(sync_job) + terminal = {SyncJobStatus.COMPLETED, SyncJobStatus.FAILED, SyncJobStatus.CANCELLED} + if sync_job.status not in terminal: + ctx.logger.info(f"Workflow not found for job {job_id}, marking CANCELLED") + await self._job_state_machine.transition( + sync_job_id=job_id, target=SyncJobStatus.CANCELLED, ctx=ctx + ) + + await db.refresh(sync_job) + return schemas.SyncJob.model_validate(sync_job, from_attributes=True) + + async def _cancel_temporal_workflow_with_retry( + self, job_id: UUID, ctx: ApiContext, max_retries: int = 1 + ) -> dict[str, bool]: + """Send cancellation to Temporal with a single retry on RPC failure.""" + for attempt in range(1 + max_retries): + result = await self._temporal_workflow_service.cancel_sync_job_workflow( + str(job_id), ctx + ) + if result["success"] or result["workflow_found"]: + return result + if attempt < max_retries: + await asyncio.sleep(0.5) + ctx.logger.info(f"Retrying cancel for job {job_id} (attempt {attempt + 2})") + return result + + async def validate_force_full_sync( + self, db: AsyncSession, sync_id: UUID, ctx: ApiContext + ) -> None: + """Log force_full_sync intent. No-op if no cursor (already a full sync).""" + cursor = await self._sync_cursor_repo.get_by_sync_id(db, sync_id, ctx) + if not cursor or not cursor.cursor_data: + ctx.logger.info( + f"force_full_sync requested but no cursor data exists for sync {sync_id}. " + "This sync will perform a full sync by default." + ) + return + ctx.logger.info( + f"Force full sync requested for continuous sync {sync_id}. " + "Will ignore cursor data and perform full sync with orphaned entity cleanup." + ) + + # ------------------------------------------------------------------ + # Execution (Temporal activity entry point) + # ------------------------------------------------------------------ async def run( self, @@ -47,7 +389,10 @@ async def run( execution_config: Optional[SyncConfig] = None, access_token: Optional[str] = None, ) -> schemas.Sync: - """Run a sync.""" + """Run a sync via SyncFactory + SyncOrchestrator. + + Called exclusively from RunSyncActivity (Temporal worker). + """ try: async with get_db_context() as db: orchestrator = await self._sync_factory.create_orchestrator( @@ -66,7 +411,7 @@ async def run( classification = classify_error(e) - await self._state_machine.transition( + await self._job_state_machine.transition( sync_job_id=sync_job.id, target=SyncJobStatus.FAILED, ctx=ctx, @@ -76,7 +421,7 @@ async def run( if classification.category is not None and sync: try: - await self._sync_state_machine.transition( + await self._state_machine.transition( sync_id=sync.id, target=SyncStatus.PAUSED, ctx=ctx, @@ -88,3 +433,167 @@ async def run( raise e return await orchestrator.run() + + # ------------------------------------------------------------------ + # Private: record creation + # ------------------------------------------------------------------ + + async def _create_sync_records( + self, + db: AsyncSession, + *, + name: str, + source_connection_id: UUID, + destination_connection_ids: List[UUID], + cron_schedule: Optional[str], + run_immediately: bool, + ctx: ApiContext, + uow: UnitOfWork, + ) -> Tuple[schemas.Sync, Optional[schemas.SyncJob]]: + """Create a Sync record and optionally a PENDING SyncJob. + + All writes happen inside the caller's UoW (no commit). + """ + sync_in = SyncCreate( + name=name, + source_connection_id=source_connection_id, + destination_connection_ids=destination_connection_ids, + cron_schedule=cron_schedule, + status=SyncStatus.ACTIVE, + run_immediately=run_immediately, + ) + + sync_schema = await self._sync_repo.create(uow.session, obj_in=sync_in, ctx=ctx, uow=uow) + await uow.session.flush() + + sync_job_schema: Optional[schemas.SyncJob] = None + if run_immediately: + sync_job = await self._sync_job_repo.create( + uow.session, + SyncJobCreate(sync_id=sync_schema.id, status=SyncJobStatus.PENDING), + ctx, + uow=uow, + ) + await uow.session.flush() + await uow.session.refresh(sync_job) + sync_job_schema = schemas.SyncJob.model_validate(sync_job, from_attributes=True) + + return sync_schema, sync_job_schema + + # ------------------------------------------------------------------ + # Private: cron resolution + # ------------------------------------------------------------------ + + def _resolve_cron( + self, + schedule_config: Optional[ScheduleConfig], + source_entry: SourceRegistryEntry, + ctx: ApiContext, + ) -> Optional[str]: + """Resolve cron schedule from config or source defaults.""" + if schedule_config is not None: + if schedule_config.cron is not None: + return schedule_config.cron + ctx.logger.info("Schedule cron explicitly null, no schedule") + return None + + if source_entry.supports_continuous: + ctx.logger.info("Continuous source, defaulting to 5-minute schedule") + return CONTINUOUS_SOURCE_DEFAULT_CRON + + now_utc = datetime.now(timezone.utc) + cron = DAILY_CRON_TEMPLATE.format(minute=now_utc.minute, hour=now_utc.hour) + ctx.logger.info(f"Defaulting to daily at {now_utc.hour:02d}:{now_utc.minute:02d} UTC") + return cron + + def _validate_cron_for_source(self, cron: str, source_entry: SourceRegistryEntry) -> None: + """Reject sub-hourly schedules for non-continuous sources.""" + if source_entry.supports_continuous: + return + + if cron == "* * * * *": + raise HTTPException( + status_code=400, + detail=( + f"Source '{source_entry.short_name}' does not support " + f"continuous syncs. Minimum interval is 1 hour." + ), + ) + + match = _SUB_HOURLY_PATTERN.match(cron) + if match and int(match.group(1)) < 60: + raise HTTPException( + status_code=400, + detail=( + f"Source '{source_entry.short_name}' does not support " + f"continuous syncs. Minimum interval is 1 hour." + ), + ) + + # ------------------------------------------------------------------ + # Private: delete helpers + # ------------------------------------------------------------------ + + async def _cancel_active_sync( + self, + db: AsyncSession, + sync_id: UUID, + ctx: ApiContext, + ) -> bool: + """Cancel PENDING/RUNNING job for a sync. Returns True if it needs waiting.""" + non_terminal = {SyncJobStatus.PENDING, SyncJobStatus.RUNNING, SyncJobStatus.CANCELLING} + latest_job = await self._sync_job_repo.get_latest_by_sync_id(db, sync_id=sync_id) + if not latest_job or latest_job.status not in non_terminal: + return False + if latest_job.status in (SyncJobStatus.PENDING, SyncJobStatus.RUNNING): + try: + await self._temporal_workflow_service.cancel_sync_job_workflow( + str(latest_job.id), ctx + ) + ctx.logger.info(f"Cancelled job {latest_job.id} before deletion") + except Exception as e: + ctx.logger.warning(f"Failed to cancel job {latest_job.id}: {e}") + return True + + async def _wait_for_terminal( + self, + db: AsyncSession, + sync_id: UUID, + timeout_seconds: int, + ctx: ApiContext, + ) -> None: + """Poll until the sync's latest job reaches a terminal state or timeout.""" + terminal = {SyncJobStatus.COMPLETED, SyncJobStatus.FAILED, SyncJobStatus.CANCELLED} + elapsed = 0.0 + while elapsed < timeout_seconds: + await asyncio.sleep(1.0) + elapsed += 1.0 + db.expire_all() + job = await self._sync_job_repo.get_latest_by_sync_id(db, sync_id=sync_id) + if not job or job.status in terminal: + return + ctx.logger.warning( + f"Sync {sync_id} did not reach terminal state " + f"within {timeout_seconds}s -- proceeding with deletion anyway" + ) + + async def _schedule_cleanup( + self, + sync_id: UUID, + collection_id: UUID, + organization_id: UUID, + ctx: ApiContext, + ) -> None: + """Schedule a Temporal workflow for async Vespa/ARF cleanup.""" + try: + await self._temporal_workflow_service.start_cleanup_sync_data_workflow( + sync_ids=[str(sync_id)], + collection_id=str(collection_id), + organization_id=str(organization_id), + ctx=ctx, + ) + except Exception as e: + ctx.logger.error( + f"Failed to schedule async cleanup for sync {sync_id}: {e}. " + f"Data may be orphaned in Vespa/ARF." + ) diff --git a/backend/airweave/domains/syncs/tests/test_lifecycle_service.py b/backend/airweave/domains/syncs/tests/test_lifecycle_service.py deleted file mode 100644 index 7a926292c..000000000 --- a/backend/airweave/domains/syncs/tests/test_lifecycle_service.py +++ /dev/null @@ -1,876 +0,0 @@ -"""Table-driven unit tests for SyncLifecycleService. - -Covers provision_sync(), run(), get_jobs(), and cancel_job() with -happy paths and error edge cases. -""" - -from dataclasses import dataclass, field -from datetime import datetime, timezone -from typing import Optional -from unittest.mock import AsyncMock, MagicMock, patch -from uuid import UUID, uuid4 - -import pytest -from fastapi import HTTPException - -from airweave import schemas -from airweave.api.context import ApiContext -from airweave.core.logging import logger -from airweave.core.shared_models import AuthMethod, SyncJobStatus, SyncStatus -from airweave.domains.collections.fakes.repository import FakeCollectionRepository -from airweave.domains.connections.fakes.repository import FakeConnectionRepository -from airweave.domains.source_connections.fakes.repository import ( - FakeSourceConnectionRepository, -) -from airweave.domains.source_connections.fakes.response import FakeResponseBuilder -from airweave.domains.sources.types import SourceRegistryEntry -from airweave.domains.syncs.fakes.cursor_repository import FakeSyncCursorRepository -from airweave.domains.syncs.jobs.fakes.repository import FakeSyncJobRepository -from airweave.domains.syncs.fakes.record_service import FakeSyncRecordService -from airweave.domains.syncs.lifecycle_service import SyncLifecycleService -from airweave.domains.syncs.types import CONTINUOUS_SOURCE_DEFAULT_CRON, SyncProvisionResult -from airweave.domains.temporal.fakes.schedule_service import FakeTemporalScheduleService -from airweave.domains.temporal.fakes.service import FakeTemporalWorkflowService -from airweave.models.collection import Collection # spec only -from airweave.models.connection import Connection # spec only -from airweave.models.source_connection import SourceConnection # spec only -from airweave.models.sync_cursor import SyncCursor # spec only -from airweave.models.sync_job import SyncJob # spec only -from airweave.platform.configs._base import Fields -from airweave.schemas.organization import Organization -from airweave.schemas.source_connection import ScheduleConfig - -NOW = datetime.now(timezone.utc) -ORG_ID = uuid4() -SC_ID = uuid4() -SYNC_ID = uuid4() -JOB_ID = uuid4() -COLLECTION_ID = uuid4() -CONNECTION_ID = uuid4() -DEST_CONN_ID = uuid4() - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _ctx() -> ApiContext: - org = Organization(id=str(ORG_ID), name="Test Org", created_at=NOW, modified_at=NOW) - return ApiContext( - request_id="test-req", - organization=org, - auth_method=AuthMethod.SYSTEM, - logger=logger.with_context(request_id="test-req"), - ) - - -def _source_connection( - id: UUID = SC_ID, - sync_id: Optional[UUID] = SYNC_ID, - connection_id: UUID = CONNECTION_ID, -) -> MagicMock: - sc = MagicMock(spec=SourceConnection) - sc.id = id - sc.sync_id = sync_id - sc.connection_id = connection_id - sc.readable_collection_id = "test-collection" - sc.organization_id = ORG_ID - return sc - - -def _collection() -> MagicMock: - col = MagicMock(spec=Collection) - col.id = COLLECTION_ID - col.name = "Test Collection" - col.readable_id = "test-collection" - col.organization_id = ORG_ID - col.vector_db_deployment_metadata_id = uuid4() - col.sync_config = None - col.created_by_email = None - col.modified_by_email = None - return col - - -def _connection() -> MagicMock: - conn = MagicMock(spec=Connection) - conn.id = CONNECTION_ID - conn.name = "Test Connection" - conn.readable_id = "test-connection-abc123" - conn.description = None - conn.short_name = "github" - conn.integration_type = "source" - conn.integration_credential_id = None - conn.status = "active" - conn.organization_id = ORG_ID - conn.created_at = NOW - conn.modified_at = NOW - conn.created_by_email = None - conn.modified_by_email = None - return conn - - -def _sync_job( - id: UUID = JOB_ID, - sync_id: UUID = SYNC_ID, - status: SyncJobStatus = SyncJobStatus.PENDING, -) -> MagicMock: - job = MagicMock(spec=SyncJob) - job.id = id - job.sync_id = sync_id - job.status = status - job.organization_id = ORG_ID - job.created_at = NOW - job.modified_at = NOW - job.created_by_email = "test@example.com" - job.modified_by_email = "test@example.com" - return job - - -class _FakeEventBus: - """Minimal fake for EventBus.publish.""" - - def __init__(self) -> None: - self.events: list = [] - - async def publish(self, event) -> None: - self.events.append(event) - - -def _build_service( - sc_repo=None, - collection_repo=None, - connection_repo=None, - sync_cursor_repo=None, - sync_service=None, - state_machine=None, - sync_job_repo=None, - temporal_workflow_service=None, - temporal_schedule_service=None, - response_builder=None, - event_bus=None, -) -> SyncLifecycleService: - return SyncLifecycleService( - sc_repo=sc_repo or FakeSourceConnectionRepository(), - collection_repo=collection_repo or FakeCollectionRepository(), - connection_repo=connection_repo or FakeConnectionRepository(), - sync_cursor_repo=sync_cursor_repo or FakeSyncCursorRepository(), - sync_service=sync_service or FakeSyncRecordService(), - state_machine=state_machine or AsyncMock(), - sync_job_repo=sync_job_repo or FakeSyncJobRepository(), - temporal_workflow_service=temporal_workflow_service or FakeTemporalWorkflowService(), - temporal_schedule_service=temporal_schedule_service or FakeTemporalScheduleService(), - response_builder=response_builder or FakeResponseBuilder(), - event_bus=event_bus or _FakeEventBus(), - ) - - -# --------------------------------------------------------------------------- -# Source entry helper -# --------------------------------------------------------------------------- - - -def _source_entry( - short_name: str = "github", - supports_continuous: bool = False, - federated_search: bool = False, -) -> SourceRegistryEntry: - """Create a minimal SourceRegistryEntry for testing.""" - empty_fields = Fields(fields=[]) - return SourceRegistryEntry( - name="Test Source", - short_name=short_name, - description="Test source for unit tests", - class_name="FakeSource", - source_class_ref=type("FakeSource", (), {}), - config_ref=None, - auth_config_ref=None, - auth_fields=empty_fields, - config_fields=empty_fields, - supported_auth_providers=[], - runtime_auth_all_fields=[], - runtime_auth_optional_fields=set(), - auth_methods=None, - oauth_type=None, - requires_byoc=False, - supports_continuous=supports_continuous, - supports_cursor=False, - federated_search=federated_search, - supports_temporal_relevance=False, - supports_access_control=False, - rate_limit_level=None, - feature_flag=None, - labels=None, - output_entity_definitions=[], - ) - - -def _sync_schema(id: UUID = SYNC_ID) -> schemas.Sync: - """Create a minimal Sync schema for testing.""" - return schemas.Sync( - id=id, - name="Sync for Test", - source_connection_id=SC_ID, - destination_connection_ids=[DEST_CONN_ID], - status=SyncStatus.ACTIVE, - organization_id=ORG_ID, - created_at=NOW, - modified_at=NOW, - ) - - -def _sync_job_schema(id: UUID = JOB_ID, sync_id: UUID = SYNC_ID) -> schemas.SyncJob: - """Create a minimal SyncJob schema for testing.""" - return schemas.SyncJob( - id=id, - sync_id=sync_id, - status=SyncJobStatus.PENDING, - organization_id=ORG_ID, - created_at=NOW, - modified_at=NOW, - ) - - -class _FakeUoW: - """Minimal UoW fake — just exposes .session.""" - - def __init__(self, session=None): - self.session = session or AsyncMock() - - -# --------------------------------------------------------------------------- -# run() tests -# --------------------------------------------------------------------------- - - -@dataclass -class RunCase: - """Table-driven case for run().""" - - name: str - sc: Optional[SourceConnection] = None - collection: Optional[Collection] = None - connection: Optional[Connection] = None - force_full_sync: bool = False - cursor: Optional[SyncCursor] = None - trigger_result: Optional[tuple] = None - expected_error: Optional[str] = None - expected_status: Optional[int] = None - - -_SYNC_SCHEMA = MagicMock() -_SYNC_SCHEMA.model_dump.return_value = {} -_SYNC_JOB_SCHEMA = MagicMock() -_SYNC_JOB_SCHEMA.id = JOB_ID -_SYNC_JOB_SCHEMA.to_source_connection_job.return_value = MagicMock() - - -RUN_CASES = [ - RunCase( - name="missing_source_connection", - expected_error="Source connection not found", - expected_status=404, - ), - RunCase( - name="no_sync_id", - sc=_source_connection(sync_id=None), - expected_error="Source connection has no associated sync", - expected_status=400, - ), - # force_full_sync_no_cursor: removed — service now logs info and proceeds - # (no cursor means first sync, which is inherently a full sync) -] - - -@pytest.mark.asyncio -@pytest.mark.parametrize("case", RUN_CASES, ids=lambda c: c.name) -async def test_run_errors(case: RunCase): - """Test run() error paths.""" - sc_repo = FakeSourceConnectionRepository() - collection_repo = FakeCollectionRepository() - connection_repo = FakeConnectionRepository() - sync_cursor_repo = FakeSyncCursorRepository() - - if case.sc: - sc_repo.seed(case.sc.id, case.sc) - if case.collection: - collection_repo.seed(case.collection.id, case.collection) - collection_repo.seed_readable(case.collection.readable_id, case.collection) - if case.connection: - connection_repo.seed(case.connection.id, case.connection) - if case.cursor: - sync_cursor_repo.seed(SYNC_ID, case.cursor) - - svc = _build_service( - sc_repo=sc_repo, - collection_repo=collection_repo, - connection_repo=connection_repo, - sync_cursor_repo=sync_cursor_repo, - ) - - with pytest.raises(HTTPException) as exc_info: - await svc.run( - AsyncMock(), - id=case.sc.id if case.sc else uuid4(), - ctx=_ctx(), - force_full_sync=case.force_full_sync, - ) - - assert exc_info.value.status_code == case.expected_status - assert case.expected_error in str(exc_info.value.detail) - - -@pytest.mark.asyncio -async def test_run_force_full_sync_happy_path(): - """Test run() with force_full_sync=True and valid cursor data.""" - sc_repo = FakeSourceConnectionRepository() - collection_repo = FakeCollectionRepository() - connection_repo = FakeConnectionRepository() - sync_cursor_repo = FakeSyncCursorRepository() - sync_service = FakeSyncRecordService() - temporal_workflow_service = FakeTemporalWorkflowService() - event_bus = _FakeEventBus() - - sc = _source_connection() - sc_repo.seed(SC_ID, sc) - collection_repo.seed_readable("test-collection", _collection()) - connection_repo.seed(CONNECTION_ID, _connection()) - - cursor = MagicMock(spec=SyncCursor) - cursor.cursor_data = {"last_modified": "2024-01-01"} - sync_cursor_repo.seed(SYNC_ID, cursor) - - mock_sync = MagicMock() - mock_sync_job = MagicMock() - sync_service.set_trigger_result(mock_sync, mock_sync_job) - - mock_collection_schema = MagicMock() - mock_collection_schema.id = COLLECTION_ID - mock_collection_schema.name = "Test Collection" - mock_collection_schema.readable_id = "test-collection" - - mock_connection_schema = MagicMock() - mock_connection_schema.short_name = "github" - - expected_sc_job = MagicMock() - mock_sj_schema = MagicMock() - mock_sj_schema.id = JOB_ID - mock_sj_schema.to_source_connection_job.return_value = expected_sc_job - - svc = _build_service( - sc_repo=sc_repo, - collection_repo=collection_repo, - connection_repo=connection_repo, - sync_cursor_repo=sync_cursor_repo, - sync_service=sync_service, - temporal_workflow_service=temporal_workflow_service, - event_bus=event_bus, - ) - - _mod = "airweave.domains.syncs.lifecycle_service.schemas" - with ( - patch(f"{_mod}.Collection.model_validate", return_value=mock_collection_schema), - patch(f"{_mod}.Connection.model_validate", return_value=mock_connection_schema), - patch(f"{_mod}.SyncJob.model_validate", return_value=mock_sj_schema), - ): - result = await svc.run(AsyncMock(), id=SC_ID, ctx=_ctx(), force_full_sync=True) - - assert result == expected_sc_job - assert len(event_bus.events) == 1 - assert len(temporal_workflow_service._calls) == 1 - wf_call = temporal_workflow_service._calls[0] - assert wf_call[0] == "run_source_connection_workflow" - assert wf_call[7] is True # force_full_sync arg - - -@pytest.mark.asyncio -async def test_run_happy_path(): - """Test run() happy path: triggers workflow and publishes event.""" - sc_repo = FakeSourceConnectionRepository() - collection_repo = FakeCollectionRepository() - connection_repo = FakeConnectionRepository() - sync_service = FakeSyncRecordService() - temporal_workflow_service = FakeTemporalWorkflowService() - event_bus = _FakeEventBus() - - sc = _source_connection() - sc_repo.seed(SC_ID, sc) - collection_repo.seed_readable("test-collection", _collection()) - connection_repo.seed(CONNECTION_ID, _connection()) - - mock_sync = MagicMock() - mock_sync_job = MagicMock() - sync_service.set_trigger_result(mock_sync, mock_sync_job) - - mock_collection_schema = MagicMock() - mock_collection_schema.id = COLLECTION_ID - mock_collection_schema.name = "Test Collection" - mock_collection_schema.readable_id = "test-collection" - - mock_connection_schema = MagicMock() - mock_connection_schema.short_name = "github" - - expected_sc_job = MagicMock() - mock_sj_schema = MagicMock() - mock_sj_schema.id = JOB_ID - mock_sj_schema.to_source_connection_job.return_value = expected_sc_job - - svc = _build_service( - sc_repo=sc_repo, - collection_repo=collection_repo, - connection_repo=connection_repo, - sync_service=sync_service, - temporal_workflow_service=temporal_workflow_service, - event_bus=event_bus, - ) - - _mod = "airweave.domains.syncs.lifecycle_service.schemas" - with ( - patch(f"{_mod}.Collection.model_validate", return_value=mock_collection_schema), - patch(f"{_mod}.Connection.model_validate", return_value=mock_connection_schema), - patch(f"{_mod}.SyncJob.model_validate", return_value=mock_sj_schema), - ): - result = await svc.run(AsyncMock(), id=SC_ID, ctx=_ctx()) - - assert result == expected_sc_job - assert len(event_bus.events) == 1 - assert len(temporal_workflow_service._calls) == 1 - assert temporal_workflow_service._calls[0][0] == "run_source_connection_workflow" - - -# --------------------------------------------------------------------------- -# get_jobs() tests -# --------------------------------------------------------------------------- - - -@dataclass -class GetJobsCase: - """Table-driven case for get_jobs().""" - - name: str - sc: Optional[SourceConnection] = None - jobs: list = field(default_factory=list) - expected_count: int = 0 - expected_error: Optional[str] = None - expected_status: Optional[int] = None - - -GET_JOBS_CASES = [ - GetJobsCase( - name="missing_source_connection", - expected_error="Source connection not found", - expected_status=404, - ), - GetJobsCase( - name="no_sync_id_returns_empty", - sc=_source_connection(sync_id=None), - expected_count=0, - ), - GetJobsCase( - name="empty_jobs", - sc=_source_connection(), - expected_count=0, - ), - GetJobsCase( - name="with_seeded_jobs", - sc=_source_connection(), - jobs=[_sync_job(id=uuid4()), _sync_job(id=uuid4())], - expected_count=2, - ), -] - - -@pytest.mark.asyncio -@pytest.mark.parametrize("case", GET_JOBS_CASES, ids=lambda c: c.name) -async def test_get_jobs(case: GetJobsCase): - """Test get_jobs() with table-driven cases.""" - sc_repo = FakeSourceConnectionRepository() - sync_job_repo = FakeSyncJobRepository() - - if case.sc: - sc_repo.seed(case.sc.id, case.sc) - if case.jobs and case.sc and case.sc.sync_id: - sync_job_repo.seed_jobs_for_sync(case.sc.sync_id, case.jobs) - - svc = _build_service(sc_repo=sc_repo, sync_job_repo=sync_job_repo) - - if case.expected_error: - with pytest.raises(HTTPException) as exc_info: - await svc.get_jobs(AsyncMock(), id=case.sc.id if case.sc else uuid4(), ctx=_ctx()) - assert exc_info.value.status_code == case.expected_status - else: - result = await svc.get_jobs(AsyncMock(), id=case.sc.id, ctx=_ctx()) - assert len(result) == case.expected_count - - -# --------------------------------------------------------------------------- -# cancel_job() tests -# --------------------------------------------------------------------------- - - -@dataclass -class CancelCase: - """Table-driven case for cancel_job().""" - - name: str - sc: Optional[SourceConnection] = None - job: Optional[SyncJob] = None - cancel_success: bool = True - workflow_found: bool = True - expected_error: Optional[str] = None - expected_status: Optional[int] = None - - -CANCEL_CASES = [ - CancelCase( - name="missing_source_connection", - expected_error="Source connection not found", - expected_status=404, - ), - CancelCase( - name="no_sync_id", - sc=_source_connection(sync_id=None), - expected_error="Source connection has no associated sync", - expected_status=400, - ), - CancelCase( - name="job_not_found", - sc=_source_connection(), - expected_error="Sync job not found", - expected_status=404, - ), - CancelCase( - name="wrong_sync", - sc=_source_connection(), - job=_sync_job(sync_id=uuid4()), - expected_error="Sync job does not belong to this source connection", - expected_status=400, - ), - CancelCase( - name="non_cancellable_state", - sc=_source_connection(), - job=_sync_job(status=SyncJobStatus.COMPLETED), - expected_error="Cannot cancel job in", - expected_status=400, - ), - CancelCase( - name="temporal_failure", - sc=_source_connection(), - job=_sync_job(status=SyncJobStatus.RUNNING), - cancel_success=False, - expected_error="Failed to request cancellation from Temporal", - expected_status=502, - ), -] - - -@pytest.mark.asyncio -@pytest.mark.parametrize("case", CANCEL_CASES, ids=lambda c: c.name) -async def test_cancel_job_errors(case: CancelCase): - """Test cancel_job() error paths.""" - sc_repo = FakeSourceConnectionRepository() - sync_job_repo = FakeSyncJobRepository() - state_machine = AsyncMock() - temporal_workflow_service = FakeTemporalWorkflowService() - - if case.sc: - sc_repo.seed(case.sc.id, case.sc) - if case.job: - sync_job_repo.seed(case.job.id, case.job) - - temporal_workflow_service.set_cancel_result( - {"success": case.cancel_success, "workflow_found": case.workflow_found} - ) - - svc = _build_service( - sc_repo=sc_repo, - sync_job_repo=sync_job_repo, - state_machine=state_machine, - temporal_workflow_service=temporal_workflow_service, - ) - - with pytest.raises(HTTPException) as exc_info: - await svc.cancel_job( - AsyncMock(), - source_connection_id=case.sc.id if case.sc else uuid4(), - job_id=case.job.id if case.job else JOB_ID, - ctx=_ctx(), - ) - - assert exc_info.value.status_code == case.expected_status - assert case.expected_error in str(exc_info.value.detail) - - -@pytest.mark.asyncio -async def test_cancel_job_happy_path(): - """Successful cancel: workflow found, job transitions to CANCELLING.""" - sc_repo = FakeSourceConnectionRepository() - sync_job_repo = FakeSyncJobRepository() - state_machine = AsyncMock() - temporal_workflow_service = FakeTemporalWorkflowService() - - sc = _source_connection() - sc_repo.seed(SC_ID, sc) - - job = _sync_job(status=SyncJobStatus.RUNNING) - sync_job_repo.seed(JOB_ID, job) - - temporal_workflow_service.set_cancel_result({"success": True, "workflow_found": True}) - - db_mock = AsyncMock() - - svc = _build_service( - sc_repo=sc_repo, - sync_job_repo=sync_job_repo, - state_machine=state_machine, - temporal_workflow_service=temporal_workflow_service, - ) - - mock_sj_schema = MagicMock() - expected_result = MagicMock() - mock_sj_schema.to_source_connection_job.return_value = expected_result - - _mod = "airweave.domains.syncs.lifecycle_service.schemas" - with patch(f"{_mod}.SyncJob.model_validate", return_value=mock_sj_schema): - result = await svc.cancel_job( - db_mock, source_connection_id=SC_ID, job_id=JOB_ID, ctx=_ctx() - ) - - assert result == expected_result - assert state_machine.transition.await_count == 1 - call_kwargs = state_machine.transition.call_args.kwargs - assert call_kwargs["target"] == SyncJobStatus.CANCELLING - assert len(temporal_workflow_service._calls) == 1 - - -@pytest.mark.asyncio -async def test_cancel_job_workflow_not_found(): - """When workflow is not found, job should be marked CANCELLED directly.""" - from unittest.mock import patch - - sc_repo = FakeSourceConnectionRepository() - sync_job_repo = FakeSyncJobRepository() - state_machine = AsyncMock() - temporal_workflow_service = FakeTemporalWorkflowService() - - sc = _source_connection() - sc_repo.seed(SC_ID, sc) - - job = _sync_job(status=SyncJobStatus.RUNNING) - sync_job_repo.seed(JOB_ID, job) - - temporal_workflow_service.set_cancel_result({"success": True, "workflow_found": False}) - - db_mock = AsyncMock() - - svc = _build_service( - sc_repo=sc_repo, - sync_job_repo=sync_job_repo, - state_machine=state_machine, - temporal_workflow_service=temporal_workflow_service, - ) - - mock_sj_schema = MagicMock() - mock_sj_schema.to_source_connection_job.return_value = MagicMock() - - _mod = "airweave.domains.syncs.lifecycle_service.schemas" - with patch(f"{_mod}.SyncJob.model_validate", return_value=mock_sj_schema): - await svc.cancel_job(db_mock, source_connection_id=SC_ID, job_id=JOB_ID, ctx=_ctx()) - - assert state_machine.transition.await_count == 2 - targets = [c.kwargs["target"] for c in state_machine.transition.call_args_list] - assert targets == [SyncJobStatus.CANCELLING, SyncJobStatus.CANCELLED] - - -# --------------------------------------------------------------------------- -# provision_sync() tests -# --------------------------------------------------------------------------- - - -@dataclass -class ProvisionCase: - """Table-driven case for provision_sync().""" - - name: str - source_entry: Optional[SourceRegistryEntry] = None - schedule_config: Optional[ScheduleConfig] = None - run_immediately: bool = True - expected_none: bool = False - expected_error: Optional[str] = None - expected_status: Optional[int] = None - expected_cron: Optional[str] = None - expect_schedule_call: bool = False - - -PROVISION_CASES = [ - ProvisionCase( - name="federated_search_returns_none", - source_entry=_source_entry(federated_search=True), - expected_none=True, - ), - ProvisionCase( - name="no_schedule_no_immediate_returns_none", - source_entry=_source_entry(), - schedule_config=ScheduleConfig(cron=None), - run_immediately=False, - expected_none=True, - ), - ProvisionCase( - name="default_continuous_schedule", - source_entry=_source_entry(supports_continuous=True), - expected_cron=CONTINUOUS_SOURCE_DEFAULT_CRON, - expect_schedule_call=True, - ), - ProvisionCase( - name="explicit_cron_used", - source_entry=_source_entry(), - schedule_config=ScheduleConfig(cron="0 3 * * *"), - expected_cron="0 3 * * *", - expect_schedule_call=True, - ), - ProvisionCase( - name="sub_hourly_rejected_for_non_continuous", - source_entry=_source_entry(supports_continuous=False), - schedule_config=ScheduleConfig(cron="*/5 * * * *"), - expected_error="does not support continuous syncs", - expected_status=400, - ), - ProvisionCase( - name="every_minute_rejected_for_non_continuous", - source_entry=_source_entry(supports_continuous=False), - schedule_config=ScheduleConfig(cron="* * * * *"), - expected_error="does not support continuous syncs", - expected_status=400, - ), - ProvisionCase( - name="sub_hourly_ok_for_continuous", - source_entry=_source_entry(supports_continuous=True), - schedule_config=ScheduleConfig(cron="*/5 * * * *"), - expected_cron="*/5 * * * *", - expect_schedule_call=True, - ), - ProvisionCase( - name="happy_path_immediate_no_schedule", - source_entry=_source_entry(), - schedule_config=ScheduleConfig(cron=None), - run_immediately=True, - expected_none=False, - expected_cron=None, - expect_schedule_call=False, - ), -] - - -@pytest.mark.asyncio -@pytest.mark.parametrize("case", PROVISION_CASES, ids=lambda c: c.name) -async def test_provision_sync(case: ProvisionCase): - """Test provision_sync() with table-driven cases.""" - sync_service = FakeSyncRecordService() - temporal_schedule_service = FakeTemporalScheduleService() - - mock_sync = _sync_schema() - mock_sync_job = _sync_job_schema() - sync_service.set_create_result(mock_sync, mock_sync_job) - - svc = _build_service( - sync_service=sync_service, - temporal_schedule_service=temporal_schedule_service, - ) - - uow = _FakeUoW() - - if case.expected_error: - with pytest.raises(HTTPException) as exc_info: - await svc.provision_sync( - uow.session, - name="Test", - source_connection_id=SC_ID, - destination_connection_ids=[DEST_CONN_ID], - collection_id=COLLECTION_ID, - collection_readable_id="test-collection", - source_entry=case.source_entry or _source_entry(), - schedule_config=case.schedule_config, - run_immediately=case.run_immediately, - ctx=_ctx(), - uow=uow, - ) - assert exc_info.value.status_code == case.expected_status - assert case.expected_error in str(exc_info.value.detail) - return - - result = await svc.provision_sync( - uow.session, - name="Test", - source_connection_id=SC_ID, - destination_connection_ids=[DEST_CONN_ID], - collection_id=COLLECTION_ID, - collection_readable_id="test-collection", - source_entry=case.source_entry or _source_entry(), - schedule_config=case.schedule_config, - run_immediately=case.run_immediately, - ctx=_ctx(), - uow=uow, - ) - - if case.expected_none: - assert result is None - assert len(sync_service._calls) == 0 - return - - assert result is not None - assert isinstance(result, SyncProvisionResult) - assert result.sync_id == SYNC_ID - assert result.cron_schedule == case.expected_cron - - create_calls = [c for c in sync_service._calls if c[0] == "create_sync"] - assert len(create_calls) == 1 - assert create_calls[0][3] == case.expected_cron # cron_schedule arg - - schedule_calls = [ - c for c in temporal_schedule_service._calls if c[0] == "create_or_update_schedule" - ] - if case.expect_schedule_call: - assert len(schedule_calls) == 1 - assert schedule_calls[0][2] == case.expected_cron - else: - assert len(schedule_calls) == 0 - - -@pytest.mark.asyncio -async def test_provision_sync_default_daily_schedule(): - """Default daily schedule uses current UTC hour:minute.""" - sync_service = FakeSyncRecordService() - temporal_schedule_service = FakeTemporalScheduleService() - - mock_sync = _sync_schema() - sync_service.set_create_result(mock_sync, _sync_job_schema()) - - svc = _build_service( - sync_service=sync_service, - temporal_schedule_service=temporal_schedule_service, - ) - - uow = _FakeUoW() - result = await svc.provision_sync( - uow.session, - name="Test", - source_connection_id=SC_ID, - destination_connection_ids=[DEST_CONN_ID], - collection_id=COLLECTION_ID, - collection_readable_id="test-collection", - source_entry=_source_entry(supports_continuous=False), - schedule_config=None, - run_immediately=True, - ctx=_ctx(), - uow=uow, - ) - - assert result is not None - parts = result.cron_schedule.split() - assert len(parts) == 5 - assert parts[2:] == ["*", "*", "*"] # daily schedule pattern - - schedule_calls = [ - c for c in temporal_schedule_service._calls if c[0] == "create_or_update_schedule" - ] - assert len(schedule_calls) == 1 diff --git a/backend/airweave/domains/syncs/tests/test_record_service.py b/backend/airweave/domains/syncs/tests/test_record_service.py deleted file mode 100644 index 83dfa4fdd..000000000 --- a/backend/airweave/domains/syncs/tests/test_record_service.py +++ /dev/null @@ -1,254 +0,0 @@ -"""Table-driven tests for SyncRecordService. - -Covers trigger_sync_run (happy, active-job, not-found). -""" - -from dataclasses import dataclass -from typing import Optional -from unittest.mock import AsyncMock, MagicMock, patch -from uuid import UUID, uuid4 - -import pytest -from fastapi import HTTPException - -from airweave.core.constants.reserved_ids import NATIVE_VESPA_UUID -from airweave.domains.syncs.record_service import SyncRecordService - -ORG_ID = uuid4() -SYNC_ID = uuid4() - - -def _mock_ctx() -> MagicMock: - ctx = MagicMock() - ctx.organization = MagicMock() - ctx.organization.id = ORG_ID - ctx.logger = MagicMock() - ctx.has_feature = MagicMock(return_value=False) - return ctx - - -def _mock_sync_model(sync_id: UUID = SYNC_ID, status: str = "active") -> MagicMock: - sync = MagicMock() - sync.id = sync_id - sync.name = "test-sync" - sync.status = status - return sync - - -def _mock_sync_job_model(sync_id: UUID = SYNC_ID, status: str = "PENDING") -> MagicMock: - job = MagicMock() - job.id = uuid4() - job.sync_id = sync_id - job.status = status - job.organization_id = ORG_ID - return job - - -# --------------------------------------------------------------------------- -# trigger_sync_run -# --------------------------------------------------------------------------- - - -@dataclass -class TriggerCase: - """Parameters for a single trigger_sync_run scenario.""" - - name: str - active_jobs: list - sync_exists: bool = True - sync_status: str = "active" - expect_error: Optional[type] = None - error_status: Optional[int] = None - - -TRIGGER_CASES = [ - TriggerCase( - name="happy_path", - active_jobs=[], - sync_exists=True, - ), - TriggerCase( - name="active_job_blocks", - active_jobs=[_mock_sync_job_model(status="running")], - expect_error=HTTPException, - error_status=400, - ), - TriggerCase( - name="sync_not_found", - active_jobs=[], - sync_exists=False, - expect_error=ValueError, - ), - TriggerCase( - name="non_active_sync_rejected", - active_jobs=[], - sync_exists=True, - sync_status="paused", - expect_error=HTTPException, - error_status=409, - ), -] - - -@pytest.mark.asyncio -@pytest.mark.parametrize("case", TRIGGER_CASES, ids=lambda c: c.name) -async def test_trigger_sync_run(case: TriggerCase) -> None: - """Verify trigger_sync_run behaviour for each scenario.""" - sync_repo = AsyncMock() - sync_job_repo = AsyncMock() - connection_repo = AsyncMock() - - sync_job_repo.get_active_for_sync = AsyncMock(return_value=case.active_jobs) - sync_repo.get = AsyncMock( - return_value=_mock_sync_model(status=case.sync_status) if case.sync_exists else None - ) - - created_job = _mock_sync_job_model() - sync_job_repo.create = AsyncMock(return_value=created_job) - - svc = SyncRecordService( - sync_repo=sync_repo, - sync_job_repo=sync_job_repo, - connection_repo=connection_repo, - ) - db = AsyncMock() - ctx = _mock_ctx() - - if case.expect_error: - with pytest.raises(case.expect_error) as exc_info: - with patch("airweave.domains.syncs.record_service.UnitOfWork") as mock_uow_cls: - mock_uow = AsyncMock() - mock_uow.session = AsyncMock() - mock_uow.commit = AsyncMock() - mock_uow.session.refresh = AsyncMock() - mock_uow_cls.return_value.__aenter__ = AsyncMock(return_value=mock_uow) - mock_uow_cls.return_value.__aexit__ = AsyncMock(return_value=False) - await svc.trigger_sync_run(db, SYNC_ID, ctx) - if case.error_status and isinstance(exc_info.value, HTTPException): - assert exc_info.value.status_code == case.error_status - else: - with patch("airweave.domains.syncs.record_service.UnitOfWork") as mock_uow_cls: - mock_uow = AsyncMock() - mock_uow.session = AsyncMock() - mock_uow.commit = AsyncMock() - mock_uow.session.refresh = AsyncMock() - mock_uow_cls.return_value.__aenter__ = AsyncMock(return_value=mock_uow) - mock_uow_cls.return_value.__aexit__ = AsyncMock(return_value=False) - - with patch("airweave.domains.syncs.record_service.schemas") as mock_schemas: - mock_sync_schema = MagicMock() - mock_job_schema = MagicMock() - mock_schemas.Sync.model_validate.return_value = mock_sync_schema - mock_schemas.SyncJob.model_validate.return_value = mock_job_schema - mock_schemas.SyncJobCreate = MagicMock() - - result = await svc.trigger_sync_run(db, SYNC_ID, ctx) - assert result == (mock_sync_schema, mock_job_schema) - - sync_job_repo.create.assert_called_once() - mock_uow.commit.assert_called_once() - - -# --------------------------------------------------------------------------- -# resolve_destination_ids -# --------------------------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_create_sync_flushes_and_refreshes_job_before_validation() -> None: - """Verify create_sync flushes twice and refreshes the job before validation.""" - sync_repo = AsyncMock() - sync_job_repo = AsyncMock() - connection_repo = AsyncMock() - svc = SyncRecordService( - sync_repo=sync_repo, - sync_job_repo=sync_job_repo, - connection_repo=connection_repo, - ) - - sync_schema = MagicMock(id=SYNC_ID) - sync_repo.create = AsyncMock(return_value=sync_schema) - created_job = MagicMock() - sync_job_repo.create = AsyncMock(return_value=created_job) - - uow = MagicMock() - uow.session = AsyncMock() - uow.session.flush = AsyncMock() - uow.session.refresh = AsyncMock() - ctx = _mock_ctx() - - with patch("airweave.domains.syncs.record_service.schemas") as mock_schemas: - validated_job_schema = MagicMock() - mock_schemas.SyncJob.model_validate.return_value = validated_job_schema - mock_schemas.SyncJobCreate = MagicMock() - - sync, sync_job = await svc.create_sync( - AsyncMock(), - name="Test Sync", - source_connection_id=uuid4(), - destination_connection_ids=[NATIVE_VESPA_UUID], - cron_schedule=None, - run_immediately=True, - ctx=ctx, - uow=uow, - ) - - assert sync is sync_schema - assert sync_job is validated_job_schema - assert uow.session.flush.await_count == 2 - uow.session.refresh.assert_awaited_once_with(created_job) - - -@pytest.mark.asyncio -async def test_create_sync_flushes_sync_even_without_immediate_job() -> None: - """Verify create_sync flushes once and skips job when run_immediately=False.""" - sync_repo = AsyncMock() - sync_job_repo = AsyncMock() - connection_repo = AsyncMock() - svc = SyncRecordService( - sync_repo=sync_repo, - sync_job_repo=sync_job_repo, - connection_repo=connection_repo, - ) - - sync_schema = MagicMock(id=SYNC_ID) - sync_repo.create = AsyncMock(return_value=sync_schema) - - uow = MagicMock() - uow.session = AsyncMock() - uow.session.flush = AsyncMock() - uow.session.refresh = AsyncMock() - ctx = _mock_ctx() - - sync, sync_job = await svc.create_sync( - AsyncMock(), - name="Test Sync", - source_connection_id=uuid4(), - destination_connection_ids=[NATIVE_VESPA_UUID], - cron_schedule="0 * * * *", - run_immediately=False, - ctx=ctx, - uow=uow, - ) - - assert sync is sync_schema - assert sync_job is None - uow.session.flush.assert_awaited_once() - uow.session.refresh.assert_not_awaited() - sync_job_repo.create.assert_not_called() - - -@pytest.mark.asyncio -async def test_resolve_destination_ids_returns_native_only() -> None: - """Verify resolve_destination_ids returns only native Vespa UUID.""" - svc = SyncRecordService( - sync_repo=AsyncMock(), - sync_job_repo=AsyncMock(), - connection_repo=AsyncMock(), - ) - db = AsyncMock() - ctx = _mock_ctx() - - destination_ids = await svc.resolve_destination_ids(db, ctx) - - assert destination_ids == [NATIVE_VESPA_UUID] diff --git a/backend/airweave/domains/syncs/tests/test_service.py b/backend/airweave/domains/syncs/tests/test_service.py index 601aa62ef..9632f49cf 100644 --- a/backend/airweave/domains/syncs/tests/test_service.py +++ b/backend/airweave/domains/syncs/tests/test_service.py @@ -10,6 +10,7 @@ from uuid import uuid4 import pytest +from fastapi import HTTPException from airweave.core.shared_models import SyncJobStatus from airweave.domains.syncs.service import SyncService @@ -93,9 +94,14 @@ async def test_run(case: RunCase): ) svc = SyncService( - state_machine=fake_state_machine, + sync_repo=MagicMock(), + sync_job_repo=MagicMock(), + sync_cursor_repo=MagicMock(), + state_machine=MagicMock(), + job_state_machine=fake_state_machine, + temporal_workflow_service=MagicMock(), + temporal_schedule_service=MagicMock(), sync_factory=fake_factory, - sync_state_machine=MagicMock(), ) sync = _mock_sync() @@ -160,9 +166,14 @@ async def test_run_forwards_optional_kwargs(): ) svc = SyncService( - state_machine=fake_state_machine, + sync_repo=MagicMock(), + sync_job_repo=MagicMock(), + sync_cursor_repo=MagicMock(), + state_machine=MagicMock(), + job_state_machine=fake_state_machine, + temporal_workflow_service=MagicMock(), + temporal_schedule_service=MagicMock(), sync_factory=fake_factory, - sync_state_machine=MagicMock(), ) mock_db = AsyncMock() @@ -205,9 +216,7 @@ async def test_credential_error_propagates_error_category(): cause = TokenExpiredError( "JWT expired", source_short_name="github", provider_kind=AuthProviderKind.OAUTH ) - wrapper = SourceValidationError( - short_name="github", reason="credential validation failed" - ) + wrapper = SourceValidationError(short_name="github", reason="credential validation failed") wrapper.__cause__ = cause fake_sm = AsyncMock() @@ -215,9 +224,14 @@ async def test_credential_error_propagates_error_category(): fake_factory.create_orchestrator = AsyncMock(side_effect=wrapper) svc = SyncService( - state_machine=fake_sm, + sync_repo=MagicMock(), + sync_job_repo=MagicMock(), + sync_cursor_repo=MagicMock(), + state_machine=AsyncMock(), + job_state_machine=fake_sm, + temporal_workflow_service=MagicMock(), + temporal_schedule_service=MagicMock(), sync_factory=fake_factory, - sync_state_machine=MagicMock(), ) with patch("airweave.domains.syncs.service.get_db_context") as mock_db_ctx: @@ -236,10 +250,7 @@ async def test_credential_error_propagates_error_category(): fake_sm.transition.assert_awaited_once() call_kwargs = fake_sm.transition.call_args.kwargs assert call_kwargs["target"] == SyncJobStatus.FAILED - assert ( - call_kwargs["error_category"] - == SourceConnectionErrorCategory.OAUTH_CREDENTIALS_EXPIRED - ) + assert call_kwargs["error_category"] == SourceConnectionErrorCategory.OAUTH_CREDENTIALS_EXPIRED @pytest.mark.asyncio @@ -247,14 +258,17 @@ async def test_non_credential_error_has_no_error_category(): """Non-auth factory error -> error_category=None on state machine transition.""" fake_sm = AsyncMock() fake_factory = MagicMock() - fake_factory.create_orchestrator = AsyncMock( - side_effect=RuntimeError("bad config") - ) + fake_factory.create_orchestrator = AsyncMock(side_effect=RuntimeError("bad config")) svc = SyncService( - state_machine=fake_sm, + sync_repo=MagicMock(), + sync_job_repo=MagicMock(), + sync_cursor_repo=MagicMock(), + state_machine=AsyncMock(), + job_state_machine=fake_sm, + temporal_workflow_service=MagicMock(), + temporal_schedule_service=MagicMock(), sync_factory=fake_factory, - sync_state_machine=MagicMock(), ) with patch("airweave.domains.syncs.service.get_db_context") as mock_db_ctx: @@ -278,9 +292,591 @@ def test_stores_injected_deps(): fake_sm = MagicMock() fake_factory = MagicMock() svc = SyncService( + sync_repo=MagicMock(), + sync_job_repo=MagicMock(), + sync_cursor_repo=MagicMock(), state_machine=fake_sm, + job_state_machine=MagicMock(), + temporal_workflow_service=MagicMock(), + temporal_schedule_service=MagicMock(), sync_factory=fake_factory, - sync_state_machine=MagicMock(), ) assert svc._state_machine is fake_sm assert svc._sync_factory is fake_factory + + +# --------------------------------------------------------------------------- +# Helper: build a SyncService with configurable mocks +# --------------------------------------------------------------------------- + + +def _build_svc( + sync_repo=None, + sync_job_repo=None, + sync_cursor_repo=None, + state_machine=None, + job_state_machine=None, + temporal_workflow_service=None, + temporal_schedule_service=None, + sync_factory=None, +): + return SyncService( + sync_repo=sync_repo or AsyncMock(), + sync_job_repo=sync_job_repo or AsyncMock(), + sync_cursor_repo=sync_cursor_repo or AsyncMock(), + state_machine=state_machine or AsyncMock(), + job_state_machine=job_state_machine or AsyncMock(), + temporal_workflow_service=temporal_workflow_service or AsyncMock(), + temporal_schedule_service=temporal_schedule_service or AsyncMock(), + sync_factory=sync_factory or MagicMock(), + ) + + +# --------------------------------------------------------------------------- +# get() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_returns_sync(): + repo = AsyncMock() + expected = MagicMock() + repo.get.return_value = expected + + svc = _build_svc(sync_repo=repo) + result = await svc.get(AsyncMock(), sync_id=uuid4(), ctx=_mock_ctx()) + assert result is expected + + +@pytest.mark.asyncio +async def test_get_raises_when_not_found(): + repo = AsyncMock() + repo.get.return_value = None + + svc = _build_svc(sync_repo=repo) + with pytest.raises(HTTPException) as exc_info: + await svc.get(AsyncMock(), sync_id=uuid4(), ctx=_mock_ctx()) + assert exc_info.value.status_code == 404 + + +# --------------------------------------------------------------------------- +# pause() / resume() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_pause_delegates_to_state_machine(): + from airweave.core.shared_models import SyncStatus + + sm = AsyncMock() + expected = MagicMock() + sm.transition.return_value = expected + + svc = _build_svc(state_machine=sm) + sid = uuid4() + result = await svc.pause(sid, _mock_ctx(), reason="maintenance") + + assert result is expected + call_kw = sm.transition.call_args.kwargs + assert call_kw["sync_id"] == sid + assert call_kw["target"] == SyncStatus.PAUSED + assert call_kw["reason"] == "maintenance" + + +@pytest.mark.asyncio +async def test_resume_delegates_to_state_machine(): + from airweave.core.shared_models import SyncStatus + + sm = AsyncMock() + expected = MagicMock() + sm.transition.return_value = expected + + svc = _build_svc(state_machine=sm) + sid = uuid4() + result = await svc.resume(sid, _mock_ctx()) + + assert result is expected + assert sm.transition.call_args.kwargs["target"] == SyncStatus.ACTIVE + + +# --------------------------------------------------------------------------- +# resolve_destination_ids() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_resolve_destination_ids_returns_vespa(): + from airweave.core.constants.reserved_ids import NATIVE_VESPA_UUID + + svc = _build_svc() + result = await svc.resolve_destination_ids(AsyncMock(), _mock_ctx()) + assert result == [NATIVE_VESPA_UUID] + + +# --------------------------------------------------------------------------- +# get_jobs() +# --------------------------------------------------------------------------- + + +def _orm_sync_job(job_id=None, sync_id=None, status=SyncJobStatus.COMPLETED): + """Create a MagicMock that passes schemas.SyncJob.model_validate().""" + m = MagicMock() + m.id = job_id or uuid4() + m.sync_id = sync_id or uuid4() + m.organization_id = uuid4() + m.status = status + m.scheduled = False + m.entities_inserted = 0 + m.entities_updated = 0 + m.entities_deleted = 0 + m.entities_kept = 0 + m.entities_skipped = 0 + m.entities_encountered = {} + m.started_at = None + m.completed_at = None + m.failed_at = None + m.error = None + m.error_category = None + m.access_token = None + m.sync_config = None + m.sync_metadata = None + m.created_by_email = None + m.modified_by_email = None + m.created_at = None + m.modified_at = None + m.sync_name = None + return m + + +@pytest.mark.asyncio +async def test_get_jobs_returns_validated_schemas(): + job_repo = AsyncMock() + mock_job = _orm_sync_job() + job_repo.get_all_by_sync_id.return_value = [mock_job] + + svc = _build_svc(sync_job_repo=job_repo) + jobs = await svc.get_jobs(AsyncMock(), sync_id=uuid4(), ctx=_mock_ctx()) + assert len(jobs) == 1 + assert jobs[0].id == mock_job.id + + +# --------------------------------------------------------------------------- +# validate_force_full_sync() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_validate_force_full_sync_no_cursor(): + cursor_repo = AsyncMock() + cursor_repo.get_by_sync_id.return_value = None + + svc = _build_svc(sync_cursor_repo=cursor_repo) + ctx = _mock_ctx() + await svc.validate_force_full_sync(AsyncMock(), uuid4(), ctx) + ctx.logger.info.assert_called_once() + assert "no cursor data" in ctx.logger.info.call_args[0][0] + + +@pytest.mark.asyncio +async def test_validate_force_full_sync_with_cursor(): + cursor_repo = AsyncMock() + cursor = MagicMock() + cursor.cursor_data = {"some": "data"} + cursor_repo.get_by_sync_id.return_value = cursor + + svc = _build_svc(sync_cursor_repo=cursor_repo) + ctx = _mock_ctx() + await svc.validate_force_full_sync(AsyncMock(), uuid4(), ctx) + ctx.logger.info.assert_called_once() + assert "Force full sync" in ctx.logger.info.call_args[0][0] + + +# --------------------------------------------------------------------------- +# cancel_job() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_cancel_job_not_found(): + job_repo = AsyncMock() + job_repo.get.return_value = None + + svc = _build_svc(sync_job_repo=job_repo) + with pytest.raises(HTTPException) as exc_info: + await svc.cancel_job(AsyncMock(), job_id=uuid4(), ctx=_mock_ctx()) + assert exc_info.value.status_code == 404 + + +@pytest.mark.asyncio +async def test_cancel_job_wrong_status(): + job_repo = AsyncMock() + job = MagicMock() + job.status = SyncJobStatus.COMPLETED + job_repo.get.return_value = job + + svc = _build_svc(sync_job_repo=job_repo) + with pytest.raises(HTTPException) as exc_info: + await svc.cancel_job(AsyncMock(), job_id=uuid4(), ctx=_mock_ctx()) + assert exc_info.value.status_code == 400 + + +@pytest.mark.asyncio +async def test_cancel_job_success(): + job_id = uuid4() + job_repo = AsyncMock() + job = _orm_sync_job(job_id=job_id, status=SyncJobStatus.RUNNING) + job_repo.get.return_value = job + + temporal = AsyncMock() + temporal.cancel_sync_job_workflow.return_value = { + "success": True, + "workflow_found": True, + } + + job_sm = AsyncMock() + db = AsyncMock() + + svc = _build_svc( + sync_job_repo=job_repo, + temporal_workflow_service=temporal, + job_state_machine=job_sm, + ) + result = await svc.cancel_job(db, job_id=job_id, ctx=_mock_ctx()) + + job_sm.transition.assert_awaited_once() + assert job_sm.transition.call_args.kwargs["target"] == SyncJobStatus.CANCELLING + temporal.cancel_sync_job_workflow.assert_awaited_once() + assert result is not None + + +@pytest.mark.asyncio +async def test_cancel_pending_job_transitions_directly_to_cancelled(): + job_id = uuid4() + job_repo = AsyncMock() + job = _orm_sync_job(job_id=job_id, status=SyncJobStatus.PENDING) + job_repo.get.return_value = job + + temporal = AsyncMock() + job_sm = AsyncMock() + db = AsyncMock() + + svc = _build_svc( + sync_job_repo=job_repo, + temporal_workflow_service=temporal, + job_state_machine=job_sm, + ) + await svc.cancel_job(db, job_id=job_id, ctx=_mock_ctx()) + + job_sm.transition.assert_awaited_once() + assert job_sm.transition.call_args.kwargs["target"] == SyncJobStatus.CANCELLED + temporal.cancel_sync_job_workflow.assert_awaited_once() + db.refresh.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_cancel_job_workflow_not_found_marks_cancelled(): + job_id = uuid4() + job_repo = AsyncMock() + job = _orm_sync_job(job_id=job_id, status=SyncJobStatus.RUNNING) + job_repo.get.return_value = job + + temporal = AsyncMock() + temporal.cancel_sync_job_workflow.return_value = { + "success": True, + "workflow_found": False, + } + + job_sm = AsyncMock() + db = AsyncMock() + + svc = _build_svc( + sync_job_repo=job_repo, + temporal_workflow_service=temporal, + job_state_machine=job_sm, + ) + await svc.cancel_job(db, job_id=job_id, ctx=_mock_ctx()) + + assert job_sm.transition.await_count == 2 + first_call = job_sm.transition.call_args_list[0].kwargs + assert first_call["target"] == SyncJobStatus.CANCELLING + second_call = job_sm.transition.call_args_list[1].kwargs + assert second_call["target"] == SyncJobStatus.CANCELLED + + +@pytest.mark.asyncio +async def test_cancel_job_temporal_failure(): + job_id = uuid4() + job_repo = AsyncMock() + job = MagicMock() + job.id = job_id + job.status = SyncJobStatus.RUNNING + job_repo.get.return_value = job + + temporal = AsyncMock() + temporal.cancel_sync_job_workflow.return_value = { + "success": False, + "workflow_found": True, + } + + svc = _build_svc(sync_job_repo=job_repo, temporal_workflow_service=temporal) + with pytest.raises(HTTPException) as exc_info: + await svc.cancel_job(AsyncMock(), job_id=job_id, ctx=_mock_ctx()) + assert exc_info.value.status_code == 502 + + +# --------------------------------------------------------------------------- +# _resolve_cron() / _validate_cron_for_source() +# --------------------------------------------------------------------------- + + +def _mock_source_entry(*, short_name="github", continuous=False, federated=False): + entry = MagicMock() + entry.short_name = short_name + entry.supports_continuous = continuous + entry.federated_search = federated + return entry + + +def test_resolve_cron_explicit(): + from airweave.schemas.source_connection import ScheduleConfig + + svc = _build_svc() + result = svc._resolve_cron( + ScheduleConfig(cron="0 6 * * *"), + _mock_source_entry(), + _mock_ctx(), + ) + assert result == "0 6 * * *" + + +def test_resolve_cron_explicit_null(): + from airweave.schemas.source_connection import ScheduleConfig + + svc = _build_svc() + result = svc._resolve_cron( + ScheduleConfig(cron=None), + _mock_source_entry(), + _mock_ctx(), + ) + assert result is None + + +def test_resolve_cron_continuous_default(): + from airweave.domains.syncs.types import CONTINUOUS_SOURCE_DEFAULT_CRON + + svc = _build_svc() + result = svc._resolve_cron(None, _mock_source_entry(continuous=True), _mock_ctx()) + assert result == CONTINUOUS_SOURCE_DEFAULT_CRON + + +def test_resolve_cron_daily_default(): + svc = _build_svc() + result = svc._resolve_cron(None, _mock_source_entry(), _mock_ctx()) + assert result is not None + parts = result.split() + assert len(parts) == 5 + assert parts[2:] == ["*", "*", "*"] + + +def test_validate_cron_allows_continuous(): + svc = _build_svc() + svc._validate_cron_for_source("* * * * *", _mock_source_entry(continuous=True)) + + +def test_validate_cron_rejects_every_minute(): + svc = _build_svc() + with pytest.raises(HTTPException) as exc_info: + svc._validate_cron_for_source("* * * * *", _mock_source_entry()) + assert exc_info.value.status_code == 400 + + +def test_validate_cron_rejects_sub_hourly(): + svc = _build_svc() + with pytest.raises(HTTPException) as exc_info: + svc._validate_cron_for_source("*/5 * * * *", _mock_source_entry()) + assert exc_info.value.status_code == 400 + + +# --------------------------------------------------------------------------- +# create() — federated / no-schedule / happy path +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_create_skips_federated(): + svc = _build_svc() + with pytest.raises(ValueError, match="federated"): + await svc.create( + AsyncMock(), + name="test", + source_connection_id=uuid4(), + destination_connection_ids=[uuid4()], + collection_id=uuid4(), + collection_readable_id="col-x", + source_entry=_mock_source_entry(federated=True), + schedule_config=None, + run_immediately=False, + ctx=_mock_ctx(), + uow=MagicMock(), + ) + + +@pytest.mark.asyncio +async def test_create_no_cron_no_run_immediately(): + from airweave.schemas.source_connection import ScheduleConfig + + svc = _build_svc() + with pytest.raises(ValueError, match="no schedule"): + await svc.create( + AsyncMock(), + name="test", + source_connection_id=uuid4(), + destination_connection_ids=[uuid4()], + collection_id=uuid4(), + collection_readable_id="col-x", + source_entry=_mock_source_entry(), + schedule_config=ScheduleConfig(cron=None), + run_immediately=False, + ctx=_mock_ctx(), + uow=MagicMock(), + ) + + +@pytest.mark.asyncio +async def test_create_with_cron_calls_temporal_schedule(): + from airweave.schemas.source_connection import ScheduleConfig + + sync_repo = AsyncMock() + mock_sync = MagicMock() + mock_sync.id = uuid4() + sync_repo.create.return_value = mock_sync + + job_repo = AsyncMock() + temporal_sched = AsyncMock() + + uow = MagicMock() + uow.session = AsyncMock() + uow.commit = AsyncMock() + + svc = _build_svc( + sync_repo=sync_repo, + sync_job_repo=job_repo, + temporal_schedule_service=temporal_sched, + ) + + result = await svc.create( + AsyncMock(), + name="test", + source_connection_id=uuid4(), + destination_connection_ids=[uuid4()], + collection_id=uuid4(), + collection_readable_id="col-x", + source_entry=_mock_source_entry(), + schedule_config=ScheduleConfig(cron="0 6 * * *"), + run_immediately=False, + ctx=_mock_ctx(), + uow=uow, + ) + assert result is not None + assert result.sync_id == mock_sync.id + temporal_sched.create_or_update_schedule.assert_awaited_once() + + +# --------------------------------------------------------------------------- +# delete() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_delete_delegates(): + svc = _build_svc() + svc._cancel_active_sync = AsyncMock(return_value=True) + svc._wait_for_terminal = AsyncMock() + svc._schedule_cleanup = AsyncMock() + + await svc.delete( + AsyncMock(), + sync_id=uuid4(), + collection_id=uuid4(), + organization_id=uuid4(), + ctx=_mock_ctx(), + ) + svc._cancel_active_sync.assert_awaited_once() + svc._wait_for_terminal.assert_awaited_once() + svc._schedule_cleanup.assert_awaited_once() + + +# --------------------------------------------------------------------------- +# _cancel_active_sync() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_cancel_active_sync_cancels_running_job(): + sync_id = uuid4() + job = MagicMock() + job.id = uuid4() + job.status = SyncJobStatus.RUNNING + + job_repo = AsyncMock() + job_repo.get_latest_by_sync_id.return_value = job + + temporal = AsyncMock() + svc = _build_svc(sync_job_repo=job_repo, temporal_workflow_service=temporal) + + result = await svc._cancel_active_sync(AsyncMock(), sync_id, _mock_ctx()) + assert result is True + temporal.cancel_sync_job_workflow.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_cancel_active_sync_skips_terminal(): + job = MagicMock() + job.status = SyncJobStatus.COMPLETED + + job_repo = AsyncMock() + job_repo.get_latest_by_sync_id.return_value = job + + svc = _build_svc(sync_job_repo=job_repo) + result = await svc._cancel_active_sync(AsyncMock(), uuid4(), _mock_ctx()) + assert result is False + + +# --------------------------------------------------------------------------- +# _wait_for_terminal() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_wait_for_terminal_returns_when_no_job(): + job_repo = AsyncMock() + job_repo.get_latest_by_sync_id.return_value = None + svc = _build_svc(sync_job_repo=job_repo) + db = MagicMock() + with patch("airweave.domains.syncs.service.asyncio.sleep", new_callable=AsyncMock): + await svc._wait_for_terminal(db, uuid4(), 5, _mock_ctx()) + job_repo.get_latest_by_sync_id.assert_awaited() + + +# --------------------------------------------------------------------------- +# _schedule_cleanup() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_schedule_cleanup_calls_temporal(): + temporal = AsyncMock() + svc = _build_svc(temporal_workflow_service=temporal) + await svc._schedule_cleanup(uuid4(), uuid4(), uuid4(), _mock_ctx()) + temporal.start_cleanup_sync_data_workflow.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_schedule_cleanup_handles_error(): + temporal = AsyncMock() + temporal.start_cleanup_sync_data_workflow.side_effect = RuntimeError("boom") + + svc = _build_svc(temporal_workflow_service=temporal) + ctx = _mock_ctx() + await svc._schedule_cleanup(uuid4(), uuid4(), uuid4(), ctx) + ctx.logger.error.assert_called_once() diff --git a/backend/airweave/domains/temporal/activities/tests/test_transition_sync_job.py b/backend/airweave/domains/temporal/activities/tests/test_transition_sync_job.py index 706d43ce1..7d0ee6a9f 100644 --- a/backend/airweave/domains/temporal/activities/tests/test_transition_sync_job.py +++ b/backend/airweave/domains/temporal/activities/tests/test_transition_sync_job.py @@ -1,6 +1,6 @@ """Tests for TransitionSyncJobActivity.""" -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock from uuid import UUID import pytest @@ -8,7 +8,6 @@ from airweave.core.shared_models import SyncJobStatus from airweave.domains.temporal.activities.transition_sync_job import ( TransitionSyncJobActivity, - _STATUS_MAP, ) from .conftest import ORG_ID, SYNC_ID, SYNC_JOB_ID, make_ctx_dict @@ -100,6 +99,27 @@ async def test_failed_transition_with_error(activity, state_machine): assert call["error"] == "Something went wrong" +@pytest.mark.unit +async def test_cancelling_transition(activity, state_machine): + lifecycle = { + "organization_id": ORG_ID, + "sync_id": SYNC_ID, + "sync_job_id": SYNC_JOB_ID, + "collection_id": "00000000-0000-0000-0000-000000000030", + "source_connection_id": "00000000-0000-0000-0000-000000000050", + } + + await activity.run( + transition="cancelling", + sync_job_id=SYNC_JOB_ID, + ctx_dict=make_ctx_dict(), + lifecycle_data=lifecycle, + ) + + call = state_machine.calls[0] + assert call["target"] == SyncJobStatus.CANCELLING + + @pytest.mark.unit async def test_cancelled_transition(activity, state_machine): lifecycle = { diff --git a/backend/airweave/domains/temporal/activities/transition_sync_job.py b/backend/airweave/domains/temporal/activities/transition_sync_job.py index cf8241259..1dad9f77e 100644 --- a/backend/airweave/domains/temporal/activities/transition_sync_job.py +++ b/backend/airweave/domains/temporal/activities/transition_sync_job.py @@ -1,6 +1,6 @@ """Transition sync job activity — thin Temporal wrapper over SyncJobStateMachine. -Called by the workflow for COMPLETED, FAILED, and CANCELLED transitions. +Called by the workflow for CANCELLING, COMPLETED, FAILED, and CANCELLED transitions. Deserializes Temporal payloads and delegates to the state machine. """ @@ -19,6 +19,7 @@ from airweave.domains.temporal.activities.context import build_activity_context _STATUS_MAP: dict[str, SyncJobStatus] = { + "cancelling": SyncJobStatus.CANCELLING, "completed": SyncJobStatus.COMPLETED, "failed": SyncJobStatus.FAILED, "cancelled": SyncJobStatus.CANCELLED, @@ -46,10 +47,10 @@ async def run( stats_dict: Optional[Dict[str, Any]] = None, timestamp_iso: Optional[str] = None, ) -> None: - """Execute a terminal state transition via the state machine. + """Execute a state transition via the state machine. Args: - transition: One of "completed", "failed", "cancelled". + transition: One of "cancelling", "completed", "failed", "cancelled". sync_job_id: The sync job UUID as a string. ctx_dict: Serialized context dict (contains organization). lifecycle_data: Fields for building LifecycleData. diff --git a/backend/airweave/domains/temporal/workflows/run_source_connection.py b/backend/airweave/domains/temporal/workflows/run_source_connection.py index 1d33ae422..d79a50412 100644 --- a/backend/airweave/domains/temporal/workflows/run_source_connection.py +++ b/backend/airweave/domains/temporal/workflows/run_source_connection.py @@ -1,6 +1,6 @@ """Run source connection workflow — the sync state machine. -Owns terminal state transitions (COMPLETED, FAILED, CANCELLED) via +Owns state transitions (CANCELLING, COMPLETED, FAILED, CANCELLED) via TransitionSyncJobActivity. RUNNING is published by the orchestrator because only it knows when sync work actually begins. """ @@ -95,7 +95,9 @@ async def run( ) except BaseException as e: if is_cancelled_exception(e): - await self._transition("cancelled", sync_job_dict, ctx_dict, lifecycle, shield=True) + cancel_args = (sync_job_dict, ctx_dict, lifecycle) + await self._transition("cancelling", *cancel_args, shield=True) + await self._transition("cancelled", *cancel_args, shield=True) raise if self._is_orphaned_sync_error(e): reason = self._extract_orphaned_reason(e) @@ -203,7 +205,12 @@ async def _transition( error: Optional[str] = None, shield: bool = False, ) -> None: - """Call TransitionSyncJobActivity for a terminal state change.""" + """Call TransitionSyncJobActivity for a state change. + + Shielded transitions (cancel path) are best-effort retries — the + orchestrator already performed these transitions, so failures here + are expected and logged at debug level. + """ timestamp = workflow.now().replace(tzinfo=None).isoformat() coro = workflow.execute_activity( transition_sync_job_activity, @@ -227,9 +234,8 @@ async def _transition( try: await (asyncio.shield(coro) if shield else coro) except Exception: - workflow.logger.warning( - f"Failed to transition sync job {sync_job_dict.get('id')} to {transition}" - ) + log = workflow.logger.debug if shield else workflow.logger.warning + log(f"Failed to transition sync job {sync_job_dict.get('id')} to {transition}") # ------------------------------------------------------------------ # Self-destruct orphaned sync diff --git a/backend/airweave/platform/configs/auth.py b/backend/airweave/platform/configs/auth.py index 8659d872e..6669ea1ee 100644 --- a/backend/airweave/platform/configs/auth.py +++ b/backend/airweave/platform/configs/auth.py @@ -752,6 +752,52 @@ class ShopifyAuthConfig(AuthConfig): ) +class SharePointOnlineAppAuthConfig(AuthConfig): + """SharePoint Online app-only authentication using client credentials. + + Uses client_id + client_secret for Microsoft Graph API calls, + and a certificate private key for SharePoint REST API calls. + Requires an Azure AD app registration with application permissions + and admin consent. + """ + + tenant_id: str = Field( + title="Tenant ID", + description="Azure AD tenant ID (e.g., 'contoso.onmicrosoft.com' or a UUID)", + min_length=1, + ) + client_id: str = Field( + title="Client ID", + description="Application (client) ID from the Azure AD app registration", + min_length=1, + ) + client_secret: str = Field( + title="Client Secret", + description="Client secret from the Azure AD app registration (for Graph API)", + min_length=1, + json_schema_extra={"is_secret": True}, + ) + private_key: str = Field( + title="Private Key (PEM)", + description=( + "PEM-encoded private key for certificate authentication " + "(for SharePoint REST API). Starts with '-----BEGIN PRIVATE KEY-----'" + ), + min_length=1, + json_schema_extra={"is_secret": True}, + ) + certificate: str = Field( + default="", + title="Certificate (PEM)", + description=( + "PEM-encoded certificate that was uploaded to the Azure AD app registration. " + "Used to compute the x5t thumbprint for SP REST API token exchange. " + "If omitted, SP site group expansion will not work." + ), + json_schema_extra={"is_secret": True}, + ) + + class ServiceNowAuthConfig(AuthConfig): """ServiceNow instance authentication credentials schema. @@ -974,6 +1020,62 @@ class PipedreamAuthConfig(AuthConfig): ) +class CustomAuthConfig(AuthConfig): + """Custom Auth Provider authentication credentials schema. + + Stores a base URL and API key for a customer-hosted token endpoint. + Airweave calls GET {base_url}/{source_connection_id} to fetch credentials + for each source connection. The customer is responsible for returning fresh tokens. + + Contract: + - GET {base_url} must return 2xx (used for validation) + - GET {base_url}/{source_connection_id} must return JSON with credentials + - source_connection_id is the UUID of the source connection in Airweave + - All requests include an X-API-Key header for authentication + - No refresh_token needed — Airweave re-fetches automatically + + Response format (only two shapes): + - Token-based sources (OAuth, PATs): {"access_token": "..."} + - API key sources (Stripe, etc.): {"api_key": "..."} + """ + + base_endpoint_url: str = Field( + title="Endpoint Base URL", + description=( + "Base URL of your token endpoint. " + "Airweave calls GET {base_url}/{source_connection_id} for each " + "source connection and expects a JSON response with the credential " + 'the source needs (e.g. {"access_token": "..."} for OAuth sources). ' + "No refresh_token needed — Airweave re-fetches automatically. " + "GET {base_url} must return 2xx for validation." + ), + ) + api_key: str = Field( + title="API Key", + description=( + "API key sent as X-API-Key header to authenticate all requests to your endpoint." + ), + ) + + @field_validator("base_endpoint_url") + @classmethod + def validate_base_endpoint_url(cls, v: str) -> str: + """Validate the endpoint URL for SSRF safety.""" + if not v or not v.strip(): + raise ValueError("base_endpoint_url is required") + v = v.strip().rstrip("/") + validate_url(v) + return v + + @field_validator("api_key") + @classmethod + def validate_api_key(cls, v: str) -> str: + """Validate that api_key is non-empty.""" + if not v or not v.strip(): + raise ValueError("api_key is required") + return v.strip() + + class ZohoCRMAuthConfig(OAuth2WithRefreshAuthConfig): """Zoho CRM authentication credentials schema.""" diff --git a/backend/airweave/platform/configs/config.py b/backend/airweave/platform/configs/config.py index 70ca32ac5..db091cf05 100644 --- a/backend/airweave/platform/configs/config.py +++ b/backend/airweave/platform/configs/config.py @@ -1130,6 +1130,12 @@ class ComposioConfig(AuthProviderConfig): ) +class CustomConfig(AuthProviderConfig): + """Custom Auth Provider configuration schema.""" + + pass + + class PipedreamConfig(AuthProviderConfig): """Pipedream Auth Provider configuration schema.""" @@ -1194,24 +1200,20 @@ class SharePointOnlineConfig(SourceConfig): default="", title="SharePoint Site URL", description=( - "URL of the SharePoint site(s) to sync. Supports a single URL " - "(e.g., 'https://contoso.sharepoint.com/sites/Marketing'), " - "comma-separated URLs for multiple sites, or leave empty to " - "sync all accessible sites." + "URL of a specific SharePoint site to sync " + "(e.g., 'https://contoso.sharepoint.com/sites/Marketing'). " + "Leave empty to sync all sites in the tenant." ), ) @field_validator("site_url") @classmethod def validate_site_url_ssrf(cls, v: str) -> str: - """Validate each comma-separated site URL for SSRF safety.""" + """Validate site URL for SSRF safety.""" if not v: return v - for url in v.split(","): - url = url.strip() - if url: - validate_url(url) - return v + validate_url(v.strip()) + return v.strip() include_personal_sites: bool = Field( default=False, diff --git a/backend/airweave/platform/decorators.py b/backend/airweave/platform/decorators.py index 7c2b4b9dd..25fbdec1b 100644 --- a/backend/airweave/platform/decorators.py +++ b/backend/airweave/platform/decorators.py @@ -198,6 +198,7 @@ def auth_provider( short_name: str, auth_config_class: Type[BaseModel], config_class: Type[BaseModel], + feature_flag: Optional[str] = None, ) -> Callable[[type[_AuthProviderT]], type[_AuthProviderT]]: """Class decorator to mark a class as representing an Airweave auth provider. @@ -207,6 +208,7 @@ def auth_provider( short_name (str): The short name of the auth provider. auth_config_class (Type[BaseModel]): The authentication config class. config_class (Type[BaseModel]): The configuration class. + feature_flag (Optional[str]): Optional feature flag required to access this provider. Returns: ------- @@ -220,6 +222,7 @@ def decorator(cls: type[_AuthProviderT]) -> type[_AuthProviderT]: cls.short_name = short_name cls.auth_config_class = auth_config_class cls.config_class = config_class + cls.feature_flag = feature_flag return cls return decorator diff --git a/backend/airweave/platform/destinations/vespa/destination.py b/backend/airweave/platform/destinations/vespa/destination.py index 7e006c665..b92cd14d7 100644 --- a/backend/airweave/platform/destinations/vespa/destination.py +++ b/backend/airweave/platform/destinations/vespa/destination.py @@ -112,10 +112,12 @@ async def create( instance.organization_id = organization_id # Initialize components + source_supports_acl = kwargs.get("source_supports_acl", False) instance._client = await VespaClient.connect(logger=instance.logger) instance._transformer = EntityTransformer( collection_id=collection_id, logger=instance.logger, + source_supports_acl=source_supports_acl, ) instance._query_builder = QueryBuilder() diff --git a/backend/airweave/platform/destinations/vespa/transformer.py b/backend/airweave/platform/destinations/vespa/transformer.py index 378ce214e..d118f6d62 100644 --- a/backend/airweave/platform/destinations/vespa/transformer.py +++ b/backend/airweave/platform/destinations/vespa/transformer.py @@ -25,7 +25,7 @@ def _sanitize_for_vespa(text: str) -> str: - """Sanitize text for Vespa by removing illegal characters. + r"""Sanitize text for Vespa by removing illegal characters. Vespa strictly rejects: 1. Control characters (code points < 32) except \n (0x0A), \r (0x0D), \t (0x09) @@ -167,15 +167,20 @@ def __init__( self, collection_id: Optional[UUID] = None, logger: Optional[ContextualLogger] = None, + source_supports_acl: bool = False, ): """Initialize the entity transformer. Args: collection_id: SQL collection UUID for multi-tenant filtering logger: Optional logger for debug/warning messages + source_supports_acl: Whether the source supports access control. + When True, entities with no access data default to invisible (fail-closed). + When False, entities default to public (visible to all). """ self.collection_id = collection_id self._logger = logger or default_logger + self._source_supports_acl = source_supports_acl def transform(self, entity: BaseEntity) -> VespaDocument: """Transform a single entity to Vespa document format. @@ -343,13 +348,17 @@ def _add_access_control_fields(self, fields: Dict[str, Any], entity: BaseEntity) """Add access control fields. Always sets access control fields with appropriate defaults: - - AC-enabled sources: Use the actual ACL values from entity.access - - Non-AC sources: Set is_public=True so entities are visible to everyone + - If entity has access data: use the actual ACL values + - AC-enabled source but no access data: fail-closed (invisible) + - Non-AC source: default to public (visible to everyone) """ access = getattr(entity, "access", None) if access is not None: fields["access_is_public"] = access.is_public fields["access_viewers"] = access.viewers if access.viewers else [] + elif self._source_supports_acl: + fields["access_is_public"] = False + fields["access_viewers"] = [] else: fields["access_is_public"] = True fields["access_viewers"] = [] diff --git a/backend/airweave/platform/sources/__init__.py b/backend/airweave/platform/sources/__init__.py index 84b97bd74..0dc48be82 100644 --- a/backend/airweave/platform/sources/__init__.py +++ b/backend/airweave/platform/sources/__init__.py @@ -50,7 +50,7 @@ from .servicenow import ServiceNowSource from .sharepoint import SharePointSource from .sharepoint2019v2.source import SharePoint2019V2Source -from .sharepoint_online.source import SharePointOnlineSource +from .sharepoint_online.source import SharePointOnlineAppSource, SharePointOnlineSource from .shopify import ShopifySource from .slab import SlabSource from .slack import SlackSource @@ -117,6 +117,7 @@ SharePointSource, SharePoint2019V2Source, SharePointOnlineSource, + SharePointOnlineAppSource, ShopifySource, SlabSource, SliteSource, diff --git a/backend/airweave/platform/sources/sharepoint_online/__init__.py b/backend/airweave/platform/sources/sharepoint_online/__init__.py index c6b62d75e..a8701157f 100644 --- a/backend/airweave/platform/sources/sharepoint_online/__init__.py +++ b/backend/airweave/platform/sources/sharepoint_online/__init__.py @@ -1,8 +1,12 @@ """SharePoint Online source connector. Uses Microsoft Graph API for content sync and Entra ID for access control. +Two variants: OAuth (delegated) and App (client credentials). """ -from airweave.platform.sources.sharepoint_online.source import SharePointOnlineSource +from airweave.platform.sources.sharepoint_online.source import ( + SharePointOnlineAppSource, + SharePointOnlineSource, +) -__all__ = ["SharePointOnlineSource"] +__all__ = ["SharePointOnlineSource", "SharePointOnlineAppSource"] diff --git a/backend/airweave/platform/sources/sharepoint_online/acl.py b/backend/airweave/platform/sources/sharepoint_online/acl.py index 404cc0d6a..b770540fe 100644 --- a/backend/airweave/platform/sources/sharepoint_online/acl.py +++ b/backend/airweave/platform/sources/sharepoint_online/acl.py @@ -6,7 +6,13 @@ - grantedToV2.user → user:{email} - grantedToV2.group (Entra ID) → group:entra:{group_id} - grantedToV2.siteGroup → group:sp:{site_group_name} -- link with scope "organization" → is_public (org-wide access) +- link with scope "anonymous" → is_public (true tenant-wide / internet-wide access) +- link with scope "organization" or "users" → group:sp:sharinglinks.{itemId}.{scopeRole}.{linkId} + derived from the permission and the file's SP UniqueId. Microsoft represents + these links as a ``link`` permission rather than a ``siteGroup`` grant, but + internally tracks redeemers as members of a ``SharingLinks...`` + SP site group. We translate so that membership intersection at search time + works for users who have actually redeemed the link. """ from typing import Any, Dict, List, Optional @@ -76,20 +82,66 @@ def has_read_permission(permission: Dict[str, Any]) -> bool: return any(r in ("read", "write", "owner", "sp.full control") for r in roles) -def is_org_wide_link(permission: Dict[str, Any]) -> bool: - """Check if a permission is an organization-wide sharing link.""" +def is_anonymous_link(permission: Dict[str, Any]) -> bool: + """Check if a permission is an anonymous sharing link.""" link = permission.get("link") if not link: return False - return link.get("scope", "") == "organization" + return link.get("scope", "") == "anonymous" -def is_anonymous_link(permission: Dict[str, Any]) -> bool: - """Check if a permission is an anonymous sharing link.""" +# Mapping from Graph (link.scope, link.type) to SharePoint's SharingLinks group +# suffix. Verified empirically against neenacorp.sharepoint.com: +# organization+edit → OrganizationEdit +# organization+view → OrganizationView +# users+edit / users+view → Flexible (both collapse; SP stores role separately) +# Anonymous is handled by ``is_public`` and does not need a derived group. +_SCOPE_ROLE_MAP: Dict[tuple, str] = { + ("organization", "edit"): "OrganizationEdit", + ("organization", "view"): "OrganizationView", + ("users", "edit"): "Flexible", + ("users", "view"): "Flexible", +} + + +def link_permission_to_sp_group_viewer( + permission: Dict[str, Any], sp_unique_id: Optional[str] +) -> Optional[str]: + """Derive the SharingLinks SP site group viewer for a non-anonymous link permission. + + SharePoint creates an internal site group named + ``SharingLinks...`` for each sharing + link, whose members are the users who have redeemed the link. The Graph + per-item permissions response represents the link itself but does *not* + return that site group as a separate ``siteGroup`` grant, so we translate. + + Args: + permission: A Graph permission with a ``link`` block. + sp_unique_id: The file's SharePoint UniqueId (lowercase GUID, no + braces). Pass ``None`` for site/drive-level permissions, where + sharing-link translation does not apply. + + Returns: + ``group:sp:sharinglinks...`` viewer string, or + ``None`` if the permission isn't a translatable link or required + fields are missing. + """ + if not sp_unique_id: + return None link = permission.get("link") if not link: - return False - return link.get("scope", "") == "anonymous" + return None + scope = link.get("scope", "") + if scope == "anonymous": + return None # handled by is_public + scope_role = _SCOPE_ROLE_MAP.get((scope, link.get("type", ""))) + if not scope_role: + return None # unknown scope/type combination — be conservative + link_id = permission.get("id", "") + if not link_id: + return None + title = f"SharingLinks.{sp_unique_id}.{scope_role}.{link_id}" + return f"group:sp:{title.lower()}" def _extract_identity_principals(perm: Dict[str, Any], viewers: List[str]) -> None: @@ -106,11 +158,17 @@ def _extract_identity_principals(perm: Dict[str, Any], viewers: List[str]) -> No async def extract_access_control( permissions: List[Dict[str, Any]], + sp_unique_id: Optional[str] = None, ) -> AccessControl: """Build AccessControl from Graph API permissions. Args: permissions: List of permission objects from Graph API. + sp_unique_id: The SharePoint UniqueId of the item the permissions + belong to (lowercase GUID, no braces). Required to translate + non-anonymous sharing-link permissions into their corresponding + ``SharingLinks.*`` SP site group viewer. Pass ``None`` for + site/drive-level permission lists. Returns: AccessControl with viewers and is_public flag. @@ -122,10 +180,17 @@ async def extract_access_control( if not has_read_permission(perm): continue - if is_org_wide_link(perm) or is_anonymous_link(perm): + if is_anonymous_link(perm): is_public = True continue + # Non-anonymous sharing links: translate to the per-link SP site group. + link_viewer = link_permission_to_sp_group_viewer(perm, sp_unique_id) + if link_viewer: + if link_viewer not in viewers: + viewers.append(link_viewer) + continue + principal = extract_principal_from_permission(perm) if principal and principal not in viewers: viewers.append(principal) diff --git a/backend/airweave/platform/sources/sharepoint_online/builders.py b/backend/airweave/platform/sources/sharepoint_online/builders.py index d4d1cb59f..eb8412b6a 100644 --- a/backend/airweave/platform/sources/sharepoint_online/builders.py +++ b/backend/airweave/platform/sources/sharepoint_online/builders.py @@ -9,11 +9,10 @@ from typing import Any, Dict, List, Optional from airweave.domains.sync_pipeline.exceptions import EntityProcessingError -from airweave.platform.entities._base import Breadcrumb +from airweave.platform.entities._base import AccessControl, Breadcrumb from airweave.platform.entities.sharepoint_online import ( SharePointOnlineDriveEntity, SharePointOnlineFileEntity, - SharePointOnlineItemEntity, SharePointOnlinePageEntity, SharePointOnlineSiteEntity, ) @@ -32,6 +31,7 @@ def _parse_datetime(dt_str: Optional[str]) -> Optional[datetime]: async def build_site_entity( site_data: Dict[str, Any], breadcrumbs: List[Breadcrumb], + access: Optional[AccessControl] = None, ) -> SharePointOnlineSiteEntity: """Build a site entity from Graph API site data.""" site_id = site_data.get("id") @@ -53,6 +53,7 @@ async def build_site_entity( created_at=_parse_datetime(site_data.get("createdDateTime")), last_modified_at=_parse_datetime(site_data.get("lastModifiedDateTime")), breadcrumbs=breadcrumbs, + access=access, ) @@ -60,6 +61,7 @@ async def build_drive_entity( drive_data: Dict[str, Any], site_id: str, breadcrumbs: List[Breadcrumb], + access: Optional[AccessControl] = None, ) -> SharePointOnlineDriveEntity: """Build a drive entity from Graph API drive data.""" drive_id = drive_data.get("id") @@ -84,6 +86,7 @@ async def build_drive_entity( created_at=_parse_datetime(drive_data.get("createdDateTime")), last_modified_at=_parse_datetime(drive_data.get("lastModifiedDateTime")), breadcrumbs=breadcrumbs, + access=access, ) @@ -93,8 +96,22 @@ async def build_file_entity( site_id: str, breadcrumbs: List[Breadcrumb], permissions: Optional[List[Dict[str, Any]]] = None, + sp_unique_id: Optional[str] = None, ) -> SharePointOnlineFileEntity: - """Build a file entity from Graph API drive item data.""" + """Build a file entity from Graph API drive item data. + + Args: + item_data: Graph drive item dict. + drive_id: Drive ID containing the item. + site_id: Site ID the drive belongs to. + breadcrumbs: Hierarchy breadcrumbs. + permissions: Optional permissions list from + ``/drives/{id}/items/{id}/permissions``. + sp_unique_id: Optional SharePoint ``listItemUniqueId`` for the item. + Required to translate sharing-link permissions; the caller should + fetch it via :meth:`GraphClient.get_item_sp_unique_id` when any + of the item's permissions has a ``link`` block. + """ item_id = item_data.get("id") if not item_id: raise EntityProcessingError("Missing id for file item") @@ -129,7 +146,7 @@ async def build_file_entity( download_url = item_data.get("@microsoft.graph.downloadUrl", "") spo_entity_id = f"spo:file:{drive_id}:{item_id}" - access = await extract_access_control(permissions or []) if permissions else None + access = await extract_access_control(permissions or [], sp_unique_id) if permissions else None return SharePointOnlineFileEntity( url=download_url or item_data.get("webUrl", ""), @@ -154,46 +171,11 @@ async def build_file_entity( ) -async def build_item_entity( - item_data: Dict[str, Any], - site_id: str, - list_id: str, - breadcrumbs: List[Breadcrumb], -) -> SharePointOnlineItemEntity: - """Build a list item entity from Graph API list item data.""" - item_id = item_data.get("id") - if not item_id: - raise EntityProcessingError("Missing id for list item") - - fields = item_data.get("fields", {}) or {} - title = fields.get("Title") or fields.get("title") or item_data.get("id", "Untitled") - - content_type = None - ct_obj = item_data.get("contentType") - if ct_obj: - content_type = ct_obj.get("name") - - spo_entity_id = f"spo:item:{site_id}:{list_id}:{item_id}" - - return SharePointOnlineItemEntity( - spo_entity_id=spo_entity_id, - item_id=item_id, - list_id=list_id, - site_id=site_id, - title=title, - web_url=item_data.get("webUrl", ""), - content_type=content_type, - fields=fields, - created_at=_parse_datetime(item_data.get("createdDateTime")), - updated_at=_parse_datetime(item_data.get("lastModifiedDateTime")), - breadcrumbs=breadcrumbs, - ) - - async def build_page_entity( page_data: Dict[str, Any], site_id: str, breadcrumbs: List[Breadcrumb], + access: Optional[AccessControl] = None, ) -> SharePointOnlinePageEntity: """Build a page entity from Graph API site page data.""" page_id = page_data.get("id") @@ -214,4 +196,5 @@ async def build_page_entity( created_at=_parse_datetime(page_data.get("createdDateTime")), updated_at=_parse_datetime(page_data.get("lastModifiedDateTime")), breadcrumbs=breadcrumbs, + access=access, ) diff --git a/backend/airweave/platform/sources/sharepoint_online/client.py b/backend/airweave/platform/sources/sharepoint_online/client.py index a93e20dde..fcfd945ff 100644 --- a/backend/airweave/platform/sources/sharepoint_online/client.py +++ b/backend/airweave/platform/sources/sharepoint_online/client.py @@ -137,6 +137,12 @@ async def search_sites(self, query: str = "*") -> AsyncGenerator[Dict[str, Any], async for site in self.get_paginated(url, params): yield site + async def get_all_sites(self) -> AsyncGenerator[Dict[str, Any], None]: + """Enumerate all sites in the tenant (requires application permissions).""" + url = f"{GRAPH_BASE_URL}/sites/getAllSites" + async for site in self.get_paginated(url): + yield site + async def get_subsites(self, site_id: str) -> AsyncGenerator[Dict[str, Any], None]: """Get subsites of a SharePoint site.""" url = f"{GRAPH_BASE_URL}/sites/{site_id}/sites" @@ -233,11 +239,18 @@ async def get_drive_delta( self, drive_id: str, delta_token: str = "", + prefer_headers: Optional[List[str]] = None, ) -> Tuple[List[Dict[str, Any]], str]: """Get changes since the last delta token. Returns (changed_items, new_delta_token). If delta_token is empty, returns all items (initial sync). + + Args: + drive_id: The drive to query. + delta_token: Continuation token from a previous delta query. + prefer_headers: Optional Prefer header values for app-only delta + (e.g., ["deltashowsharingchanges", "deltashowremovedasdeleted"]). """ if delta_token: url = delta_token # Delta tokens are full URLs @@ -249,7 +262,15 @@ async def get_drive_delta( delta_link = "" while current_url: - data = await self.get(current_url) + if prefer_headers: + headers = await self._headers() + headers["Prefer"] = ", ".join(prefer_headers) + self.logger.debug(f"GET {current_url} (Prefer: {headers['Prefer']})") + response = await self._http_client.get(current_url, headers=headers, timeout=30.0) + response.raise_for_status() + data = response.json() + else: + data = await self.get(current_url) items = data.get("value", []) all_items.extend(items) @@ -282,6 +303,67 @@ async def get_item_permissions( return [] raise + async def list_internal_tenant_users(self) -> AsyncGenerator[Dict[str, str], None]: + """Yield internal tenant members (``userType eq 'Member'``). + + Used to expand the SharePoint "Everyone except external users" claim + into per-user memberships of the synthetic claim group. The Graph + filter excludes guests (``userType eq 'Guest'``), preserving the + claim's "except external" semantics. + + Yields: + Dicts with at least ``email`` (mail or userPrincipalName, + lowercased) and ``display_name``. Users without any addressable + identifier are skipped. + """ + url = ( + f"{GRAPH_BASE_URL}/users" + "?$filter=userType eq 'Member'" + "&$select=id,mail,userPrincipalName,displayName" + ) + async for u in self.get_paginated(url): + email = (u.get("mail") or u.get("userPrincipalName") or "").strip().lower() + if not email: + continue + yield {"email": email, "display_name": u.get("displayName") or email} + + async def get_item_sp_unique_id( + self, + drive_id: str, + item_id: str, + ) -> Optional[str]: + """Fetch the SharePoint ``listItemUniqueId`` (lowercase GUID) for a drive item. + + Used to translate sharing-link permissions into the underlying + ``SharingLinks...`` SP site group viewer. + Only worth calling when the item has at least one ``link`` permission; + for items with only direct grants there's nothing to translate. + """ + url = f"{GRAPH_BASE_URL}/drives/{drive_id}/items/{item_id}?$select=sharepointIds" + try: + data = await self.get(url) + except httpx.HTTPStatusError as e: + if e.response.status_code == 404: + return None + raise + sp_ids = data.get("sharepointIds") or {} + luid = sp_ids.get("listItemUniqueId") + return luid.lower() if luid else None + + async def get_drive_root_permissions( + self, + drive_id: str, + ) -> List[Dict[str, Any]]: + """Get permissions for the root of a drive (site-level permissions).""" + url = f"{GRAPH_BASE_URL}/drives/{drive_id}/root/permissions" + try: + data = await self.get(url) + return data.get("value", []) + except httpx.HTTPStatusError as e: + if e.response.status_code in (404, 403): + return [] + raise + # -- Lists -- async def get_lists(self, site_id: str) -> AsyncGenerator[Dict[str, Any], None]: diff --git a/backend/airweave/platform/sources/sharepoint_online/source.py b/backend/airweave/platform/sources/sharepoint_online/source.py index cfa141458..5ba50ca54 100644 --- a/backend/airweave/platform/sources/sharepoint_online/source.py +++ b/backend/airweave/platform/sources/sharepoint_online/source.py @@ -18,13 +18,18 @@ Incremental sync: - Uses Graph delta queries (/drives/{id}/root/delta) - Per-drive delta tokens stored in cursor + +Two source variants: +- SharePointOnlineSource: OAuth (delegated user auth) +- SharePointOnlineAppSource: Client credentials (app-only auth) """ from __future__ import annotations import asyncio +import re from dataclasses import dataclass -from typing import Any, AsyncGenerator, Callable, Dict, List, Optional +from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Set, Tuple from urllib.parse import urlparse import httpx @@ -34,11 +39,14 @@ from airweave.domains.access_control.schemas import MembershipTuple from airweave.domains.browse_tree.types import BrowseNode, NodeSelectionData from airweave.domains.sources.exceptions import SourceAuthError +from airweave.domains.sources.token_providers.credential import DirectCredentialProvider from airweave.domains.sources.token_providers.protocol import TokenProviderProtocol +from airweave.domains.sources.token_providers.static import StaticTokenProvider from airweave.domains.storage import FileSkippedException from airweave.domains.storage.file_service import FileService from airweave.domains.sync_pipeline.exceptions import EntityProcessingError from airweave.domains.syncs.cursors.cursor import SyncCursor +from airweave.platform.configs.auth import SharePointOnlineAppAuthConfig from airweave.platform.configs.config import SharePointOnlineConfig from airweave.platform.cursors.sharepoint_online import SharePointOnlineCursor from airweave.platform.decorators import source @@ -53,6 +61,7 @@ retry_if_rate_limit_or_timeout, wait_rate_limit_with_backoff, ) +from airweave.platform.sources.sharepoint_online.acl import extract_access_control from airweave.platform.sources.sharepoint_online.builders import ( build_drive_entity, build_file_entity, @@ -66,6 +75,14 @@ MAX_CONCURRENT_FILE_DOWNLOADS = 10 ITEM_BATCH_SIZE = 50 +# Synthetic principal representing the SharePoint "Everyone except external users" +# claim. SP exposes this claim as a member of site groups but our membership +# table only handles real users / Entra groups / SP groups. We translate the +# claim into a synthetic group, populate it with the tenant's internal members +# at sync time, and let the broker's recursive expansion do the rest. +EVERYONE_EXCEPT_EXTERNAL_PRINCIPAL = "claim:everyone_except_external" +EVERYONE_EXCEPT_EXTERNAL_DISPLAY_NAME = "Everyone except external users (synthetic)" + @dataclass class PendingFileDownload: @@ -76,52 +93,88 @@ class PendingFileDownload: item_id: str -@source( - name="SharePoint Online", - short_name="sharepoint_online", - auth_methods=[ - AuthenticationMethod.OAUTH_BROWSER, - AuthenticationMethod.OAUTH_TOKEN, - AuthenticationMethod.AUTH_PROVIDER, - ], - oauth_type=OAuthType.WITH_ROTATING_REFRESH, - auth_config_class=None, - config_class=SharePointOnlineConfig, - supports_continuous=True, - cursor_class=SharePointOnlineCursor, - supports_access_control=True, - supports_browse_tree=True, - feature_flag="sharepoint_2019_v2", - labels=["Collaboration", "File Storage"], -) -class SharePointOnlineSource(BaseSource): - """SharePoint Online source using Microsoft Graph API. +# ============================================================================= +# Base class — shared sync, browse tree, download, and ACL logic +# ============================================================================= + - Syncs sites, drives, files, lists, and pages with full ACL support. - Uses Entra ID for group membership expansion. +class SharePointOnlineBase(BaseSource): + """Shared implementation for SharePoint Online sources. + + Subclasses must implement the auth-specific hooks: + - create() — class constructor + - _get_access_token() — return a valid Microsoft Graph token + - _handle_401() — refresh/re-exchange on 401, return new token + - _make_sp_token_provider_for_site(site_url) — per-site SP REST token provider + - _get_download_auth(url) — auth suitable for file download + - _discover_sites(graph_client) — site discovery strategy """ - @classmethod - async def create( - cls, - *, - auth: TokenProviderProtocol, - logger: ContextualLogger, - http_client: AirweaveHttpClient, - config: SharePointOnlineConfig, - ) -> SharePointOnlineSource: - """Create and configure a SharePoint Online source instance.""" - instance = cls(auth=auth, logger=logger, http_client=http_client) - instance._site_url = config.site_url.rstrip("/") if config.site_url else "" - instance._include_personal_sites = config.include_personal_sites - instance._include_pages = config.include_pages - instance._item_level_entra_groups: set = set() - instance._item_level_sp_groups: set = set() - return instance + # Instance attributes set by _init_common() + _site_url: str + _include_personal_sites: bool + _include_pages: bool + _item_level_entra_groups: Set[str] + # Site-scoped SP group tracking: {site_url: {sp_group_name, ...}} + # Keyed by normalized site URL so multi-site syncs can expand SP groups per site. + _item_level_sp_groups: Dict[str, Set[str]] + # Set to True during membership extraction when an SP group contains the + # "Everyone except external users" claim, so we know to enumerate internal + # tenant users once at the end. + _needs_internal_user_enum: bool + + def _init_common(self, config: SharePointOnlineConfig) -> None: + """Initialize fields shared by both OAuth and client-credentials sources.""" + self._site_url = config.site_url.rstrip("/") if config.site_url else "" + self._include_personal_sites = config.include_personal_sites + self._include_pages = config.include_pages + self._item_level_entra_groups = set() + self._item_level_sp_groups = {} + self._needs_internal_user_enum = False + + # -- Auth hooks (subclasses override) -- + + async def _get_access_token(self) -> str: + """Get a valid Microsoft Graph access token.""" + raise NotImplementedError + + async def _handle_401(self) -> str: + """Handle a 401 by refreshing/re-exchanging. Returns new token.""" + raise NotImplementedError + + def _make_sp_token_provider_for_site(self, site_url: str) -> Optional[Callable]: + """Create an SP REST token provider scoped to a specific site URL. + + Subclasses must override. Returns None if a token cannot be obtained + for the given site (e.g., malformed URL). + """ + raise NotImplementedError + + async def _get_download_auth(self, url: str) -> Any: + """Return an auth object suitable for FileService.download_from_url.""" + return self.auth + + async def _discover_sites(self, graph_client: GraphClient) -> List[Dict[str, Any]]: + """Discover SharePoint sites to sync.""" + raise NotImplementedError + + @property + def _delta_prefer_headers(self) -> List[str]: + """Prefer headers for delta queries (permission change tracking).""" + return [] + + # -- Shared client factories -- def _create_graph_client(self) -> GraphClient: return GraphClient( - access_token_provider=self.auth.get_token, + access_token_provider=self._get_access_token, + http_client=self.http_client, + logger=self.logger, + ) + + def _create_group_expander(self) -> EntraGroupExpander: + return EntraGroupExpander( + access_token_provider=self._get_access_token, http_client=self.http_client, logger=self.logger, ) @@ -134,13 +187,13 @@ def _create_graph_client(self) -> GraphClient: ) async def _get(self, url: str, params: Optional[Dict] = None) -> Dict[str, Any]: """Make an authenticated GET request to Microsoft Graph API.""" - token = await self.auth.get_token() + token = await self._get_access_token() headers = {"Authorization": f"Bearer {token}", "Accept": "application/json"} response = await self.http_client.get(url, headers=headers, params=params) - if response.status_code == 401 and self.auth.supports_refresh: + if response.status_code == 401: self.logger.warning("Received 401 from Microsoft Graph API — refreshing token") - new_token = await self.auth.force_refresh() + new_token = await self._handle_401() headers = {"Authorization": f"Bearer {new_token}", "Accept": "application/json"} response = await self.http_client.get(url, headers=headers, params=params) @@ -151,53 +204,143 @@ async def _get(self, url: str, params: Optional[Dict] = None) -> Dict[str, Any]: ) return response.json() - def _create_group_expander(self) -> EntraGroupExpander: - return EntraGroupExpander( - access_token_provider=self.auth.get_token, - http_client=self.http_client, - logger=self.logger, - ) - - def _derive_sp_resource_scope(self) -> Optional[str]: - """Derive the SharePoint resource scope from the site URL. - - E.g. https://neenacorp.sharepoint.com/sites/JAman - -> https://neenacorp.sharepoint.com/.default - """ + def _derive_sp_hostname(self) -> Optional[str]: + """Derive the SharePoint hostname from the site URL.""" if not self._site_url: return None parsed = urlparse(self._site_url) - if not parsed.netloc: - return None - return f"https://{parsed.netloc}/.default" - - def _make_sp_token_provider(self) -> Optional[Callable]: - """Create an async callable that returns a SharePoint-scoped token. - - Returns None if the site URL is not set or no token manager is available. + return parsed.netloc or None + + @staticmethod + def _normalize_site_url(site_url: str) -> str: + """Normalize a site URL for use as a dict key (strip trailing slash).""" + return (site_url or "").rstrip("/") + + def _track_entity_groups(self, entity: BaseEntity, site_url: str = "") -> None: + """Track Entra ID and SP site groups found in entity permissions. + + Args: + entity: The entity whose access viewers to inspect. + site_url: The site URL this entity belongs to. SP groups are keyed + by site URL so multi-site syncs can expand SP groups per-site. + May be empty for paths that lack site context (incremental / + targeted single-file); those SP groups won't expand. """ - sp_scope = self._derive_sp_resource_scope() - if not sp_scope: - return None - - async def _provider() -> str: - token = await self.get_token_for_resource(sp_scope) - if not token: - raise RuntimeError(f"Could not obtain SharePoint token for scope {sp_scope}") - return token - - return _provider - - def _track_entity_groups(self, entity: BaseEntity) -> None: - """Track Entra ID and SP site groups found in entity permissions.""" if not hasattr(entity, "access") or entity.access is None: return + norm_site = self._normalize_site_url(site_url) for viewer in entity.access.viewers or []: if viewer.startswith("group:entra:"): group_id = viewer[len("group:") :] self._item_level_entra_groups.add(group_id) elif viewer.startswith("group:sp:"): - self._item_level_sp_groups.add(viewer[len("group:") :]) + sp_name = viewer[len("group:") :] + self._item_level_sp_groups.setdefault(norm_site, set()).add(sp_name) + + # -- SP site group membership parsing -- + + # Match regular user logins: "i:0#.f|membership|" + _MEMBERSHIP_LOGIN_RE = re.compile(r"^i:0#\.f\|membership\|(?P[^|]+@[^|]+)$") + # Match Entra federated group logins: "c:0o.c|federateddirectoryclaimprovider|[_o]" + _ENTRA_GROUP_LOGIN_RE = re.compile( + r"^c:0o\.c\|federateddirectoryclaimprovider\|" + r"(?P[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-" + r"[0-9a-fA-F]{4}-[0-9a-fA-F]{12})(_o)?$" + ) + # Match the "Everyone except external users" claim: + # "c:0-.f|rolemanager|spo-grid-all-users/" + # PrincipalType=4 (SecurityGroup), but the claim provider is `rolemanager` + # rather than `federateddirectoryclaimprovider`. Represents all tenant + # users with userType=Member; excludes B2B guests by definition. + _EVERYONE_EXCEPT_EXTERNAL_LOGIN_RE = re.compile( + r"^c:0-\.f\|rolemanager\|spo-grid-all-users/[0-9a-fA-F-]+$" + ) + + @classmethod + def _email_from_membership_login(cls, login: str) -> Optional[str]: + """Extract email from SP user LoginName if it follows the membership pattern. + + Only matches "i:0#.f|membership|". Returns None for role principals + (e.g., "c:0-.f|rolemanager|spo-grid-all-users/...") and other shapes so + we don't pollute the membership table with fake email-like strings. + """ + if not login: + return None + m = cls._MEMBERSHIP_LOGIN_RE.match(login) + if m: + return m.group("email").strip().lower() or None + return None + + @classmethod + def _parse_sp_group_member(cls, user: Dict[str, Any]) -> Optional[Tuple[str, str]]: + """Parse one entry from /_api/web/sitegroups({id})/users into (member_id, member_type). + + Returns None for entries that should not become memberships: + - Catch-all "All" / "Everyone" principals (PrincipalType=15) + - DistList, SPGroup, RoleManager (other than the recognized claim below) + - Unparseable entries (no email for users, no GUID for groups) + + Recognized PrincipalType=4 shapes: + - Entra federated group: ``c:0o.c|federateddirectoryclaimprovider|[_o]`` + → returns ``("entra:", "group")``. + - "Everyone except external users" claim: + ``c:0-.f|rolemanager|spo-grid-all-users/`` → returns the + synthetic ``(EVERYONE_EXCEPT_EXTERNAL_PRINCIPAL, "group")`` sentinel. + The caller (``_expand_sp_site_groups``) then enumerates internal + tenant users once per sync to populate the synthetic group. + - Any other PT=4 LoginName: returns None. The caller logs the raw + shape at info-level so unknown claim shapes show up in operator + logs and can be wired up explicitly later. + + PrincipalType reference: + 1 = User + 2 = DistList + 4 = SecurityGroup (Entra group OR rolemanager claim) + 8 = SPGroup + 15 = All + 16 = RoleManager + """ + ptype = user.get("PrincipalType") + login = user.get("LoginName", "") or "" + + if ptype == 1: + email = user.get("Email") or "" + email = email.strip().lower() + if not email: + email = cls._email_from_membership_login(login) or "" + if not email: + return None + # Bare email (no "user:" prefix) matches the broker storage + # convention used by EntraGroupExpander and SP 2019 V2. + return (email, "user") + + if ptype == 4: + m = cls._ENTRA_GROUP_LOGIN_RE.match(login) + if m: + return (f"entra:{m.group('guid').lower()}", "group") + if cls._EVERYONE_EXCEPT_EXTERNAL_LOGIN_RE.match(login): + return (EVERYONE_EXCEPT_EXTERNAL_PRINCIPAL, "group") + return None + + # PrincipalType 2 (DistList), 8 (SPGroup), 15 (All), 16 (RoleManager), + # and unknown types are intentionally skipped. + return None + + @classmethod + def _is_unrecognized_pt4_login(cls, user: Dict[str, Any]) -> bool: + """Return True for a PT=4 entry whose LoginName matches none of our patterns. + + Used at the call site to emit a one-line diagnostic so that unknown + claim shapes (rare custom rolemanager roles, legacy Windows claims, + etc.) surface in operator logs without breaking sync. + """ + if user.get("PrincipalType") != 4: + return False + login = user.get("LoginName", "") or "" + return not ( + cls._ENTRA_GROUP_LOGIN_RE.match(login) + or cls._EVERYONE_EXCEPT_EXTERNAL_LOGIN_RE.match(login) + ) # -- Browse Tree -- @@ -236,14 +379,7 @@ async def get_browse_children( self, parent_node_id: Optional[str] = None, ) -> List[BrowseNode]: - """Lazy-load tree nodes from Microsoft Graph API. - - Tree structure: - - Root (parent_node_id=None): returns discovered sites - - Site node (site:{site_id}): returns drives for the site - - Drive node (drive:{site_id}|{drive_id}): returns root children of the drive - - Folder node (folder:{drive_id}|{folder_id}): returns children of the folder - """ + """Lazy-load tree nodes from Microsoft Graph API.""" graph_client = self._create_graph_client() nodes: List[BrowseNode] = [] @@ -381,10 +517,13 @@ async def _download_and_save_file( entity.url = ( f"https://graph.microsoft.com/v1.0/drives/{drive_id}/items/{item_id}/content" ) + + auth = await self._get_download_auth(entity.url) + await files.download_from_url( entity=entity, client=self.http_client, - auth=self.auth, + auth=auth, logger=self.logger, ) return entity @@ -416,6 +555,8 @@ async def download_one(item: PendingFileDownload): self.logger.debug(f"File download skipped for {item.drive_id}/{item.item_id}") except EntityProcessingError as e: self.logger.warning(f"Skipping file download: {e}") + except Exception as e: + self.logger.warning(f"Unexpected error downloading {item.entity.name}: {e}") tasks = [asyncio.create_task(download_one(p)) for p in pending] await asyncio.gather(*tasks, return_exceptions=True) @@ -439,7 +580,7 @@ def _should_do_full_sync(self, cursor: SyncCursor | None) -> tuple: # -- Entity Generation -- - async def generate_entities( + async def generate_entities( # noqa: C901 self, *, cursor: SyncCursor | None = None, @@ -450,8 +591,19 @@ async def generate_entities( cursor_data = cursor.data if cursor else {} for g in cursor_data.get("tracked_entra_groups", []): self._item_level_entra_groups.add(g) - for g in cursor_data.get("tracked_sp_groups", []): - self._item_level_sp_groups.add(g) + + # tracked_sp_groups format changed from List[str] (flat names) to + # Dict[site_url, List[str]] (site-scoped). Migrate defensively. + tracked_sp = cursor_data.get("tracked_sp_groups") + if isinstance(tracked_sp, dict): + for site_url, names in tracked_sp.items(): + if isinstance(names, list): + self._item_level_sp_groups[site_url] = set(names) + elif isinstance(tracked_sp, list): + self.logger.info( + "Legacy tracked_sp_groups list format detected; discarding — " + "will re-collect on next full sync" + ) if node_selections: self.logger.info(f"Sync strategy: TARGETED ({len(node_selections)} node selections)") @@ -469,39 +621,6 @@ async def generate_entities( async for entity in self._incremental_sync(cursor, files): yield entity - async def _discover_sites(self, graph_client: GraphClient) -> List[Dict[str, Any]]: - """Discover sites to sync based on config. - - Supports: - - Single URL: "https://tenant.sharepoint.com/sites/MySite" - - Comma-separated: "https://tenant.sharepoint.com/sites/A, .../sites/B" - - Empty string: discover all accessible sites - """ - sites = [] - - if self._site_url: - urls = [u.strip() for u in self._site_url.split(",") if u.strip()] - for url in urls: - parsed = urlparse(url) - hostname = parsed.netloc - site_path = parsed.path.lstrip("/") - try: - site = await graph_client.get_site_by_url(hostname, site_path) - sites.append(site) - except SourceAuthError: - raise - except Exception as e: - self.logger.warning(f"Could not resolve site URL {url}: {e}") - raise - else: - async for site in graph_client.search_sites("*"): - if not self._include_personal_sites and site.get("isPersonalSite", False): - continue - sites.append(site) - - self.logger.info(f"Discovered {len(sites)} sites to sync") - return sites - async def _resolve_unresolved_viewers( self, entity: BaseEntity, graph_client: GraphClient ) -> None: @@ -527,43 +646,10 @@ async def _resolve_unresolved_viewers( new_viewers.append(v) entity.access.viewers = new_viewers - async def _fetch_sp_group_viewers(self) -> List[str]: - """Fetch all SP site groups and return their viewer strings. - - Uses the shared http_client with SP-scoped token headers. - Returns empty list if SP token is unavailable. - """ - sp_token_provider = self._make_sp_token_provider() - if not sp_token_provider or not self._site_url: - return [] - try: - token = await sp_token_provider() - headers = { - "Authorization": f"Bearer {token}", - "Accept": "application/json;odata=verbose", - } - resp = await self.http_client.get( - f"{self._site_url}/_api/web/sitegroups", - headers=headers, - timeout=30.0, - ) - resp.raise_for_status() - groups = resp.json().get("d", {}).get("results", []) - - viewers = [] - for g in groups: - title = g.get("Title", "") - if title: - tag = f"group:sp:{title.lower().replace(' ', '_')}" - viewers.append(tag) - self._item_level_sp_groups.add(tag[len("group:") :]) - self.logger.info(f"Fetched {len(viewers)} SP site groups as viewers") - return viewers - except SourceAuthError: - raise - except Exception as e: - self.logger.warning(f"SP group fetch failed: {e}") - return [] + @staticmethod + def _has_link_permission(permissions: List[Dict[str, Any]]) -> bool: + """Return True if any permission carries a sharing-link block.""" + return any(p.get("link") for p in (permissions or [])) async def _full_sync( # noqa: C901 self, @@ -577,9 +663,27 @@ async def _full_sync( # noqa: C901 for site_data in sites: site_id = site_data.get("id", "") + site_url = self._normalize_site_url(site_data.get("webUrl", "")) + + # Collect all drives for this site (single API call) + all_drives = [] + async for drive_data in graph_client.get_drives(site_id): + all_drives.append(drive_data) + + # Fetch site-level permissions from the first drive's root. + site_access = None + if all_drives: + try: + site_permissions = await graph_client.get_drive_root_permissions( + all_drives[0]["id"] + ) + site_access = await extract_access_control(site_permissions) + except Exception as e: + self.logger.warning(f"Could not fetch site-level permissions: {e}") try: - site_entity = await build_site_entity(site_data, []) + site_entity = await build_site_entity(site_data, [], access=site_access) + self._track_entity_groups(site_entity, site_url) yield site_entity entity_count += 1 @@ -593,12 +697,24 @@ async def _full_sync( # noqa: C901 self.logger.warning(f"Skipping site {site_id}: {e}") continue - sp_group_viewers = await self._fetch_sp_group_viewers() - - async for drive_data in graph_client.get_drives(site_id): + for drive_data in all_drives: drive_id = drive_data.get("id", "") try: - drive_entity = await build_drive_entity(drive_data, site_id, site_breadcrumbs) + # Each drive gets its own root permissions + drive_access = site_access + if drive_id != all_drives[0]["id"]: + try: + drive_permissions = await graph_client.get_drive_root_permissions( + drive_id + ) + drive_access = await extract_access_control(drive_permissions) + except Exception: + pass # Fall back to site_access + + drive_entity = await build_drive_entity( + drive_data, site_id, site_breadcrumbs, access=drive_access + ) + self._track_entity_groups(drive_entity, site_url) yield drive_entity entity_count += 1 @@ -622,21 +738,26 @@ async def _full_sync( # noqa: C901 item_data["id"], ) + # Sharing-link permissions need the file's SP UniqueId + # to translate into the SharingLinks.* SP site group. + # Skip the extra fetch when the file has no sharing links. + sp_unique_id = None + if self._has_link_permission(permissions): + sp_unique_id = await graph_client.get_item_sp_unique_id( + drive_id, item_data["id"] + ) + file_entity = await build_file_entity( item_data, drive_id, site_id, drive_breadcrumbs, permissions, + sp_unique_id=sp_unique_id, ) await self._resolve_unresolved_viewers(file_entity, graph_client) - if sp_group_viewers and file_entity.access: - existing = set(file_entity.access.viewers or []) - for spv in sp_group_viewers: - if spv not in existing: - file_entity.access.viewers.append(spv) - self._track_entity_groups(file_entity) + self._track_entity_groups(file_entity, site_url) if files: pending_files.append( @@ -661,6 +782,8 @@ async def _full_sync( # noqa: C901 except EntityProcessingError as e: self.logger.warning(f"Skipping file: {e}") + except Exception as e: + self.logger.warning(f"Unexpected error processing file: {e}") if pending_files and files: downloaded = await self._download_files_parallel(pending_files, files) @@ -670,7 +793,9 @@ async def _full_sync( # noqa: C901 if cursor: try: - _, delta_token = await graph_client.get_drive_delta(drive_id) + _, delta_token = await graph_client.get_drive_delta( + drive_id, prefer_headers=self._delta_prefer_headers + ) if delta_token: cursor_schema = SharePointOnlineCursor(**cursor.data) cursor_schema.update_entity_cursor( @@ -696,8 +821,9 @@ async def _full_sync( # noqa: C901 async for page_data in graph_client.get_pages(site_id): try: page_entity = await build_page_entity( - page_data, site_id, site_breadcrumbs + page_data, site_id, site_breadcrumbs, access=site_access ) + self._track_entity_groups(page_entity, site_url) yield page_entity entity_count += 1 except EntityProcessingError as e: @@ -718,7 +844,9 @@ async def _full_sync( # noqa: C901 full_sync_required=False, total_entities_synced=entity_count, tracked_entra_groups=list(self._item_level_entra_groups), - tracked_sp_groups=list(self._item_level_sp_groups), + tracked_sp_groups={ + site: sorted(names) for site, names in self._item_level_sp_groups.items() + }, ) self.logger.info(f"Full sync complete: {entity_count} entities") @@ -743,7 +871,9 @@ async def _incremental_sync( # noqa: C901 for drive_id, token in delta_tokens.items(): try: - changed_items, new_token = await graph_client.get_drive_delta(drive_id, token) + changed_items, new_token = await graph_client.get_drive_delta( + drive_id, token, prefer_headers=self._delta_prefer_headers + ) except SourceAuthError: raise except Exception as e: @@ -776,12 +906,18 @@ async def _incremental_sync( # noqa: C901 if item_data.get("file"): try: permissions = await graph_client.get_item_permissions(drive_id, item_id) + sp_unique_id = None + if self._has_link_permission(permissions): + sp_unique_id = await graph_client.get_item_sp_unique_id( + drive_id, item_id + ) file_entity = await build_file_entity( item_data, drive_id, "", [], permissions, + sp_unique_id=sp_unique_id, ) await self._resolve_unresolved_viewers(file_entity, graph_client) self._track_entity_groups(file_entity) @@ -848,7 +984,20 @@ async def _targeted_sync( # noqa: C901 try: site_data = await graph_client.get_site(site_id) - site_entity = await build_site_entity(site_data, []) + targeted_site_url = self._normalize_site_url(site_data.get("webUrl", "")) + + # Fetch site-level permissions from first drive root + targeted_site_access = None + async for peek_drive in graph_client.get_drives(site_id): + try: + perms = await graph_client.get_drive_root_permissions(peek_drive["id"]) + targeted_site_access = await extract_access_control(perms) + except Exception: + pass + break + + site_entity = await build_site_entity(site_data, [], access=targeted_site_access) + self._track_entity_groups(site_entity, targeted_site_url) yield site_entity entity_count += 1 except SourceAuthError: @@ -906,8 +1055,18 @@ async def _targeted_sync( # noqa: C901 item_data = await graph_client.get(url) if item_data.get("file"): permissions = await graph_client.get_item_permissions(drive_id, item_id) + sp_unique_id = None + if self._has_link_permission(permissions): + sp_unique_id = await graph_client.get_item_sp_unique_id( + drive_id, item_id + ) file_entity = await build_file_entity( - item_data, drive_id, "", [], permissions + item_data, + drive_id, + "", + [], + permissions, + sp_unique_id=sp_unique_id, ) await self._resolve_unresolved_viewers(file_entity, graph_client) self._track_entity_groups(file_entity) @@ -935,7 +1094,19 @@ async def _sync_drive( """Sync all files in a single drive (used by both full and targeted sync).""" try: drive_data = await graph_client.get_drive(drive_id) - drive_entity = await build_drive_entity(drive_data, site_id, site_breadcrumbs) + + # Fetch drive root permissions for the drive entity + drive_access = None + try: + drive_permissions = await graph_client.get_drive_root_permissions(drive_id) + drive_access = await extract_access_control(drive_permissions) + except Exception: + pass + + drive_entity = await build_drive_entity( + drive_data, site_id, site_breadcrumbs, access=drive_access + ) + self._track_entity_groups(drive_entity) yield drive_entity drive_breadcrumbs = site_breadcrumbs + [ @@ -975,7 +1146,7 @@ async def _sync_folder_recursive( ): yield entity - async def _process_file_items( + async def _process_file_items( # noqa: C901 self, graph_client: GraphClient, item_stream: AsyncGenerator[Dict[str, Any], None], @@ -994,8 +1165,18 @@ async def _process_file_items( continue try: permissions = await graph_client.get_item_permissions(drive_id, item_data["id"]) + sp_unique_id = None + if self._has_link_permission(permissions): + sp_unique_id = await graph_client.get_item_sp_unique_id( + drive_id, item_data["id"] + ) file_entity = await build_file_entity( - item_data, drive_id, site_id, breadcrumbs, permissions + item_data, + drive_id, + site_id, + breadcrumbs, + permissions, + sp_unique_id=sp_unique_id, ) if resolve_viewers: await self._resolve_unresolved_viewers(file_entity, graph_client) @@ -1043,55 +1224,128 @@ async def _expand_entra_groups( async for membership in group_expander.expand_group(group_id): yield membership - async def _expand_sp_site_groups(self) -> AsyncGenerator[MembershipTuple, None]: - """Expand tracked SP site groups into user memberships. - - Uses the shared http_client with SP-scoped token headers. + async def _expand_sp_site_groups( # noqa: C901 + self, + ) -> AsyncGenerator[MembershipTuple, None]: + """Expand tracked SP site groups into user/group memberships. + + Iterates per-site: for each site URL we've tracked SP group names against, + fetches that site's SP groups via the SharePoint REST API and resolves + their members. + + Member types emitted: + - ``user`` for real users (PrincipalType=1). Role principals like + "Everyone except external users" are skipped. + - ``group`` for Entra security groups nested inside SP groups + (PrincipalType=4 with federateddirectoryclaimprovider). The broker's + recursive group expansion resolves these to individual users at + search time. """ - sp_group_names = list(self._item_level_sp_groups) - if not sp_group_names or not self._site_url: - return - sp_token_provider = self._make_sp_token_provider() - if not sp_token_provider: - self.logger.warning("No SP token provider for site group expansion") + if not self._item_level_sp_groups: return - self.logger.info(f"Expanding {len(sp_group_names)} SP site groups") + total_groups = sum(len(v) for v in self._item_level_sp_groups.values()) + self.logger.info( + f"Expanding {total_groups} SP site groups across " + f"{len(self._item_level_sp_groups)} site(s)" + ) + graph_client = self._create_graph_client() - sp_groups = await graph_client.get_site_groups( - self._site_url, - sp_token_provider=sp_token_provider, - ) - sp_name_to_id = { - f"sp:{g['Title'].replace(' ', '_').lower()}": g.get("Id") - for g in sp_groups - if g.get("Title") - } - - for sp_name in sp_group_names: - sp_id = sp_name_to_id.get(sp_name) - if not sp_id: - self.logger.debug(f"SP group '{sp_name}' not found in site") + for site_url, sp_group_names in self._item_level_sp_groups.items(): + if not site_url or not sp_group_names: continue - users = await graph_client.get_site_group_users( - self._site_url, - sp_id, - sp_token_provider=sp_token_provider, - ) - for user in users: - email = user.get("Email", "") - login = user.get("LoginName", "") - if not email and login and "|" in login: - email = login.split("|")[-1] - if email: + sp_token_provider = self._make_sp_token_provider_for_site(site_url) + if not sp_token_provider: + self.logger.warning( + f"No SP token provider for site {site_url}; skipping SP group expansion" + ) + continue + + try: + sp_groups = await graph_client.get_site_groups( + site_url, sp_token_provider=sp_token_provider + ) + except Exception as e: + self.logger.warning(f"Failed to fetch SP groups for {site_url}: {e}") + continue + + sp_name_to_id = { + f"sp:{g['Title'].replace(' ', '_').lower()}": g.get("Id") + for g in sp_groups + if g.get("Title") + } + + for sp_name in sp_group_names: + sp_id = sp_name_to_id.get(sp_name) + if not sp_id: + self.logger.debug(f"SP group '{sp_name}' not found in site {site_url}") + continue + + try: + users = await graph_client.get_site_group_users( + site_url, sp_id, sp_token_provider=sp_token_provider + ) + except Exception as e: + self.logger.warning( + f"Failed to fetch users for SP group {sp_name} in {site_url}: {e}" + ) + continue + + for user in users: + parsed = self._parse_sp_group_member(user) + if parsed is None: + if self._is_unrecognized_pt4_login(user): + self.logger.info( + "Unrecognized PrincipalType=4 SP group member; skipped. " + f"LoginName={user.get('LoginName', '')!r} " + f"Title={user.get('Title', '')!r}" + ) + continue + member_id, member_type = parsed yield MembershipTuple( - member_id=email.lower(), - member_type="user", + member_id=member_id, + member_type=member_type, group_id=sp_name, - group_name=user.get("Title", sp_name), + group_name=user.get("Title") or sp_name, ) + if member_id == EVERYONE_EXCEPT_EXTERNAL_PRINCIPAL: + self._needs_internal_user_enum = True + + async def _expand_everyone_except_external( + self, + ) -> AsyncGenerator[MembershipTuple, None]: + """Populate the synthetic ``EVERYONE_EXCEPT_EXTERNAL_PRINCIPAL`` group. + + Called once per sync, only when the SP group expansion observed at + least one occurrence of the claim. Enumerates internal tenant users + via Graph (``userType eq 'Member'`` filter excludes B2B guests) and + yields one user → claim membership per user. The broker's recursive + group expansion then chains user → claim → SP group at search time. + """ + graph_client = self._create_graph_client() + count = 0 + try: + async for u in graph_client.list_internal_tenant_users(): + count += 1 + yield MembershipTuple( + member_id=u["email"], + member_type="user", + group_id=EVERYONE_EXCEPT_EXTERNAL_PRINCIPAL, + group_name=EVERYONE_EXCEPT_EXTERNAL_DISPLAY_NAME, + ) + except SourceAuthError: + raise + except Exception as e: + self.logger.warning( + f"Failed to enumerate internal tenant users for " + f"'{EVERYONE_EXCEPT_EXTERNAL_PRINCIPAL}': {e}" + ) + self.logger.info( + f"Populated synthetic '{EVERYONE_EXCEPT_EXTERNAL_PRINCIPAL}' group " + f"with {count} internal tenant users" + ) async def generate_access_control_memberships( self, @@ -1099,6 +1353,7 @@ async def generate_access_control_memberships( """Expand Entra ID groups and SP site groups into user memberships.""" self.logger.info("Starting access control membership extraction") membership_count = 0 + self._needs_internal_user_enum = False group_expander = self._create_group_expander() async for m in self._expand_entra_groups(group_expander): @@ -1114,5 +1369,352 @@ async def generate_access_control_memberships( except Exception as e: self.logger.warning(f"SP site group expansion failed: {e}") + # If any SP site group contained the "Everyone except external users" + # claim, populate the synthetic claim group with internal tenant users + # exactly once. Skipped entirely when no group used the claim. + if self._needs_internal_user_enum: + async for m in self._expand_everyone_except_external(): + yield m + membership_count += 1 + group_expander.log_stats() self.logger.info(f"Access control extraction complete: {membership_count} memberships") + + +# ============================================================================= +# OAuth source — delegated user auth +# ============================================================================= + + +@source( + name="SharePoint Online", + short_name="sharepoint_online", + auth_methods=[ + AuthenticationMethod.OAUTH_BROWSER, + AuthenticationMethod.OAUTH_TOKEN, + AuthenticationMethod.AUTH_PROVIDER, + ], + oauth_type=OAuthType.WITH_ROTATING_REFRESH, + auth_config_class=None, + config_class=SharePointOnlineConfig, + supports_continuous=True, + cursor_class=SharePointOnlineCursor, + supports_access_control=True, + supports_browse_tree=True, + feature_flag="sharepoint_2019_v2", + labels=["Collaboration", "File Storage"], +) +class SharePointOnlineSource(SharePointOnlineBase): + """SharePoint Online source using delegated OAuth. + + Uses the signed-in user's permissions via OAuth browser flow. + Site discovery uses Graph search (delegated permissions). + """ + + @classmethod + async def create( + cls, + *, + auth: TokenProviderProtocol, + logger: ContextualLogger, + http_client: AirweaveHttpClient, + config: SharePointOnlineConfig, + ) -> SharePointOnlineSource: + """Create and configure an OAuth SharePoint Online source.""" + instance = cls(auth=auth, logger=logger, http_client=http_client) + instance._init_common(config) + return instance + + async def _get_access_token(self) -> str: + return await self.auth.get_token() + + async def _handle_401(self) -> str: + if self.auth.supports_refresh: + return await self.auth.force_refresh() + return await self.auth.get_token() + + def _make_sp_token_provider_for_site(self, site_url: str) -> Optional[Callable]: + """Create SP token provider for a specific site URL via OAuth scope exchange.""" + if not site_url: + return None + parsed = urlparse(site_url) + hostname = parsed.netloc + if not hostname: + return None + sp_scope = f"https://{hostname}/.default" + + async def _provider() -> str: + token = await self.get_token_for_resource(sp_scope) + if not token: + raise RuntimeError(f"Could not obtain SharePoint token for scope {sp_scope}") + return token + + return _provider + + async def _discover_sites(self, graph_client: GraphClient) -> List[Dict[str, Any]]: + """Discover sites via Graph search (delegated permissions). + + Supports: + - Single URL: "https://tenant.sharepoint.com/sites/MySite" + - Comma-separated: "https://tenant.sharepoint.com/sites/A, .../sites/B" + - Empty string: search all accessible sites + """ + sites = [] + + if self._site_url: + urls = [u.strip() for u in self._site_url.split(",") if u.strip()] + for url in urls: + parsed = urlparse(url) + hostname = parsed.netloc + site_path = parsed.path.lstrip("/") + try: + site = await graph_client.get_site_by_url(hostname, site_path) + sites.append(site) + except SourceAuthError: + raise + except Exception as e: + self.logger.warning(f"Could not resolve site URL {url}: {e}") + raise + else: + async for site in graph_client.search_sites("*"): + if not self._include_personal_sites and site.get("isPersonalSite", False): + continue + sites.append(site) + + self.logger.info(f"Discovered {len(sites)} sites to sync") + return sites + + +# ============================================================================= +# Client credentials source — app-only auth +# ============================================================================= + + +@source( + name="SharePoint Online (App)", + short_name="sharepoint_online_app", + auth_methods=[AuthenticationMethod.DIRECT], + auth_config_class=SharePointOnlineAppAuthConfig, + config_class=SharePointOnlineConfig, + supports_continuous=True, + cursor_class=SharePointOnlineCursor, + supports_access_control=True, + supports_browse_tree=True, + feature_flag="sharepoint_2019_v2", + labels=["Collaboration", "File Storage"], +) +class SharePointOnlineAppSource(SharePointOnlineBase): + """SharePoint Online source using client credentials (app-only auth). + + Uses client_id + client_secret for Graph API and certificate-based + authentication for SharePoint REST API. Requires Azure AD app registration + with application permissions and admin consent. + """ + + _tenant_id: str + _client_id: str + _client_secret: str + _private_key: str + _certificate: str + _graph_token: Optional[str] + _graph_token_expires: float + _sp_tokens: Dict[str, tuple[str, float]] + + @classmethod + async def create( + cls, + *, + auth: DirectCredentialProvider, + logger: ContextualLogger, + http_client: AirweaveHttpClient, + config: SharePointOnlineConfig, + ) -> SharePointOnlineAppSource: + """Create and configure a client-credentials SharePoint Online source.""" + instance = cls(auth=auth, logger=logger, http_client=http_client) + instance._init_common(config) + + creds: SharePointOnlineAppAuthConfig = auth.credentials + instance._tenant_id = creds.tenant_id + instance._client_id = creds.client_id + instance._client_secret = creds.client_secret + instance._private_key = creds.private_key + instance._certificate = creds.certificate + + # Token cache + instance._graph_token = None + instance._graph_token_expires = 0.0 + instance._sp_tokens = {} # hostname -> (token, expires_at) + + # Exchange for initial Graph token + instance._graph_token = await instance._exchange_graph_token() + instance._graph_token_expires = asyncio.get_event_loop().time() + 3500 + + return instance + + # -- Token exchange (app-only mode) -- + + async def _exchange_graph_token(self) -> str: + """Exchange client credentials for a Microsoft Graph access token.""" + url = f"https://login.microsoftonline.com/{self._tenant_id}/oauth2/v2.0/token" + async with httpx.AsyncClient() as client: + resp = await client.post( + url, + data={ + "grant_type": "client_credentials", + "client_id": self._client_id, + "client_secret": self._client_secret, + "scope": "https://graph.microsoft.com/.default", + }, + ) + resp.raise_for_status() + data = resp.json() + self.logger.info(f"App-only Graph token obtained (expires_in={data.get('expires_in')})") + return str(data["access_token"]) + + async def _exchange_sp_token_with_certificate(self, hostname: str) -> str: + """Exchange certificate credentials for a SharePoint REST API access token.""" + import base64 + import hashlib + import time as _time + + import jwt as pyjwt + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey + from cryptography.x509 import load_pem_x509_certificate + + token_url = f"https://login.microsoftonline.com/{self._tenant_id}/oauth2/v2.0/token" + + loaded_key = serialization.load_pem_private_key(self._private_key.encode(), password=None) + if not isinstance(loaded_key, RSAPrivateKey): + raise ValueError("SharePoint certificate auth requires an RSA private key") + private_key: RSAPrivateKey = loaded_key + + if not self._certificate: + raise ValueError( + "Certificate PEM is required for SP REST API token exchange. " + "Provide the PEM certificate that was uploaded to the Azure AD app registration." + ) + + cert = load_pem_x509_certificate(self._certificate.encode()) + cert_der = cert.public_bytes(serialization.Encoding.DER) + cert_hash = hashlib.sha1(cert_der).digest() # noqa: S324 + x5t = base64.urlsafe_b64encode(cert_hash).rstrip(b"=").decode() + + now = int(_time.time()) + assertion = pyjwt.encode( + { + "aud": token_url, + "iss": self._client_id, + "sub": self._client_id, + "jti": str(now), + "nbf": now, + "exp": now + 600, + }, + private_key, + algorithm="RS256", + headers={"x5t": x5t}, + ) + + async with httpx.AsyncClient() as client: + resp = await client.post( + token_url, + data={ + "grant_type": "client_credentials", + "client_id": self._client_id, + "client_assertion_type": ( + "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" + ), + "client_assertion": assertion, + "scope": f"https://{hostname}/.default", + }, + ) + resp.raise_for_status() + data = resp.json() + self.logger.info( + f"App-only SP token for {hostname} obtained (expires_in={data.get('expires_in')})" + ) + return str(data["access_token"]) + + async def _get_sp_token(self, hostname: str) -> str: + """Get a valid SP REST API token for a hostname, re-exchanging if expired.""" + now = asyncio.get_event_loop().time() + cached = self._sp_tokens.get(hostname) + if cached: + token, expires_at = cached + if now < expires_at: + return token + token = await self._exchange_sp_token_with_certificate(hostname) + self._sp_tokens[hostname] = (token, now + 3500) + return token + + # -- Auth hooks -- + + async def _get_access_token(self) -> str: + now = asyncio.get_event_loop().time() + if self._graph_token and now < self._graph_token_expires: + return self._graph_token + self._graph_token = await self._exchange_graph_token() + self._graph_token_expires = now + 3500 # ~58 min + return self._graph_token + + async def _handle_401(self) -> str: + self._graph_token_expires = 0 # force re-exchange + return await self._get_access_token() + + def _make_sp_token_provider_for_site(self, site_url: str) -> Optional[Callable]: + """Create SP token provider for a specific site URL via certificate exchange.""" + if not site_url: + return None + parsed = urlparse(site_url) + hostname = parsed.netloc + if not hostname: + return None + + async def _provider() -> str: + return await self._get_sp_token(hostname) + + return _provider + + @property + def _delta_prefer_headers(self) -> List[str]: + return [ + "deltashowsharingchanges", + "deltashowremovedasdeleted", + "deltatraversepermissiongaps", + ] + + async def _get_download_auth(self, url: str) -> Any: + """For client-credentials auth, use StaticTokenProvider for Graph URLs.""" + if "tempauth=" in url: + return self.auth # pre-signed URL, no auth needed + graph_token = await self._get_access_token() + return StaticTokenProvider(graph_token) + + async def _discover_sites(self, graph_client: GraphClient) -> List[Dict[str, Any]]: + """Discover sites via getAllSites (application permissions). + + When site_url is set: resolve the specific site. + When empty: use getAllSites for complete enumeration. + """ + sites = [] + + if self._site_url: + parsed = urlparse(self._site_url) + hostname = parsed.netloc + site_path = parsed.path.lstrip("/") + try: + site = await graph_client.get_site_by_url(hostname, site_path) + sites.append(site) + except SourceAuthError: + raise + except Exception as e: + self.logger.warning(f"Could not resolve site URL {self._site_url}: {e}") + raise + else: + async for site in graph_client.get_all_sites(): + if not self._include_personal_sites and site.get("isPersonalSite", False): + continue + sites.append(site) + + self.logger.info(f"Discovered {len(sites)} sites to sync") + return sites diff --git a/backend/conftest.py b/backend/conftest.py index 84ad6698d..532161998 100644 --- a/backend/conftest.py +++ b/backend/conftest.py @@ -210,11 +210,11 @@ def fake_health_service() -> FakeHealthService: @pytest.fixture -def fake_source_connection_service(fake_sync_lifecycle): +def fake_source_connection_service(fake_sync_service): """Fake SourceConnectionService.""" from airweave.domains.source_connections.fakes.service import FakeSourceConnectionService - return FakeSourceConnectionService(sync_lifecycle=fake_sync_lifecycle) + return FakeSourceConnectionService(sync_service=fake_sync_service) @pytest.fixture @@ -375,11 +375,9 @@ def fake_billing_service(): @pytest.fixture -def fake_sync_record_service(): - """Fake SyncRecordService.""" - from airweave.domains.syncs.fakes.record_service import FakeSyncRecordService - - return FakeSyncRecordService() +def fake_sync_record_service(fake_sync_service): + """Legacy fixture — returns the unified FakeSyncService for backward compatibility.""" + return fake_sync_service @pytest.fixture @@ -407,11 +405,9 @@ def fake_sync_service(): @pytest.fixture -def fake_sync_lifecycle(): - """Fake SyncLifecycleService.""" - from airweave.domains.syncs.fakes.lifecycle_service import FakeSyncLifecycleService - - return FakeSyncLifecycleService() +def fake_sync_lifecycle(fake_sync_service): + """Legacy fixture — returns the unified FakeSyncService for backward compatibility.""" + return fake_sync_service @pytest.fixture @@ -717,12 +713,9 @@ def test_container( fake_sync_cursor_repo, fake_sync_cursor_service, fake_sync_job_repo, - fake_sync_record_service, fake_sync_job_service, fake_sync_job_state_machine, - fake_sync_state_machine, fake_sync_service, - fake_sync_lifecycle, fake_billing_service, fake_billing_webhook, fake_payment_gateway, @@ -809,12 +802,9 @@ def test_container( sync_cursor_repo=fake_sync_cursor_repo, sync_cursor_service=fake_sync_cursor_service, sync_job_repo=fake_sync_job_repo, - sync_record_service=fake_sync_record_service, sync_job_service=fake_sync_job_service, sync_job_state_machine=fake_sync_job_state_machine, - sync_state_machine=fake_sync_state_machine, sync_service=fake_sync_service, - sync_lifecycle=fake_sync_lifecycle, billing_service=fake_billing_service, billing_webhook=fake_billing_webhook, payment_gateway=fake_payment_gateway, diff --git a/backend/poetry.lock b/backend/poetry.lock index edfff3ff3..7dfbf176d 100644 --- a/backend/poetry.lock +++ b/backend/poetry.lock @@ -1625,21 +1625,6 @@ wrapt = ">=1.10,<3" [package.extras] dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "setuptools ; python_version >= \"3.12\"", "tox"] -[[package]] -name = "deprecation" -version = "2.1.0" -description = "A library to handle automated deprecations" -optional = false -python-versions = "*" -groups = ["main"] -files = [ - {file = "deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a"}, - {file = "deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff"}, -] - -[package.dependencies] -packaging = "*" - [[package]] name = "diff-cover" version = "10.2.0" @@ -2677,83 +2662,6 @@ typing-extensions = ">=4.10,<5" [package.extras] aiohttp = ["aiohttp", "httpx-aiohttp (>=0.1.8)"] -[[package]] -name = "grpcio" -version = "1.76.0" -description = "HTTP/2-based RPC framework" -optional = false -python-versions = ">=3.9" -groups = ["main"] -files = [ - {file = "grpcio-1.76.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:65a20de41e85648e00305c1bb09a3598f840422e522277641145a32d42dcefcc"}, - {file = "grpcio-1.76.0-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:40ad3afe81676fd9ec6d9d406eda00933f218038433980aa19d401490e46ecde"}, - {file = "grpcio-1.76.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:035d90bc79eaa4bed83f524331d55e35820725c9fbb00ffa1904d5550ed7ede3"}, - {file = "grpcio-1.76.0-cp310-cp310-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:4215d3a102bd95e2e11b5395c78562967959824156af11fa93d18fdd18050990"}, - {file = "grpcio-1.76.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:49ce47231818806067aea3324d4bf13825b658ad662d3b25fada0bdad9b8a6af"}, - {file = "grpcio-1.76.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:8cc3309d8e08fd79089e13ed4819d0af72aa935dd8f435a195fd152796752ff2"}, - {file = "grpcio-1.76.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:971fd5a1d6e62e00d945423a567e42eb1fa678ba89072832185ca836a94daaa6"}, - {file = "grpcio-1.76.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:9d9adda641db7207e800a7f089068f6f645959f2df27e870ee81d44701dd9db3"}, - {file = "grpcio-1.76.0-cp310-cp310-win32.whl", hash = "sha256:063065249d9e7e0782d03d2bca50787f53bd0fb89a67de9a7b521c4a01f1989b"}, - {file = "grpcio-1.76.0-cp310-cp310-win_amd64.whl", hash = "sha256:a6ae758eb08088d36812dd5d9af7a9859c05b1e0f714470ea243694b49278e7b"}, - {file = "grpcio-1.76.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:2e1743fbd7f5fa713a1b0a8ac8ebabf0ec980b5d8809ec358d488e273b9cf02a"}, - {file = "grpcio-1.76.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:a8c2cf1209497cf659a667d7dea88985e834c24b7c3b605e6254cbb5076d985c"}, - {file = "grpcio-1.76.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:08caea849a9d3c71a542827d6df9d5a69067b0a1efbea8a855633ff5d9571465"}, - {file = "grpcio-1.76.0-cp311-cp311-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:f0e34c2079d47ae9f6188211db9e777c619a21d4faba6977774e8fa43b085e48"}, - {file = "grpcio-1.76.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8843114c0cfce61b40ad48df65abcfc00d4dba82eae8718fab5352390848c5da"}, - {file = "grpcio-1.76.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8eddfb4d203a237da6f3cc8a540dad0517d274b5a1e9e636fd8d2c79b5c1d397"}, - {file = "grpcio-1.76.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:32483fe2aab2c3794101c2a159070584e5db11d0aa091b2c0ea9c4fc43d0d749"}, - {file = "grpcio-1.76.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:dcfe41187da8992c5f40aa8c5ec086fa3672834d2be57a32384c08d5a05b4c00"}, - {file = "grpcio-1.76.0-cp311-cp311-win32.whl", hash = "sha256:2107b0c024d1b35f4083f11245c0e23846ae64d02f40b2b226684840260ed054"}, - {file = "grpcio-1.76.0-cp311-cp311-win_amd64.whl", hash = "sha256:522175aba7af9113c48ec10cc471b9b9bd4f6ceb36aeb4544a8e2c80ed9d252d"}, - {file = "grpcio-1.76.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:81fd9652b37b36f16138611c7e884eb82e0cec137c40d3ef7c3f9b3ed00f6ed8"}, - {file = "grpcio-1.76.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:04bbe1bfe3a68bbfd4e52402ab7d4eb59d72d02647ae2042204326cf4bbad280"}, - {file = "grpcio-1.76.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d388087771c837cdb6515539f43b9d4bf0b0f23593a24054ac16f7a960be16f4"}, - {file = "grpcio-1.76.0-cp312-cp312-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:9f8f757bebaaea112c00dba718fc0d3260052ce714e25804a03f93f5d1c6cc11"}, - {file = "grpcio-1.76.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:980a846182ce88c4f2f7e2c22c56aefd515daeb36149d1c897f83cf57999e0b6"}, - {file = "grpcio-1.76.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f92f88e6c033db65a5ae3d97905c8fea9c725b63e28d5a75cb73b49bda5024d8"}, - {file = "grpcio-1.76.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:4baf3cbe2f0be3289eb68ac8ae771156971848bb8aaff60bad42005539431980"}, - {file = "grpcio-1.76.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:615ba64c208aaceb5ec83bfdce7728b80bfeb8be97562944836a7a0a9647d882"}, - {file = "grpcio-1.76.0-cp312-cp312-win32.whl", hash = "sha256:45d59a649a82df5718fd9527ce775fd66d1af35e6d31abdcdc906a49c6822958"}, - {file = "grpcio-1.76.0-cp312-cp312-win_amd64.whl", hash = "sha256:c088e7a90b6017307f423efbb9d1ba97a22aa2170876223f9709e9d1de0b5347"}, - {file = "grpcio-1.76.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:26ef06c73eb53267c2b319f43e6634c7556ea37672029241a056629af27c10e2"}, - {file = "grpcio-1.76.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:45e0111e73f43f735d70786557dc38141185072d7ff8dc1829d6a77ac1471468"}, - {file = "grpcio-1.76.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:83d57312a58dcfe2a3a0f9d1389b299438909a02db60e2f2ea2ae2d8034909d3"}, - {file = "grpcio-1.76.0-cp313-cp313-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:3e2a27c89eb9ac3d81ec8835e12414d73536c6e620355d65102503064a4ed6eb"}, - {file = "grpcio-1.76.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:61f69297cba3950a524f61c7c8ee12e55c486cb5f7db47ff9dcee33da6f0d3ae"}, - {file = "grpcio-1.76.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6a15c17af8839b6801d554263c546c69c4d7718ad4321e3166175b37eaacca77"}, - {file = "grpcio-1.76.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:25a18e9810fbc7e7f03ec2516addc116a957f8cbb8cbc95ccc80faa072743d03"}, - {file = "grpcio-1.76.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:931091142fd8cc14edccc0845a79248bc155425eee9a98b2db2ea4f00a235a42"}, - {file = "grpcio-1.76.0-cp313-cp313-win32.whl", hash = "sha256:5e8571632780e08526f118f74170ad8d50fb0a48c23a746bef2a6ebade3abd6f"}, - {file = "grpcio-1.76.0-cp313-cp313-win_amd64.whl", hash = "sha256:f9f7bd5faab55f47231ad8dba7787866b69f5e93bc306e3915606779bbfb4ba8"}, - {file = "grpcio-1.76.0-cp314-cp314-linux_armv7l.whl", hash = "sha256:ff8a59ea85a1f2191a0ffcc61298c571bc566332f82e5f5be1b83c9d8e668a62"}, - {file = "grpcio-1.76.0-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:06c3d6b076e7b593905d04fdba6a0525711b3466f43b3400266f04ff735de0cd"}, - {file = "grpcio-1.76.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fd5ef5932f6475c436c4a55e4336ebbe47bd3272be04964a03d316bbf4afbcbc"}, - {file = "grpcio-1.76.0-cp314-cp314-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:b331680e46239e090f5b3cead313cc772f6caa7d0fc8de349337563125361a4a"}, - {file = "grpcio-1.76.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2229ae655ec4e8999599469559e97630185fdd53ae1e8997d147b7c9b2b72cba"}, - {file = "grpcio-1.76.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:490fa6d203992c47c7b9e4a9d39003a0c2bcc1c9aa3c058730884bbbb0ee9f09"}, - {file = "grpcio-1.76.0-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:479496325ce554792dba6548fae3df31a72cef7bad71ca2e12b0e58f9b336bfc"}, - {file = "grpcio-1.76.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:1c9b93f79f48b03ada57ea24725d83a30284a012ec27eab2cf7e50a550cbbbcc"}, - {file = "grpcio-1.76.0-cp314-cp314-win32.whl", hash = "sha256:747fa73efa9b8b1488a95d0ba1039c8e2dca0f741612d80415b1e1c560febf4e"}, - {file = "grpcio-1.76.0-cp314-cp314-win_amd64.whl", hash = "sha256:922fa70ba549fce362d2e2871ab542082d66e2aaf0c19480ea453905b01f384e"}, - {file = "grpcio-1.76.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:8ebe63ee5f8fa4296b1b8cfc743f870d10e902ca18afc65c68cf46fd39bb0783"}, - {file = "grpcio-1.76.0-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:3bf0f392c0b806905ed174dcd8bdd5e418a40d5567a05615a030a5aeddea692d"}, - {file = "grpcio-1.76.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0b7604868b38c1bfd5cf72d768aedd7db41d78cb6a4a18585e33fb0f9f2363fd"}, - {file = "grpcio-1.76.0-cp39-cp39-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:e6d1db20594d9daba22f90da738b1a0441a7427552cc6e2e3d1297aeddc00378"}, - {file = "grpcio-1.76.0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d099566accf23d21037f18a2a63d323075bebace807742e4b0ac210971d4dd70"}, - {file = "grpcio-1.76.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:ebea5cc3aa8ea72e04df9913492f9a96d9348db876f9dda3ad729cfedf7ac416"}, - {file = "grpcio-1.76.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:0c37db8606c258e2ee0c56b78c62fc9dee0e901b5dbdcf816c2dd4ad652b8b0c"}, - {file = "grpcio-1.76.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:ebebf83299b0cb1721a8859ea98f3a77811e35dce7609c5c963b9ad90728f886"}, - {file = "grpcio-1.76.0-cp39-cp39-win32.whl", hash = "sha256:0aaa82d0813fd4c8e589fac9b65d7dd88702555f702fb10417f96e2a2a6d4c0f"}, - {file = "grpcio-1.76.0-cp39-cp39-win_amd64.whl", hash = "sha256:acab0277c40eff7143c2323190ea57b9ee5fd353d8190ee9652369fae735668a"}, - {file = "grpcio-1.76.0.tar.gz", hash = "sha256:7be78388d6da1a25c0d5ec506523db58b18be22d9c37d8d3a32c08be4987bd73"}, -] - -[package.dependencies] -typing-extensions = ">=4.12,<5.0" - -[package.extras] -protobuf = ["grpcio-tools (>=1.76.0)"] - [[package]] name = "h11" version = "0.16.0" @@ -8013,21 +7921,6 @@ dev = ["Cython (>=3.0,<4.0)", "setuptools (>=60)"] docs = ["Sphinx (>=4.1.2,<4.2.0)", "sphinx_rtd_theme (>=0.5.2,<0.6.0)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"] test = ["aiohttp (>=3.10.5)", "flake8 (>=6.1,<7.0)", "mypy (>=0.800)", "psutil", "pyOpenSSL (>=25.3.0,<25.4.0)", "pycodestyle (>=2.11.0,<2.12.0)"] -[[package]] -name = "validators" -version = "0.35.0" -description = "Python Data Validation for Humans™" -optional = false -python-versions = ">=3.9" -groups = ["main"] -files = [ - {file = "validators-0.35.0-py3-none-any.whl", hash = "sha256:e8c947097eae7892cb3d26868d637f79f47b4a0554bc6b80065dfe5aac3705dd"}, - {file = "validators-0.35.0.tar.gz", hash = "sha256:992d6c48a4e77c81f1b4daba10d16c3a9bb0dbb79b3a19ea847ff0928e70497a"}, -] - -[package.extras] -crypto-eth-addresses = ["eth-hash[pycryptodome] (>=0.7.0)"] - [[package]] name = "virtualenv" version = "20.36.1" @@ -8183,42 +8076,6 @@ files = [ [package.dependencies] anyio = ">=3.0.0" -[[package]] -name = "weaviate" -version = "0.1.2" -description = "A placeholder package for the Weaviate name" -optional = false -python-versions = "*" -groups = ["main"] -files = [ - {file = "weaviate-0.1.2-py3-none-any.whl", hash = "sha256:40f1c1cf0b769036315d2b6026c8cd823a3a6e951c90d4e70a001a770ba8a444"}, - {file = "weaviate-0.1.2.tar.gz", hash = "sha256:a381b8bb0eb236bd10256def8612953ed9024e6738b8a259e7ec11e626ae0665"}, -] - -[[package]] -name = "weaviate-client" -version = "4.19.2" -description = "A python native Weaviate client" -optional = false -python-versions = ">=3.10" -groups = ["main"] -files = [ - {file = "weaviate_client-4.19.2-py3-none-any.whl", hash = "sha256:e78306d47c574c4035c87223e480bb77bd6e54142a21c4c58522dd43019fe493"}, - {file = "weaviate_client-4.19.2.tar.gz", hash = "sha256:99e76e912c95762436089cd5feedbfeea31e892aa13b6ad94729a2a54b316c45"}, -] - -[package.dependencies] -authlib = ">=1.6.5,<2.0.0" -deprecation = ">=2.1.0,<3.0.0" -grpcio = ">=1.59.5,<1.80.0" -httpx = ">=0.26.0,<0.29.0" -protobuf = ">=4.21.6,<7.0.0" -pydantic = ">=2.12.0,<3.0.0" -validators = ">=0.34.0,<1.0.0" - -[package.extras] -agents = ["weaviate-agents (>=1.0.0,<2.0.0)"] - [[package]] name = "websockets" version = "16.0" @@ -8610,4 +8467,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.13,<3.14" -content-hash = "2065fc338a5f211587a8838f8c3dcee58987ec94e5613452be47ac8dec37e6d0" +content-hash = "60288465ee2e36d9dd0099a018a0949f08dea9e3f96abad034a66d7976317d0c" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index c092253a3..f452f7896 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -28,8 +28,6 @@ tenacity = "^8.2.3" structlog = "^24.1.0" pydantic-settings = "^2.7.0" psycopg2-binary = "^2.9.10" -weaviate = "^0.1.2" -weaviate-client = "^4.10.2" markitdown = "^0.0.1a3" neo4j = "^5.27.0" pyodbc = "^5.2.0" diff --git a/backend/tests/e2e/smoke/test_sources.py b/backend/tests/e2e/smoke/test_sources.py index da92dd856..62ca82e9d 100644 --- a/backend/tests/e2e/smoke/test_sources.py +++ b/backend/tests/e2e/smoke/test_sources.py @@ -154,7 +154,7 @@ async def test_supported_auth_providers_structure(self, api_client: httpx.AsyncC sources = response.json() # Known valid auth provider short names - valid_providers = ["pipedream", "composio"] + valid_providers = ["pipedream", "composio", "custom"] for source in sources: providers = source.get("supported_auth_providers", []) diff --git a/backend/tests/unit/platform/configs/test_config_ssrf.py b/backend/tests/unit/platform/configs/test_config_ssrf.py index 115e4daf7..832487b7a 100644 --- a/backend/tests/unit/platform/configs/test_config_ssrf.py +++ b/backend/tests/unit/platform/configs/test_config_ssrf.py @@ -217,21 +217,23 @@ def test_accepts_public_site(self): class TestSharePointOnlineConfig: - def test_rejects_loopback_in_csv(self): + def test_rejects_loopback(self): with pytest.raises(ValidationError, match="SSRF|blocked"): - SharePointOnlineConfig( - site_url="https://ok.sharepoint.com, http://127.0.0.1" - ) + SharePointOnlineConfig(site_url="http://127.0.0.1") def test_empty_site_url_passes(self): cfg = SharePointOnlineConfig(site_url="") assert cfg.site_url == "" - def test_accepts_valid_csv(self): + def test_missing_site_url_defaults_empty(self): + cfg = SharePointOnlineConfig() + assert cfg.site_url == "" + + def test_accepts_valid_site_url(self): cfg = SharePointOnlineConfig( - site_url="https://a.sharepoint.com, https://b.sharepoint.com" + site_url="https://contoso.sharepoint.com/sites/Marketing" ) - assert cfg.site_url == "https://a.sharepoint.com, https://b.sharepoint.com" + assert cfg.site_url == "https://contoso.sharepoint.com/sites/Marketing" class TestSalesforceConfig: diff --git a/backend/tests/unit/platform/sources/test_sharepoint_online_acl.py b/backend/tests/unit/platform/sources/test_sharepoint_online_acl.py new file mode 100644 index 000000000..a976f9802 --- /dev/null +++ b/backend/tests/unit/platform/sources/test_sharepoint_online_acl.py @@ -0,0 +1,286 @@ +"""Unit tests for SharePoint Online ACL extraction. + +Covers ``extract_access_control`` and the rules around how Microsoft Graph +sharing-link permissions map to ``AccessControl``. + +Background — two related bug fixes: + +1. Organization-scoped sharing links (``link.scope == "organization"``, + "anyone in your org with the link") used to set ``is_public = True``, + which made the search broker bypass all viewer checks. Only + ``link.scope == "anonymous"`` is genuine public access. + +2. The previous code attempted to recover sharing-link audience by + blanket-attaching every site group (including SharingLinks system + groups for unrelated files) to every file in the site. That over-granted + massively. The fix is to translate each link permission into the + specific ``SharingLinks...`` SP site group + for that one file, scoped by the file's SharePoint UniqueId. +""" + +import pytest + +from airweave.platform.sources.sharepoint_online.acl import ( + extract_access_control, + link_permission_to_sp_group_viewer, +) + +# --------------------------------------------------------------------------- +# Helpers — build minimal Graph permission objects +# --------------------------------------------------------------------------- + + +def _link_perm( + scope: str, type_: str = "edit", roles=None, link_id: str = "link-1" +) -> dict: + """Sharing-link permission with the given scope (no grantedTo principal).""" + return { + "id": link_id, + "roles": roles if roles is not None else ["write"], + "link": {"scope": scope, "type": type_}, + "grantedToIdentitiesV2": [], + "grantedToIdentities": [], + } + + +def _site_group_perm(name: str, group_id: str = "5", roles=None) -> dict: + return { + "id": f"sg-{group_id}", + "roles": roles if roles is not None else ["write"], + "grantedToV2": {"siteGroup": {"displayName": name, "id": group_id}}, + } + + +def _user_perm(email: str, roles=None) -> dict: + return { + "id": f"u-{email}", + "roles": roles if roles is not None else ["read"], + "grantedToV2": {"user": {"email": email, "displayName": email}}, + } + + +def _entra_group_perm(group_id: str, roles=None) -> dict: + return { + "id": f"eg-{group_id}", + "roles": roles if roles is not None else ["read"], + "grantedToV2": {"group": {"id": group_id}}, + } + + +# --------------------------------------------------------------------------- +# Sharing-link scope handling +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_organization_scoped_link_does_not_set_is_public(): + """Org-scoped link by itself must not flip is_public. + + Regression: the previous behavior treated organization-scoped links as + fully public, bypassing all viewer checks at search time. + """ + # Without sp_unique_id the link cannot be translated into a SharingLinks + # site group viewer either — both halves of the fix combine to give + # "no public access, no viewer either". + ac = await extract_access_control([_link_perm("organization")]) + assert ac.is_public is False + assert ac.viewers == [] + + +@pytest.mark.asyncio +async def test_organization_edit_link_with_sp_unique_id_yields_per_link_viewer(): + """When the file's SP UniqueId is known, an org+edit link translates.""" + perm = _link_perm("organization", type_="edit", link_id="LINK0001") + ac = await extract_access_control( + [perm], sp_unique_id="dd7691b0-3468-446f-81b0-72f3bdab7d1f" + ) + assert ac.is_public is False + assert ac.viewers == [ + "group:sp:sharinglinks.dd7691b0-3468-446f-81b0-72f3bdab7d1f.organizationedit.link0001" + ] + + +@pytest.mark.asyncio +async def test_organization_view_link_translates_to_organizationview_suffix(): + perm = _link_perm("organization", type_="view", link_id="LINK0002") + ac = await extract_access_control([perm], sp_unique_id="aaaa-bbbb") + assert ac.viewers == ["group:sp:sharinglinks.aaaa-bbbb.organizationview.link0002"] + + +@pytest.mark.asyncio +async def test_users_scope_link_translates_to_flexible_suffix(): + """Empirically verified: both users+edit and users+view collapse to Flexible.""" + perm_edit = _link_perm("users", type_="edit", link_id="LINKE") + perm_view = _link_perm("users", type_="view", link_id="LINKV") + ac_e = await extract_access_control([perm_edit], sp_unique_id="ITEM1") + ac_v = await extract_access_control([perm_view], sp_unique_id="ITEM1") + assert ac_e.viewers == ["group:sp:sharinglinks.item1.flexible.linke"] + assert ac_v.viewers == ["group:sp:sharinglinks.item1.flexible.linkv"] + + +@pytest.mark.asyncio +async def test_anonymous_link_does_not_get_translated_to_viewer(): + """Anonymous → is_public, never a SharingLinks viewer.""" + perm = _link_perm("anonymous", type_="view", link_id="LINKA") + ac = await extract_access_control([perm], sp_unique_id="ITEM1") + assert ac.is_public is True + assert ac.viewers == [] + + +@pytest.mark.asyncio +async def test_anonymous_link_sets_is_public(): + ac = await extract_access_control([_link_perm("anonymous")]) + assert ac.is_public is True + + +@pytest.mark.asyncio +async def test_org_and_anonymous_links_together_still_public_via_anonymous(): + ac = await extract_access_control([_link_perm("organization"), _link_perm("anonymous")]) + assert ac.is_public is True + + +@pytest.mark.asyncio +async def test_users_scoped_link_does_not_set_is_public(): + """``users``-scoped links target named recipients, not the org.""" + ac = await extract_access_control([_link_perm("users")]) + assert ac.is_public is False + + +@pytest.mark.asyncio +async def test_unknown_link_scope_does_not_set_is_public(): + """Future / unrecognized scopes default to non-public.""" + ac = await extract_access_control([_link_perm("someFutureScope")]) + assert ac.is_public is False + + +# --------------------------------------------------------------------------- +# Mixed permissions — the realistic case +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_org_link_alongside_explicit_grants_extracts_grants_and_link_group(): + """Org-link plus explicit grants → both end up in viewers. + + Mirrors the Mistral bug-report payload shape: a file with one + organization-scoped sharing link plus inherited site-group grants. + Post-fix, is_public is False, all explicit grants are present, and + the per-link SharingLinks site group is included exactly once. + """ + perms = [ + _link_perm("organization", link_id="LINK0001"), + _site_group_perm("Access Control Tests Owners", group_id="3", roles=["owner"]), + _site_group_perm("Access Control Tests Members", group_id="5", roles=["write"]), + _site_group_perm("Access Control Tests Visitors", group_id="4", roles=["read"]), + ] + ac = await extract_access_control(perms, sp_unique_id="ITEM1") + assert ac.is_public is False + assert set(ac.viewers) == { + "group:sp:access_control_tests_owners", + "group:sp:access_control_tests_members", + "group:sp:access_control_tests_visitors", + "group:sp:sharinglinks.item1.organizationedit.link0001", + } + + +@pytest.mark.asyncio +async def test_user_and_entra_group_grants_extracted(): + perms = [ + _user_perm("alice@example.com"), + _entra_group_perm("11111111-2222-3333-4444-555555555555"), + ] + ac = await extract_access_control(perms) + assert ac.is_public is False + assert set(ac.viewers) == { + "user:alice@example.com", + "group:entra:11111111-2222-3333-4444-555555555555", + } + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_empty_permissions_returns_empty_access_control(): + ac = await extract_access_control([]) + assert ac.is_public is False + assert ac.viewers == [] + + +@pytest.mark.asyncio +async def test_permission_without_read_role_is_ignored(): + """Roles other than read/write/owner/sp.full control don't grant viewing.""" + perms = [ + { + "id": "restricted", + "roles": ["restricted"], + "grantedToV2": {"user": {"email": "alice@example.com"}}, + }, + ] + ac = await extract_access_control(perms) + assert ac.is_public is False + assert ac.viewers == [] + + +@pytest.mark.asyncio +async def test_org_link_without_read_role_is_ignored_entirely(): + """A link without a read-equivalent role doesn't even reach scope check.""" + perms = [ + { + "id": "link-restricted", + "roles": ["restricted"], + "link": {"scope": "organization"}, + } + ] + ac = await extract_access_control(perms) + assert ac.is_public is False + assert ac.viewers == [] + + +@pytest.mark.asyncio +async def test_duplicate_principal_only_added_once(): + perms = [ + _user_perm("alice@example.com", roles=["read"]), + _user_perm("alice@example.com", roles=["write"]), + ] + ac = await extract_access_control(perms) + assert ac.viewers == ["user:alice@example.com"] + + +# --------------------------------------------------------------------------- +# link_permission_to_sp_group_viewer — None-return paths +# --------------------------------------------------------------------------- + + +def test_link_translation_returns_none_without_sp_unique_id(): + """No SP UniqueId means we can't construct the group name — return None.""" + perm = _link_perm("organization", link_id="L1") + assert link_permission_to_sp_group_viewer(perm, None) is None + + +def test_link_translation_returns_none_for_non_link_perm(): + """A non-link permission is not a sharing link — return None.""" + perm = {"id": "x", "roles": ["read"], "grantedToV2": {"user": {"email": "a@b.com"}}} + assert link_permission_to_sp_group_viewer(perm, "ITEM1") is None + + +def test_link_translation_returns_none_for_anonymous(): + perm = _link_perm("anonymous", link_id="L1") + assert link_permission_to_sp_group_viewer(perm, "ITEM1") is None + + +def test_link_translation_returns_none_for_unknown_scope(): + """Unknown / future scope: be conservative, don't fabricate a viewer.""" + perm = { + "id": "L1", + "roles": ["read"], + "link": {"scope": "future-scope", "type": "edit"}, + } + assert link_permission_to_sp_group_viewer(perm, "ITEM1") is None + + +def test_link_translation_returns_none_when_link_id_missing(): + perm = {"id": "", "roles": ["read"], "link": {"scope": "organization", "type": "edit"}} + assert link_permission_to_sp_group_viewer(perm, "ITEM1") is None diff --git a/backend/tests/unit/platform/sources/test_sharepoint_online_group_expansion.py b/backend/tests/unit/platform/sources/test_sharepoint_online_group_expansion.py new file mode 100644 index 000000000..45c2f8c54 --- /dev/null +++ b/backend/tests/unit/platform/sources/test_sharepoint_online_group_expansion.py @@ -0,0 +1,406 @@ +"""Unit tests for SharePoint Online SP site group expansion helpers. + +Covers _parse_sp_group_member, _email_from_membership_login, and the cursor +migration path for tracked_sp_groups. +""" + +from unittest.mock import MagicMock + +from airweave.platform.sources.sharepoint_online.source import ( + EVERYONE_EXCEPT_EXTERNAL_PRINCIPAL, + SharePointOnlineBase, +) + +# --------------------------------------------------------------------------- +# _email_from_membership_login +# --------------------------------------------------------------------------- + + +def test_email_from_membership_login_valid(): + assert ( + SharePointOnlineBase._email_from_membership_login("i:0#.f|membership|foo@bar.com") + == "foo@bar.com" + ) + + +def test_email_from_membership_login_uppercase_normalized(): + assert ( + SharePointOnlineBase._email_from_membership_login("i:0#.f|membership|Foo@BAR.com") + == "foo@bar.com" + ) + + +def test_email_from_membership_login_rejects_role_principal(): + # Role principals would otherwise yield "spo-grid-all-users/..." — must reject. + assert ( + SharePointOnlineBase._email_from_membership_login( + "c:0-.f|rolemanager|spo-grid-all-users/26adf163-2699-4d04-a0ad-3d935411bf45" + ) + is None + ) + + +def test_email_from_membership_login_rejects_federated_group(): + assert ( + SharePointOnlineBase._email_from_membership_login( + "c:0o.c|federateddirectoryclaimprovider|58cb1814-203a-44d0-8578-b53f63860579" + ) + is None + ) + + +def test_email_from_membership_login_rejects_empty(): + assert SharePointOnlineBase._email_from_membership_login("") is None + + +def test_email_from_membership_login_rejects_malformed(): + assert SharePointOnlineBase._email_from_membership_login("i:0#.f|membership|noat") is None + + +# --------------------------------------------------------------------------- +# _parse_sp_group_member +# --------------------------------------------------------------------------- + + +def test_parse_real_user_with_email(): + user = { + "PrincipalType": 1, + "LoginName": "i:0#.f|membership|alice@contoso.com", + "Email": "alice@contoso.com", + "Title": "Alice", + } + assert SharePointOnlineBase._parse_sp_group_member(user) == ( + "alice@contoso.com", + "user", + ) + + +def test_parse_real_user_uppercase_email_normalized(): + user = { + "PrincipalType": 1, + "LoginName": "i:0#.f|membership|ALICE@CONTOSO.COM", + "Email": "ALICE@CONTOSO.COM", + "Title": "Alice", + } + assert SharePointOnlineBase._parse_sp_group_member(user) == ( + "alice@contoso.com", + "user", + ) + + +def test_parse_real_user_email_empty_fallback_to_login(): + # If Email is missing but LoginName has the membership pattern, use that. + user = { + "PrincipalType": 1, + "LoginName": "i:0#.f|membership|alice@contoso.com", + "Email": "", + "Title": "Alice", + } + assert SharePointOnlineBase._parse_sp_group_member(user) == ( + "alice@contoso.com", + "user", + ) + + +def test_parse_real_user_no_email_no_parseable_login_returns_none(): + # System Account and similar — no Email, no membership LoginName. + user = { + "PrincipalType": 1, + "LoginName": "SHAREPOINT\\system", + "Email": "", + "Title": "System Account", + } + assert SharePointOnlineBase._parse_sp_group_member(user) is None + + +def test_parse_role_principal_skipped(): + """Bug B regression test — 'Everyone except external users' must not become a fake user.""" + user = { + "PrincipalType": 16, + "LoginName": "c:0-.f|rolemanager|spo-grid-all-users/26adf163-2699-4d04-a0ad-3d935411bf45", + "Email": "", + "Title": "Everyone except external users", + } + assert SharePointOnlineBase._parse_sp_group_member(user) is None + + +def test_parse_entra_group_emits_group_membership(): + """Bug C/D regression test — Entra group must be emitted as group-to-group.""" + user = { + "PrincipalType": 4, + "LoginName": "c:0o.c|federateddirectoryclaimprovider|58cb1814-203a-44d0-8578-b53f63860579", + "Email": "neena@neenacorp.onmicrosoft.com", # group's email, must NOT be used + "Title": "Neena Members", + } + assert SharePointOnlineBase._parse_sp_group_member(user) == ( + "entra:58cb1814-203a-44d0-8578-b53f63860579", + "group", + ) + + +def test_parse_entra_group_owner_suffix_stripped(): + """Owner-style claim has `_o` suffix — must strip it to get the bare GUID.""" + login = "c:0o.c|federateddirectoryclaimprovider|58cb1814-203a-44d0-8578-b53f63860579_o" + user = { + "PrincipalType": 4, + "LoginName": login, + "Email": "neena@neenacorp.onmicrosoft.com", + "Title": "Neena Owners", + } + assert SharePointOnlineBase._parse_sp_group_member(user) == ( + "entra:58cb1814-203a-44d0-8578-b53f63860579", + "group", + ) + + +def test_parse_entra_group_uppercase_guid_normalized(): + user = { + "PrincipalType": 4, + "LoginName": "c:0o.c|federateddirectoryclaimprovider|58CB1814-203A-44D0-8578-B53F63860579", + "Title": "Neena Owners", + } + assert SharePointOnlineBase._parse_sp_group_member(user) == ( + "entra:58cb1814-203a-44d0-8578-b53f63860579", + "group", + ) + + +def test_parse_entra_group_malformed_guid_returns_none(): + user = { + "PrincipalType": 4, + "LoginName": "c:0o.c|federateddirectoryclaimprovider|not-a-guid", + "Title": "Bad Group", + } + assert SharePointOnlineBase._parse_sp_group_member(user) is None + + +def test_parse_security_group_non_federated_returns_none(): + # PrincipalType=4 but not federated — on-prem AD claim, skip. + user = { + "PrincipalType": 4, + "LoginName": "c:0-.f|adclaimprovider|S-1-5-21-...", + "Title": "On-prem Group", + } + assert SharePointOnlineBase._parse_sp_group_member(user) is None + + +def test_parse_everyone_except_external_claim_returns_synthetic_sentinel(): + """The rolemanager/spo-grid-all-users claim → synthetic group sentinel.""" + user = { + "PrincipalType": 4, + "LoginName": ( + "c:0-.f|rolemanager|spo-grid-all-users/26adf163-2699-4d04-a0ad-3d935411bf45" + ), + "Title": "Everyone except external users", + } + assert SharePointOnlineBase._parse_sp_group_member(user) == ( + EVERYONE_EXCEPT_EXTERNAL_PRINCIPAL, + "group", + ) + + +def test_parse_everyone_except_external_uppercase_tenant_id(): + """Tenant ID GUIDs may be upper- or lowercase; both should match.""" + user = { + "PrincipalType": 4, + "LoginName": ( + "c:0-.f|rolemanager|spo-grid-all-users/26ADF163-2699-4D04-A0AD-3D935411BF45" + ), + "Title": "Everyone except external users", + } + assert SharePointOnlineBase._parse_sp_group_member(user) == ( + EVERYONE_EXCEPT_EXTERNAL_PRINCIPAL, + "group", + ) + + +def test_parse_other_rolemanager_claim_skipped_and_flagged_as_unrecognized(): + """A different rolemanager claim shouldn't match; should be flagged for logging.""" + user = { + "PrincipalType": 4, + "LoginName": "c:0-.f|rolemanager|some-future-claim", + "Title": "Custom Role", + } + assert SharePointOnlineBase._parse_sp_group_member(user) is None + assert SharePointOnlineBase._is_unrecognized_pt4_login(user) is True + + +def test_is_unrecognized_pt4_login_false_for_known_shapes(): + """Known PT=4 shapes (Entra group, claim) must NOT be flagged as unrecognized.""" + entra = { + "PrincipalType": 4, + "LoginName": ( + "c:0o.c|federateddirectoryclaimprovider|7d344400-39bc-4ee7-aa6e-437bd8de85c0" + ), + } + claim = { + "PrincipalType": 4, + "LoginName": "c:0-.f|rolemanager|spo-grid-all-users/26adf163-2699-4d04-a0ad-3d935411bf45", + } + assert SharePointOnlineBase._is_unrecognized_pt4_login(entra) is False + assert SharePointOnlineBase._is_unrecognized_pt4_login(claim) is False + + +def test_is_unrecognized_pt4_login_false_for_non_pt4(): + """The flag is scoped to PT=4 only; other PrincipalTypes are skipped silently.""" + user = {"PrincipalType": 1, "LoginName": "i:0#.f|membership|alice@example.com"} + assert SharePointOnlineBase._is_unrecognized_pt4_login(user) is False + + +def test_parse_distlist_skipped(): + user = {"PrincipalType": 2, "LoginName": "some-dl", "Title": "DL"} + assert SharePointOnlineBase._parse_sp_group_member(user) is None + + +def test_parse_spgroup_skipped(): + user = {"PrincipalType": 8, "LoginName": "some-sp", "Title": "SP"} + assert SharePointOnlineBase._parse_sp_group_member(user) is None + + +def test_parse_all_catchall_skipped(): + user = {"PrincipalType": 15, "LoginName": "everyone", "Title": "All"} + assert SharePointOnlineBase._parse_sp_group_member(user) is None + + +def test_parse_unknown_principal_type_skipped(): + user = {"PrincipalType": 99, "LoginName": "x", "Title": "X"} + assert SharePointOnlineBase._parse_sp_group_member(user) is None + + +def test_parse_missing_principal_type_skipped(): + user = {"LoginName": "x", "Title": "X", "Email": "x@y.z"} + assert SharePointOnlineBase._parse_sp_group_member(user) is None + + +# --------------------------------------------------------------------------- +# _normalize_site_url +# --------------------------------------------------------------------------- + + +def test_normalize_site_url_strips_trailing_slash(): + assert ( + SharePointOnlineBase._normalize_site_url("https://contoso.sharepoint.com/sites/X/") + == "https://contoso.sharepoint.com/sites/X" + ) + + +def test_normalize_site_url_empty(): + assert SharePointOnlineBase._normalize_site_url("") == "" + assert SharePointOnlineBase._normalize_site_url(None) == "" # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# _track_entity_groups with site_url scoping +# --------------------------------------------------------------------------- + + +class _StubEntity: + def __init__(self, viewers): + self.access = MagicMock() + self.access.viewers = viewers + + +def _bare_base() -> SharePointOnlineBase: + """Instantiate the base class just enough to exercise tracking logic. + + We bypass the normal source creation path since we only need the tracking + state and its methods. + """ + instance = SharePointOnlineBase.__new__(SharePointOnlineBase) + instance._site_url = "" + instance._include_personal_sites = False + instance._include_pages = False + instance._item_level_entra_groups = set() + instance._item_level_sp_groups = {} + return instance + + +def test_track_entity_groups_scopes_sp_by_site(): + base = _bare_base() + e = _StubEntity( + [ + "group:sp:neena_members", + "group:sp:neena_owners", + "group:entra:58cb1814-203a-44d0-8578-b53f63860579", + "user:alice@contoso.com", + ] + ) + base._track_entity_groups(e, "https://neenacorp.sharepoint.com/sites/Neena77") + + assert base._item_level_sp_groups == { + "https://neenacorp.sharepoint.com/sites/Neena77": { + "sp:neena_members", + "sp:neena_owners", + } + } + assert base._item_level_entra_groups == {"entra:58cb1814-203a-44d0-8578-b53f63860579"} + + +def test_track_entity_groups_multiple_sites_keep_separate(): + base = _bare_base() + base._track_entity_groups( + _StubEntity(["group:sp:neena_members"]), + "https://neenacorp.sharepoint.com/sites/A", + ) + base._track_entity_groups( + _StubEntity(["group:sp:access_control_tests_owners"]), + "https://neenacorp.sharepoint.com/sites/B", + ) + + assert base._item_level_sp_groups == { + "https://neenacorp.sharepoint.com/sites/A": {"sp:neena_members"}, + "https://neenacorp.sharepoint.com/sites/B": {"sp:access_control_tests_owners"}, + } + + +def test_track_entity_groups_same_name_different_sites_do_not_collide(): + base = _bare_base() + base._track_entity_groups( + _StubEntity(["group:sp:members"]), + "https://neenacorp.sharepoint.com/sites/A", + ) + base._track_entity_groups( + _StubEntity(["group:sp:members"]), + "https://neenacorp.sharepoint.com/sites/B", + ) + + # Same group name but two different sites — must be tracked independently. + assert set(base._item_level_sp_groups.keys()) == { + "https://neenacorp.sharepoint.com/sites/A", + "https://neenacorp.sharepoint.com/sites/B", + } + + +def test_track_entity_groups_normalizes_trailing_slash(): + base = _bare_base() + base._track_entity_groups( + _StubEntity(["group:sp:x"]), + "https://neenacorp.sharepoint.com/sites/A/", + ) + base._track_entity_groups( + _StubEntity(["group:sp:y"]), + "https://neenacorp.sharepoint.com/sites/A", + ) + # Both should land under the same normalized key. + assert base._item_level_sp_groups == { + "https://neenacorp.sharepoint.com/sites/A": {"sp:x", "sp:y"} + } + + +def test_track_entity_groups_no_access_noop(): + base = _bare_base() + entity = MagicMock() + entity.access = None + base._track_entity_groups(entity, "https://neenacorp.sharepoint.com/sites/A") + assert base._item_level_sp_groups == {} + + +def test_track_entity_groups_empty_site_url_still_stores_under_empty_key(): + """Groups are still stored under the empty-string key when no site_url. + + Expansion skips empty-key buckets, so this is effectively a no-op for + broker purposes but keeps the data structure consistent. + """ + base = _bare_base() + base._track_entity_groups(_StubEntity(["group:sp:orphan"]), "") + assert base._item_level_sp_groups == {"": {"sp:orphan"}} diff --git a/backend/tests/unit/platform/sync/test_token_providers.py b/backend/tests/unit/platform/sync/test_token_providers.py index 85803dee4..fc194117a 100644 --- a/backend/tests/unit/platform/sync/test_token_providers.py +++ b/backend/tests/unit/platform/sync/test_token_providers.py @@ -221,7 +221,7 @@ def _make_provider(self, access_token: str = "fresh_token"): mock_registry = MagicMock() entry = MagicMock() entry.runtime_auth_all_fields = ["access_token"] - entry.runtime_auth_optional_fields = [] + entry.runtime_auth_optional_fields = set() mock_registry.get.return_value = entry return AuthProviderTokenProvider( diff --git a/fern/docs/pages/concepts.mdx b/fern/docs/pages/concepts.mdx index a5387018f..9743e2141 100644 --- a/fern/docs/pages/concepts.mdx +++ b/fern/docs/pages/concepts.mdx @@ -6,10 +6,9 @@ slug: concepts Airweave connects to your apps, databases, and documents, then turns them into knowledge you can search. To understand how it works, you only need a few core concepts. - ## Source -A **Source** is an external application, database, or document store where your data lives. Sources are the systems Airweave pulls data from to build your retrieval layer. +A **Source** is an external application, database, or document store where your data lives. Sources are the systems from which Airweave pulls data to build your retrieval layer. Sources can be: - **Productivity tools**: Notion, Slack, Asana, Jira, Confluence @@ -23,7 +22,7 @@ Each source type has its own data structures, authentication methods, and API pa A **Connector** is the integration code that allows Airweave to communicate with a specific source. Each connector handles: -- **Authentication**: OAuth flows, API keys, or database credentials depending on the source +- **Authentication**: Using OAuth flows, API keys, or database credentials depending on the source - **Data extraction**: Fetching records, documents, or rows from the source API - **Entity mapping**: Transforming source-specific data structures into Airweave's unified entity format - **Incremental sync**: Tracking changes so only new or modified data is re-synced @@ -39,7 +38,7 @@ A **Source Connection** is a configured, authenticated instance of a connector l When you create a source connection, you: 1. Select a connector (e.g., Slack) 2. Authenticate with your credentials (e.g., OAuth login to your Slack workspace) -3. Assign it to a collection +3. Assign the connection to a collection Once created, Airweave continuously syncs data from that source connection, keeping your retrieval layer fresh and up-to-date. You can have multiple source connections of the same type (e.g., connecting to several different Slack workspaces). @@ -47,10 +46,10 @@ Once created, Airweave continuously syncs data from that source connection, keep An **Entity** is a single, searchable item extracted from a source. Entities are the atomic units of data that get indexed and returned in search results. -Examples of entities: +Examples of entities include: - A Slack message or thread - A Notion page or database row -- A GitHub issue, pull request, or code file +- A GitHub issue, pull request (PR), or code file - A Google Doc or spreadsheet - A Zendesk ticket or customer conversation - An Airtable record or row @@ -78,4 +77,4 @@ When an agent searches a collection, the query runs across all entities from all Collections can be queried via the REST API, SDKs, or MCP, making them accessible to any AI agent, RAG pipeline, or application that needs grounded, up-to-date context. -To learn more about querying collections, including the three search tiers (instant, classic, agentic), retrieval strategies, and filtering, see the [Search](/search) documentation. +To learn more about querying collections, including the three search tiers (instant, classic, and agentic), retrieval strategies, and filtering, see the [Search](/search) documentation. diff --git a/fern/docs/pages/connectors/overview.mdx b/fern/docs/pages/connectors/overview.mdx index f937a05bb..cbfde1337 100644 --- a/fern/docs/pages/connectors/overview.mdx +++ b/fern/docs/pages/connectors/overview.mdx @@ -18,14 +18,14 @@ Airweave supports many different types of connectors across productivity tools, ### Popular Connectors - - - - - - - - + + + + + + + + ### All Connectors @@ -33,57 +33,57 @@ Airweave supports many different types of connectors across productivity tools, ### Productivity & Collaboration -- [Notion](/connectors/notion) -- [Slack](/connectors/slack) -- [Asana](/connectors/asana) -- [Monday](/connectors/monday) -- [Linear](/connectors/linear) -- [Trello](/connectors/trello) -- [Clickup](/connectors/clickup) -- [Todoist](/connectors/todoist) -- [Airtable](/connectors/airtable) +- [Notion](/docs/connectors/notion) +- [Slack](/docs/connectors/slack) +- [Asana](/docs/connectors/asana) +- [Monday](/docs/connectors/monday) +- [Linear](/docs/connectors/linear) +- [Trello](/docs/connectors/trello) +- [Clickup](/docs/connectors/clickup) +- [Todoist](/docs/connectors/todoist) +- [Airtable](/docs/connectors/airtable) ### Cloud Storage & Documents -- [Google Drive](/connectors/google_drive) -- [Google Docs](/connectors/google_docs) -- [Google Slides](/connectors/google_slides) -- [Dropbox](/connectors/dropbox) -- [OneDrive](/connectors/onedrive) -- [Box](/connectors/box) -- [SharePoint](/connectors/sharepoint) -- [Word](/connectors/word) -- [OneNote](/connectors/onenote) +- [Google Drive](/docs/connectors/google_drive) +- [Google Docs](/docs/connectors/google_docs) +- [Google Slides](/docs/connectors/google_slides) +- [Dropbox](/docs/connectors/dropbox) +- [OneDrive](/docs/connectors/onedrive) +- [Box](/docs/connectors/box) +- [SharePoint](/docs/connectors/sharepoint) +- [Word](/docs/connectors/word) +- [OneNote](/docs/connectors/onenote) ### Developer Tools -- [GitHub](/connectors/github) -- [GitLab](/connectors/gitlab) -- [Bitbucket](/connectors/bitbucket) -- [Jira](/connectors/jira) -- [Confluence](/connectors/confluence) +- [GitHub](/docs/connectors/github) +- [GitLab](/docs/connectors/gitlab) +- [Bitbucket](/docs/connectors/bitbucket) +- [Jira](/docs/connectors/jira) +- [Confluence](/docs/connectors/confluence) ### CRM & Sales -- [Salesforce](/connectors/salesforce) -- [HubSpot](/connectors/hubspot) -- [Pipedrive](/connectors/pipedrive) -- [Attio](/connectors/attio) -- [Zoho CRM](/connectors/zoho_crm) +- [Salesforce](/docs/connectors/salesforce) +- [HubSpot](/docs/connectors/hubspot) +- [Pipedrive](/docs/connectors/pipedrive) +- [Attio](/docs/connectors/attio) +- [Zoho CRM](/docs/connectors/zoho_crm) ### Communication & Email -- [Gmail](/connectors/gmail) -- [Outlook Mail](/connectors/outlook_mail) -- [Outlook Calendar](/connectors/outlook_calendar) -- [Google Calendar](/connectors/google_calendar) -- [Teams](/connectors/teams) +- [Gmail](/docs/connectors/gmail) +- [Outlook Mail](/docs/connectors/outlook_mail) +- [Outlook Calendar](/docs/connectors/outlook_calendar) +- [Google Calendar](/docs/connectors/google_calendar) +- [Teams](/docs/connectors/teams) ### Support & Service -- [Zendesk](/connectors/zendesk) +- [Zendesk](/docs/connectors/zendesk) ### E-commerce & Payments -- [Shopify](/connectors/shopify) -- [Stripe](/connectors/stripe) +- [Shopify](/docs/connectors/shopify) +- [Stripe](/docs/connectors/stripe) ### Other -- [Ctti](/connectors/ctti) +- [Ctti](/docs/connectors/ctti) diff --git a/fern/docs/pages/quickstart.mdx b/fern/docs/pages/quickstart.mdx index 565a53f48..2ca124b45 100644 --- a/fern/docs/pages/quickstart.mdx +++ b/fern/docs/pages/quickstart.mdx @@ -12,7 +12,7 @@ Follow this guide to get up and running with Airweave in just a few steps. The simplest way to use Airweave is through our hosted cloud platform at [app.airweave.ai](https://app.airweave.ai). - If you prefer to run Airweave yourself, you can deploy it locally on macOS, Linux or WSL. After cloning the repository and starting the server, you will be able to open the dashboard at http://localhost:8080 + If you prefer to run Airweave yourself, you can deploy it locally on macOS, Linux or WSL. After cloning the repository and starting the server, open the dashboard at [http://localhost:8080](http://localhost:8080). ```bash git clone https://github.com/airweave-ai/airweave.git @@ -21,7 +21,7 @@ Follow this guide to get up and running with Airweave in just a few steps. ``` - + Airweave provides SDKs for Python and Node.js. Install the package. @@ -47,10 +47,10 @@ Follow this guide to get up and running with Airweave in just a few steps. > Your browser does not support the video tag. - Initialize the Airweave client with your new API key. For local deployments, set base_url to `"http://localhost:8001"`. + Initialize the Airweave client with your new API key. For local deployments, set `base_url` to `"http://localhost:8001"`. - ```Python title="Python" + ```python title="Python" from airweave import AirweaveSDK airweave = AirweaveSDK(api_key="YOUR_API_KEY", base_url="https://api.airweave.ai") @@ -69,7 +69,7 @@ Follow this guide to get up and running with Airweave in just a few steps. A collection is a group of different data sources that you can search using a single endpoint. - ```Python title="Python" + ```python title="Python" collection = airweave.collections.create(name="My First Collection") print(f"Created collection: {collection.readable_id}") @@ -92,11 +92,11 @@ Follow this guide to get up and running with Airweave in just a few steps. - - A source connection links a specific app or database to your collection. It handles authentication and automatically syncs data. + + Source connections link specific apps or databases to your collection. They handle authentication and automatically sync data. You need at least one source connection per collection. -```Python title="Python" +```python title="Python" source_connection = airweave.source_connections.create( name="My Stripe Connection", short_name="stripe", @@ -149,10 +149,10 @@ curl -X POST 'https://api.airweave.ai/source-connections' \ You can now search your collection and get the most relevant results from all connected sources. -```Python title="Python" +```python title="Python" response = airweave.collections.search.instant( readable_id=collection.readable_id, - query="Find returned payments from user John Doe?", + query="Find returned payments from user John Doe", ) for result in response.results: @@ -162,7 +162,7 @@ for result in response.results: ```javascript title="Node.js" const response = await airweave.collections.search.instant( collection.readableId, - { query: "Find returned payments from user John Doe?" } + { query: "Find returned payments from user John Doe" } ); response.results.forEach(result => console.log(result.name, result.relevanceScore)); @@ -172,7 +172,7 @@ response.results.forEach(result => console.log(result.name, result.relevanceScor curl -X POST 'https://api.airweave.ai/collections/my-first-collection-abc123/search/instant' \ -H 'x-api-key: YOUR_API_KEY' \ -H 'Content-Type: application/json' \ - -d '{"query": "Find returned payments from user John Doe?"}' + -d '{"query": "Find returned payments from user John Doe"}' ``` diff --git a/fern/docs/pages/search.mdx b/fern/docs/pages/search.mdx index 114f7736a..2d6615b20 100644 --- a/fern/docs/pages/search.mdx +++ b/fern/docs/pages/search.mdx @@ -5,46 +5,46 @@ edit-this-page-url: https://github.com/airweave-ai/airweave/blob/main/fern/docs/ slug: search --- -## Instant Search +## Instant search `POST /collections/{id}/search/instant` -Direct vector search. Use when speed is critical (~0.5sec). +Use this direct vector search when speed is critical (~0.5sec). -The only parameter unique to instant is `retrieval_strategy`, which controls how the vector database matches your query: +The only parameter unique to instant search is `retrieval_strategy`, which controls how the vector database matches your query: -- **`hybrid`** (default) — Combines semantic and keyword search via Reciprocal Rank Fusion. Best for most queries. -- **`semantic`** — Dense vector cosine similarity. Finds conceptually similar content even when wording differs. -- **`keyword`** — BM25 text matching. Only returns content with your exact terms. Use for error codes, identifiers, or known phrases. +- **`hybrid`** (default) combines semantic and keyword search via Reciprocal Rank Fusion. It's best for most queries. +- **`semantic`** uses dense vector cosine similarity. It finds conceptually similar content even when wording differs. +- **`keyword`** uses BM25 text matching and only returns content with your exact terms. Use it for error codes, identifiers, or known phrases. In classic and agentic search, the retrieval strategy is chosen automatically. -## Classic Search +## Classic search `POST /collections/{id}/search/classic` -AI-optimized search strategy. Sensible default for most use cases (~2sec). +Classic search uses an AI-optimized search strategy. It's a sensible default for most use cases (~2sec). An LLM analyzes your query and generates an optimized search strategy. -## Agentic Search +## Agentic search `POST /collections/{id}/search/agentic` -Agent that navigates through your collection to find the best results. Use when recall matters more than latency (<2min). +Agentic search uses an agent to navigate through your collection to find the best results. Use it when recall matters more than latency (<2min). An AI agent iteratively searches your data using tool calling. It searches with multiple strategies, reads full documents, navigates entity hierarchies (parent/child/sibling), and builds a comprehensive result set. -Two parameters unique to agentic: +Two parameters are unique to agentic search: -- **`thinking`** — Enables extended chain-of-thought reasoning before tool calls. Better search strategies, but slower and uses more tokens. Useful for complex or ambiguous queries. -- **`limit`** — Unlike instant/classic where the vector database always returns up to `limit` results, the agent collects results based on relevance. It may return fewer if it decides there aren't enough matches. Setting a limit caps the maximum — if the agent collects more, results are truncated. When `null` (default), there is no cap. +- **`thinking`** enables extended chain-of-thought reasoning before tool calls. It results in better search strategies, but is slower and uses more tokens. It's useful for complex or ambiguous queries. +- **`limit`** enables the agent to collect results based on relevance instead of having the vector data return up to `limit` results like in instant and classic search. The agent may return fewer results if it decides there aren't enough matches. Setting a limit caps the maximum — if the agent collects more, results are truncated. When the `limit` is `null` (default), there is no cap on the number of results. ### Streaming `POST /collections/{id}/search/agentic/stream` -Real-time SSE events as the agent works. Events are delivered as `data: {json}\n\n` messages. The stream terminates after a `done` or `error` event. +Streaming lets you see real-time SSE events as the agent works. Events are delivered as `data: {json}\n\n` messages. The stream terminates after a `done` or `error` event. Emitted once when the search begins. @@ -64,7 +64,7 @@ Emitted once when the search begins. -Emitted once per iteration after the LLM responds. `thinking` contains extended reasoning (when enabled), `text` contains conversational output before tool calls. +Emitted once per iteration after the LLM responds. `thinking` contains extended reasoning (when enabled), and `text` contains conversational output before tool calls. ```json { @@ -82,7 +82,7 @@ Emitted once per iteration after the LLM responds. `thinking` contains extended -Emitted after each tool the agent calls. `diagnostics.arguments` has the full tool input, `diagnostics.stats` has the output. The stats shape depends on which tool was called: +Emitted after each tool the agent calls. `diagnostics.arguments` has the full tool input, and `diagnostics.stats` has the output. The stats shape depends on which tool was called: ```json @@ -432,9 +432,9 @@ Emitted when the search fails. Also terminates the stream. ## Filters -Filters constrain search results by metadata. They work across all three tiers. +Filters constrain search results by metadata. They work across all three tiers (instant, classic, and agentic search). -In classic and agentic search, the AI generates its own filters internally, your filters are **AND'd into every search** it performs, acting as constraints that cannot be bypassed. +In classic and agentic search, the AI generates its own filters internally. Your filters are **AND'd into every search** it performs, acting as constraints that cannot be bypassed. ### Structure @@ -457,7 +457,7 @@ This allows expressions like: `(A AND B) OR (C AND D)` } ``` -### Filterable Fields +### Filterable fields | Field | Type | Description | |-------|------|-------------| @@ -534,7 +534,7 @@ This allows expressions like: `(A AND B) OR (C AND D)` } ``` -**Combine groups with OR — Slack messages OR Notion pages:** +**Combine groups with OR (Slack messages OR Notion pages):** ```json { @@ -555,7 +555,7 @@ This allows expressions like: `(A AND B) OR (C AND D)` } ``` -**Navigate hierarchy — find all entities inside a parent:** +**Navigate hierarchy (find all entities inside a parent):** ```json { @@ -569,7 +569,7 @@ This allows expressions like: `(A AND B) OR (C AND D)` } ``` -### Validation Rules +### Validation rules - Date fields (`created_at`, `updated_at`) require ISO 8601 timestamps (e.g., `2025-01-15T00:00:00Z`) - Ordering operators (`greater_than`, `less_than`, etc.) only work on date and numeric fields @@ -578,6 +578,63 @@ This allows expressions like: `(A AND B) OR (C AND D)` - Scalar operators (`equals`, `contains`, etc.) require a single value, not an array -## Response Format +## Response format -All three tiers return the same `SearchV2Response` with a `results` array. See the [API Reference](/api-reference/collections/instant-search) for the full response schema and interactive examples. +All three search tiers return the same `SearchV2Response` with a `results` array. See the [API Reference](/api-reference/collections/instant-search) for the full response schema and interactive examples. + + +## Configuring the LLM provider chain + + + This section is only relevant to self-hosted deployments. The managed service ships with providers configured. + + +Classic and Agentic search call an LLM. Instant search does not — a backend with no LLM configured still answers instant queries, and Classic/Agentic return HTTP 503 until an API key is set. + +### Default chain + +Out of the box, Airweave tries providers in this order: + +1. `together:zai-glm-5` +2. `anthropic:claude-sonnet-4.6` + +The first provider with an API key set that responds successfully handles the request. Subsequent entries are tried only on failure. + +### Setting API keys + +Set at least one of the following environment variables on the backend: + +| Env var | Provider | +|---|---| +| `TOGETHER_API_KEY` | Together | +| `ANTHROPIC_API_KEY` | Anthropic | +| `MISTRAL_API_KEY` | Mistral | +| `GROQ_API_KEY` | Groq | +| `CEREBRAS_API_KEY` | Cerebras | + +If none are set, the backend boots normally; Classic/Agentic search return `503 Service Unavailable` with a message listing these variables. + +### Overriding the chain + +Set `LLM_FALLBACK_CHAIN` to a comma-separated list of `provider:model` pairs. Example: + +``` +LLM_FALLBACK_CHAIN=cerebras:gpt-oss-120b,anthropic:claude-sonnet-4.6 +``` + +Supported providers: `cerebras`, `groq`, `anthropic`, `together`, `mistral`. The full list of models per provider lives in `backend/airweave/adapters/llm/registry.py`. + +The parser validates three things at startup: + +- Every provider is a known provider. +- Every model is a known model. +- Every `(provider, model)` combination exists in the registry (e.g. `together:mistral-large` is rejected because `mistral-large` is hosted on Mistral, not Together). + +Misconfiguration is caught at startup with an error that lists the accepted values. + +### Fallback semantics + +- Providers without an API key are silently skipped when the chain is built. +- Providers whose initialization raises are logged and skipped. +- If the resulting chain is empty, the backend wires a null LLM — instant search still works; Classic/Agentic return 503. +- When a call fails in a chained provider, the next one is tried; a circuit breaker temporarily removes providers that recently failed. diff --git a/fern/docs/pages/welcome.mdx b/fern/docs/pages/welcome.mdx index e9e04fc09..0fab0d7b5 100644 --- a/fern/docs/pages/welcome.mdx +++ b/fern/docs/pages/welcome.mdx @@ -42,9 +42,9 @@ Airweave continuously syncs information from connected sources and makes it avai ## Who it's for -Developers and teams building AI agents and other AI-powered applications that need reliable access to information across multiple tools and data sources. +Ideal for developers and teams building AI agents and AI-powered applications, Airweave provides reliable access to information across multiple tools and data sources. -If you're working on long-running AI agents, retrieval-augmented generation, or any context-heavy LLM application, Airweave provides the infrastructure to retrieve and manage context without maintaining bespoke integrations for every data source. +If you're working on long-running AI agents, retrieval-augmented generation (RAG), or any context-heavy LLM application, Airweave provides the infrastructure to retrieve and manage context without maintaining bespoke integrations for every data source. ## Use cases @@ -69,7 +69,7 @@ Common use cases include: title="Multi-Source Context Retrieval" icon="fa-solid fa-database" > - Retrieve and combine relevant context from structured and unstructured sources at query time. + Retrieve and combine relevant context from structured and unstructured sources at query time @@ -97,6 +97,6 @@ In all cases, Airweave helps agents retrieve facts from the right source instead icon="fa-solid fa-home" href="https://airweave.ai" > - High-level overview and latest product updates + See a high-level overview and the latest product updates diff --git a/frontend/src/components/auth-providers/AuthProviderConnectionsList.tsx b/frontend/src/components/auth-providers/AuthProviderConnectionsList.tsx new file mode 100644 index 000000000..f71a7a1c9 --- /dev/null +++ b/frontend/src/components/auth-providers/AuthProviderConnectionsList.tsx @@ -0,0 +1,92 @@ +import React from "react"; +import { useTheme } from "@/lib/theme-provider"; +import { cn } from "@/lib/utils"; +import { Plus } from "lucide-react"; +import { Button } from "@/components/ui/button"; +import { getAuthProviderIconUrl } from "@/lib/utils/icons"; +import { AuthProvider, AuthProviderConnection, useAuthProvidersStore } from "@/lib/stores/authProviders"; +import { format } from "date-fns"; + +interface AuthProviderConnectionsListProps { + authProvider: AuthProvider; + onSelectConnection: (connection: AuthProviderConnection) => void; + onAddNew: () => void; + onCancel: () => void; +} + +export const AuthProviderConnectionsList: React.FC = ({ + authProvider, + onSelectConnection, + onAddNew, + onCancel, +}) => { + const { resolvedTheme } = useTheme(); + const isDark = resolvedTheme === "dark"; + const { authProviderConnections } = useAuthProvidersStore(); + + const connections = authProviderConnections.filter( + (conn) => conn.short_name === authProvider?.short_name + ); + + return ( +
+ {/* Header */} +
+
+ {`${authProvider?.name} +
+
+

{authProvider?.name}

+

+ {connections.length} connection{connections.length !== 1 ? "s" : ""} +

+
+
+ + {/* Connections list */} +
+ {connections.map((conn) => ( + + ))} +
+ + {/* Actions */} +
+ + +
+
+ ); +}; diff --git a/frontend/src/components/auth-providers/AuthProviderDetailView.tsx b/frontend/src/components/auth-providers/AuthProviderDetailView.tsx index 3e23df8b2..6090866ae 100644 --- a/frontend/src/components/auth-providers/AuthProviderDetailView.tsx +++ b/frontend/src/components/auth-providers/AuthProviderDetailView.tsx @@ -226,20 +226,8 @@ export const AuthProviderDetailView: React.FC = ({ const [isDeleting, setIsDeleting] = useState(false); const [isClosing, setIsClosing] = useState(false); - console.log('🔍 [AuthProviderDetailView] Component mounted with:', { - authProviderConnectionId, - authProviderName, - authProviderShortName, - viewData - }); - - // Log component lifecycle useEffect(() => { - console.log('🌟 [AuthProviderDetailView] useEffect mount check'); - return () => { - console.log('💥 [AuthProviderDetailView] Component unmounting'); - // Clear any stored errors when component unmounts clearStoredErrorDetails(); }; }, []); @@ -247,16 +235,12 @@ export const AuthProviderDetailView: React.FC = ({ // Fetch connection details useEffect(() => { if (!authProviderConnectionId || isClosing) { - if (!authProviderConnectionId) { - console.warn('⚠️ [AuthProviderDetailView] No authProviderConnectionId provided'); - } return; } let isMounted = true; const fetchConnectionDetails = async () => { - console.log('📡 [AuthProviderDetailView] Fetching connection details for:', authProviderConnectionId); setLoading(true); try { const response = await apiClient.get(`/auth-providers/connections/${authProviderConnectionId}`); diff --git a/frontend/src/components/auth-providers/AuthProviderDialog.tsx b/frontend/src/components/auth-providers/AuthProviderDialog.tsx index a8b4db3df..cf2263340 100644 --- a/frontend/src/components/auth-providers/AuthProviderDialog.tsx +++ b/frontend/src/components/auth-providers/AuthProviderDialog.tsx @@ -2,6 +2,7 @@ import React, { useState, useEffect } from "react"; import { Dialog, DialogContent } from "@/components/ui/dialog"; import { ConfigureAuthProviderView } from "./ConfigureAuthProviderView"; import { AuthProviderDetailView } from "./AuthProviderDetailView"; +import { AuthProviderConnectionsList } from "./AuthProviderConnectionsList"; import { EditAuthProviderView } from "@/components/shared/views/EditAuthProviderView"; import { useTheme } from "@/lib/theme-provider"; import { cn } from "@/lib/utils"; @@ -14,7 +15,7 @@ export type { DialogViewProps }; interface AuthProviderDialogProps { open: boolean; onOpenChange: (open: boolean) => void; - mode: 'auth-provider' | 'auth-provider-detail' | 'auth-provider-edit'; + mode: 'auth-provider' | 'auth-provider-detail' | 'auth-provider-edit' | 'auth-provider-list'; authProvider: any; connection?: any; onComplete?: (result: any) => void; @@ -78,8 +79,6 @@ export const AuthProviderDialog: React.FC = ({ }; const handleNext = (data?: any) => { - console.log("🚀 [AuthProviderDialog] handleNext called with:", data); - // Merge new data with existing viewData const newViewData = { ...viewData, ...data }; setViewData(newViewData); @@ -91,8 +90,6 @@ export const AuthProviderDialog: React.FC = ({ }; const handleComplete = (result?: any) => { - console.log("✅ [AuthProviderDialog] handleComplete called with:", result); - // Handle different completion actions if (result?.action === 'edit') { // Switch to edit mode @@ -164,6 +161,25 @@ export const AuthProviderDialog: React.FC = ({ /> ); + case 'auth-provider-list': + return ( + { + setViewData(prev => ({ + ...prev, + authProviderConnectionId: conn.readable_id, + authProviderConnectionName: conn.name, + })); + setCurrentView('auth-provider-detail'); + }} + onAddNew={() => { + setCurrentView('auth-provider'); + }} + onCancel={handleCancel} + /> + ); + default: return null; } diff --git a/frontend/src/components/auth-providers/AuthProviderTable.tsx b/frontend/src/components/auth-providers/AuthProviderTable.tsx index 611706a4f..d38b9967f 100644 --- a/frontend/src/components/auth-providers/AuthProviderTable.tsx +++ b/frontend/src/components/auth-providers/AuthProviderTable.tsx @@ -24,7 +24,7 @@ export const AuthProviderTable = () => { const [dialogOpen, setDialogOpen] = useState(false); const [selectedAuthProvider, setSelectedAuthProvider] = useState(null); const [selectedConnection, setSelectedConnection] = useState(null); - const [dialogMode, setDialogMode] = useState<'auth-provider' | 'auth-provider-detail' | 'auth-provider-edit'>('auth-provider'); + const [dialogMode, setDialogMode] = useState<'auth-provider' | 'auth-provider-detail' | 'auth-provider-edit' | 'auth-provider-list'>('auth-provider'); const [remountKey, setRemountKey] = useState(0); // Fetch auth providers and connections on component mount @@ -33,79 +33,35 @@ export const AuthProviderTable = () => { Promise.all([ fetchAuthProviders(), fetchAuthProviderConnections() - ]).then(([providers, connections]) => { - console.log(`🔄 [AuthProviderTable] Auth providers loaded: ${providers.length} providers, ${connections.length} connections`); - }); + ]); }, [fetchAuthProviders, fetchAuthProviderConnections]); - // Log state changes - useEffect(() => { - console.log('🎮 [AuthProviderTable] State changed:', { - dialogOpen, - dialogMode, - selectedAuthProvider: selectedAuthProvider?.short_name, - selectedConnection: selectedConnection?.readable_id - }); - }, [dialogOpen, dialogMode, selectedAuthProvider, selectedConnection]); - - // Log when connections change - useEffect(() => { - console.log('📊 [AuthProviderTable] Auth provider connections changed:', { - count: authProviderConnections.length, - isLoadingConnections - }); - }, [authProviderConnections, isLoadingConnections]); - - // Log when dialog open state changes - useEffect(() => { - console.log('🚨 [AuthProviderTable] dialogOpen state changed to:', dialogOpen); - }, [dialogOpen]); + const handleAuthProviderClick = (authProvider: any) => { + const connections = authProviderConnections.filter(conn => conn.short_name === authProvider.short_name); - // Log dialog state for debugging + setSelectedAuthProvider(authProvider); - const handleAuthProviderClick = (authProvider: any) => { - console.log('🖱️ [AuthProviderTable] handleAuthProviderClick called:', { - authProvider: authProvider.short_name, - hasConnection: !!authProviderConnections.find(conn => conn.short_name === authProvider.short_name) - }); - - const connection = authProviderConnections.find(conn => conn.short_name === authProvider.short_name); - - if (connection) { - console.log('🔗 [AuthProviderTable] Found existing connection:', connection.readable_id); - // Auth provider is already connected, show details - setSelectedAuthProvider(authProvider); - setSelectedConnection(connection); - setDialogMode('auth-provider-detail'); + if (connections.length === 0) { + if (!canManage) { + toast.info("Only admins can configure auth providers"); + return; + } + setSelectedConnection(null); + setDialogMode('auth-provider'); setDialogOpen(true); - } else if (!canManage) { - toast.info("Only admins can configure auth providers"); - return; } else { - console.log('➕ [AuthProviderTable] No connection found, opening configure dialog'); - // Auth provider not connected, show configure dialog - setSelectedAuthProvider(authProvider); setSelectedConnection(null); - setDialogMode('auth-provider'); + setDialogMode('auth-provider-list'); setDialogOpen(true); } - - console.log('🎯 [AuthProviderTable] Dialog state after click:', { - dialogOpen: true, - dialogMode: connection ? 'auth-provider-detail' : 'auth-provider', - selectedAuthProvider: authProvider.short_name - }); }; const handleDialogComplete = (result: any) => { - console.log("🏁 [AuthProviderTable] Dialog completed:", result); - // Close the dialog setDialogOpen(false); // If it was an edit action, open edit dialog if (result?.action === 'edit') { - console.log("✏️ [AuthProviderTable] Edit action requested, opening edit dialog"); // Store the auth provider details for edit dialog const tempAuthProvider = selectedAuthProvider; @@ -130,7 +86,6 @@ export const AuthProviderTable = () => { // If it was an updated action, open detail dialog with refreshed data if (result?.action === 'updated') { - console.log("✅ [AuthProviderTable] Auth provider connection was updated"); // Find the updated connection from the refreshed list const updatedConnection = authProviderConnections.find( @@ -153,7 +108,6 @@ export const AuthProviderTable = () => { // If it was a deletion, increment remountKey to force dialog remount if (result?.action === 'deleted') { - console.log("🗑️ [AuthProviderTable] Auth provider connection was deleted"); setRemountKey(prev => prev + 1); } @@ -164,7 +118,6 @@ export const AuthProviderTable = () => { // Refresh connections if a new one was created or deleted if (result?.success) { - console.log("♻️ [AuthProviderTable] Refreshing auth provider connections"); fetchAuthProviderConnections(); } }; @@ -187,9 +140,7 @@ export const AuthProviderTable = () => { // Memoize dialog key to prevent remounts const dialogKey = useMemo(() => { // Only use auth provider short name as key since connection ID isn't available when creating new - const key = dialogOpen ? `auth-${selectedAuthProvider?.short_name || 'none'}-${remountKey}` : 'closed'; - console.log('🔑 [AuthProviderTable] Dialog key:', key); - return key; + return dialogOpen ? `auth-${selectedAuthProvider?.short_name || 'none'}-${remountKey}` : 'closed'; }, [dialogOpen, selectedAuthProvider?.short_name, remountKey]); return ( @@ -205,7 +156,7 @@ export const AuthProviderTable = () => { ) : ( allProviders.map(provider => { - const connection = authProviderConnections.find( + const connections = authProviderConnections.filter( conn => conn.short_name === provider.short_name ); @@ -215,7 +166,8 @@ export const AuthProviderTable = () => { id={provider.short_name} name={provider.name} shortName={provider.short_name} - isConnected={!!connection} + isConnected={connections.length > 0} + connectionCount={connections.length} isComingSoon={'isComingSoon' in provider ? provider.isComingSoon : false} onClick={() => handleAuthProviderClick(provider)} /> diff --git a/frontend/src/components/auth-providers/ConfigureAuthProviderView.tsx b/frontend/src/components/auth-providers/ConfigureAuthProviderView.tsx index be93ebb38..3c77b3541 100644 --- a/frontend/src/components/auth-providers/ConfigureAuthProviderView.tsx +++ b/frontend/src/components/auth-providers/ConfigureAuthProviderView.tsx @@ -81,19 +81,6 @@ export const ConfigureAuthProviderView: React.FC const navigate = useNavigate(); const { fetchAuthProviderConnections } = useAuthProvidersStore(); - // Log component lifecycle - useEffect(() => { - console.log('🌟 [ConfigureAuthProviderView] Component mounted:', { - authProviderName, - authProviderShortName, - viewData - }); - - return () => { - console.log('💥 [ConfigureAuthProviderView] Component unmounting'); - }; - }, []); - const [isSubmitting, setIsSubmitting] = useState(false); const [loading, setLoading] = useState(true); const [authProviderDetails, setAuthProviderDetails] = useState(null); @@ -105,11 +92,6 @@ export const ConfigureAuthProviderView: React.FC const [airweaveImageError, setAirweaveImageError] = useState(false); const [authProviderImageError, setAuthProviderImageError] = useState(false); - // Log loading state changes - useEffect(() => { - console.log('⏳ [ConfigureAuthProviderView] Loading state:', loading); - }, [loading]); - // Default name for the connection const defaultConnectionName = authProviderName ? `My ${authProviderName} Connection` : "My Connection"; @@ -157,30 +139,18 @@ export const ConfigureAuthProviderView: React.FC // Fetch auth provider details useEffect(() => { - console.log('🔍 [ConfigureAuthProviderView] Auth provider details effect triggered:', { - authProviderShortName, - currentLoading: loading - }); - if (!authProviderShortName) { - console.log('⚠️ [ConfigureAuthProviderView] No authProviderShortName, skipping fetch'); setLoading(false); return; } const fetchDetails = async () => { - console.log('🚀 [ConfigureAuthProviderView] Starting to fetch auth provider details'); setLoading(true); try { const response = await apiClient.get(`/auth-providers/detail/${authProviderShortName}`); - console.log('📡 [ConfigureAuthProviderView] Auth provider details response:', response.ok); if (response.ok) { const data = await response.json(); - console.log('✅ [ConfigureAuthProviderView] Auth provider details loaded:', { - hasAuthFields: !!data.auth_fields, - fieldsCount: data.auth_fields?.fields?.length || 0 - }); setAuthProviderDetails(data); // Initialize auth field values @@ -195,16 +165,13 @@ export const ConfigureAuthProviderView: React.FC } } else { const errorText = await response.text(); - console.error('❌ [ConfigureAuthProviderView] Failed to load auth provider details:', errorText); throw new Error(`Failed to load auth provider details: ${errorText}`); } } catch (error) { - console.error("Error fetching auth provider details:", error); if (onError) { onError(error instanceof Error ? error : new Error(String(error)), authProviderName); } } finally { - console.log('🏁 [ConfigureAuthProviderView] Setting loading to false'); setLoading(false); } }; @@ -391,16 +358,7 @@ export const ConfigureAuthProviderView: React.FC duration: 5000, }); - // Navigate to detail view BEFORE refreshing connections - console.log('🎯 [ConfigureAuthProviderView] Connection created successfully:', { - connectionId: connection.id, - readableId: connection.readable_id, - name: connection.name, - shortName: connection.short_name - }); - if (onNext) { - console.log('🚀 [ConfigureAuthProviderView] Calling onNext to navigate to detail view'); onNext({ authProviderConnectionId: connection.readable_id, authProviderName: authProviderName, // Use the original auth provider name, not connection name @@ -408,17 +366,11 @@ export const ConfigureAuthProviderView: React.FC isNewConnection: true // Flag to indicate this is a new connection }); - // Refresh connections after navigation - testing without delay - console.log('📡 [ConfigureAuthProviderView] Refreshing auth provider connections after navigation'); fetchAuthProviderConnections(); } else { - console.warn('⚠️ [ConfigureAuthProviderView] onNext is not defined!'); - // If no onNext, refresh immediately await fetchAuthProviderConnections(); } } catch (error) { - console.error("Error creating auth provider connection:", error); - // Extract error message from the response let errorMessage = "Failed to create connection"; if (error instanceof Error) { @@ -654,6 +606,37 @@ export const ConfigureAuthProviderView: React.FC )} + {authProviderShortName === 'custom' && ( + + + + + + +
+

Your endpoint must implement:

+

GET {'{base_url}'} — return 2xx (used for validation)

+

GET {'{base_url}/{source_connection_id}'} — return JSON credentials

+

Response: {`{"access_token": "..."}`} or {`{"api_key": "..."}`}

+

Auth: X-API-Key header sent with every request

+

No refresh_token needed — Airweave re-fetches automatically

+
+
+
+
+ )}
{authProviderDetails.auth_fields.fields.map((field: any) => ( diff --git a/frontend/src/components/dashboard/AuthProviderButton.tsx b/frontend/src/components/dashboard/AuthProviderButton.tsx index a3434f2ec..5a006c54f 100644 --- a/frontend/src/components/dashboard/AuthProviderButton.tsx +++ b/frontend/src/components/dashboard/AuthProviderButton.tsx @@ -10,7 +10,8 @@ interface AuthProviderButtonProps { name: string; shortName: string; isConnected?: boolean; - isComingSoon?: boolean; // Add this prop + connectionCount?: number; + isComingSoon?: boolean; onClick?: () => void; } @@ -19,7 +20,8 @@ export const AuthProviderButton = ({ name, shortName, isConnected = false, - isComingSoon = false, // Add default value + connectionCount = 0, + isComingSoon = false, onClick }: AuthProviderButtonProps) => { const { resolvedTheme } = useTheme(); @@ -30,12 +32,6 @@ export const AuthProviderButton = ({ // Don't handle clicks for coming soon providers if (isComingSoon) return; - console.log('🔘 [AuthProviderButton] Button clicked:', { - id, - name, - shortName, - isConnected - }); if (onClick) { onClick(); } @@ -78,7 +74,19 @@ export const AuthProviderButton = ({ )}
- {name} +
+ {name} + {connectionCount > 1 && ( + + {connectionCount} + + )} +
{isComingSoon && ( Coming soon diff --git a/frontend/src/components/icons/auth_providers/custom-dark.svg b/frontend/src/components/icons/auth_providers/custom-dark.svg new file mode 100644 index 000000000..ecef112bf --- /dev/null +++ b/frontend/src/components/icons/auth_providers/custom-dark.svg @@ -0,0 +1,13 @@ + + + + + + + + + + + + + diff --git a/frontend/src/components/icons/auth_providers/custom-light.svg b/frontend/src/components/icons/auth_providers/custom-light.svg new file mode 100644 index 000000000..822c72096 --- /dev/null +++ b/frontend/src/components/icons/auth_providers/custom-light.svg @@ -0,0 +1,13 @@ + + + + + + + + + + + + + diff --git a/frontend/src/lib/constants/feature-flags.ts b/frontend/src/lib/constants/feature-flags.ts index 223868021..a1694c019 100644 --- a/frontend/src/lib/constants/feature-flags.ts +++ b/frontend/src/lib/constants/feature-flags.ts @@ -18,6 +18,9 @@ export const FeatureFlags = { // Connect CONNECT: 'connect', + + // Auth Providers + CUSTOM_AUTH_PROVIDER: 'custom_auth_provider', } as const; export type FeatureFlag = typeof FeatureFlags[keyof typeof FeatureFlags]; diff --git a/frontend/src/lib/stores/authProviders.ts b/frontend/src/lib/stores/authProviders.ts index d958e1549..ec8a57677 100644 --- a/frontend/src/lib/stores/authProviders.ts +++ b/frontend/src/lib/stores/authProviders.ts @@ -116,7 +116,6 @@ export const useAuthProvidersStore = create((set, get) => ({ }, clearAuthProviderConnections: () => { - console.log("🧹 [AuthProvidersStore] Clearing auth provider connections"); set({ authProviderConnections: [] }); } }));