diff --git a/docs/how-to/oci-dac.md b/docs/how-to/oci-dac.md new file mode 100644 index 00000000..60103570 --- /dev/null +++ b/docs/how-to/oci-dac.md @@ -0,0 +1,117 @@ +# OCI Dedicated AI Cluster (DAC) endpoints + +OCI GenAI exposes two serving modes: + +- **On-demand** — pay-per-token against a shared model id (`openai.gpt-5.5`, + `cohere.command-r-plus-08-2024`, …). What `Agent(model="oci:openai.gpt-5.5")` + has been using by default. +- **Dedicated AI Cluster (DAC)** — provisioned capacity exposed as a + *generative AI endpoint* OCID + (`ocid1.generativeaiendpoint.oc1.....`). Inference is routed + to your cluster, with predictable latency and isolation guarantees. + +Locus auto-routes DAC endpoint OCIDs to the SDK transport (`OCIModel`) +because the V1 OpenAI-compatible endpoint doesn't speak +`DedicatedServingMode`. Pass the endpoint OCID exactly the way you'd +pass a model id: + +```python +from locus import Agent + +agent = Agent( + model="oci:ocid1.generativeaiendpoint.oc1.....", + compartment_id="ocid1.compartment.oc1...", # required for DAC + profile_name="DEFAULT", # any profile in ~/.oci/config + system_prompt="...", +) +``` + +Behind the scenes: + +```text +get_model("oci:ocid1.generativeaiendpoint....") + → OCIModel(model_id="ocid1.generativeaiendpoint....") + → OCIClient.get_serving_mode(...) + returns DedicatedServingMode(endpoint_id=...) + → SDK chat() routes to your DAC. +``` + +## Streaming + +`OCIModel.stream()` flips `is_stream=True` on the underlying +`GenericChatRequest` / `CohereChatRequest` and iterates the SSE event +stream the SDK returns. Works for both on-demand and DAC serving +modes, and for both Generic (Llama / OpenAI / xAI / Mistral / Gemini +on OCI) and Cohere R-series request shapes: + +```python +async for event in agent.run("Plan Q3"): + if isinstance(event, ModelChunkEvent) and event.content: + print(event.content, end="", flush=True) +``` + +Each SSE event carries a JSON delta. The provider's +`parse_stream_chunk()` extracts text deltas and tool-call deltas; if +the endpoint rejects `is_stream` (some custom DAC deployments do), +the stream falls back to the non-streaming `complete()` path and +yields a single chunk with the full content — so a misconfigured +endpoint never hard-fails the stream. + +## Auth + +DAC endpoints accept the same auth methods as on-demand: + +| Method | When | Kwarg | +|---|---|---| +| API key | local dev with `~/.oci/config` | `profile_name="DEFAULT"` | +| Session token | corporate SSO | `auth_type="security_token", profile_name="..."` | +| Instance principal | OCI VM / OKE / Functions | `auth_type="instance_principal"` | +| Resource principal | OCI Functions, Data Science | `auth_type="resource_principal"` | + +`compartment_id` is **required** for DAC — the dedicated endpoint +exists in a specific compartment, and the SDK validates the +`compartment_id` field on every chat request. + +## Tutorial-style env-var workflow + +`examples/config.py`'s `_pick_oci_transport()` recognises DAC OCIDs +and routes them to the SDK transport automatically: + +```bash +export LOCUS_MODEL_PROVIDER=oci +export LOCUS_MODEL_ID="ocid1.generativeaiendpoint.oc1....." +export LOCUS_OCI_PROFILE=MY_PROFILE +export LOCUS_OCI_COMPARTMENT="ocid1.compartment.oc1..." +python examples/tutorial_01_basic_agent.py +``` + +`LOCUS_OCI_TRANSPORT=sdk` forces the SDK transport explicitly if you +have a hosted model that uses an OCID-shaped name but isn't a real +DAC endpoint. + +## Things that go wrong + +| Symptom | Likely cause | +|---|---| +| `404 Not Found` on chat | Endpoint OCID is from a different region than the SDK is talking to. Pass the right `service_endpoint=` (or set `LOCUS_OCI_REGION`) to match the endpoint's region. | +| `compartment_id is required` | Pass `compartment_id=` on `Agent()` — DAC enforces it even when on-demand wouldn't. | +| Stream yields one big chunk instead of deltas | The endpoint rejected `is_stream`. The fall-back path swallows the failure and emits the full response as one chunk; check `OCI_LOG_REQUESTS=1` to see the API error. | +| `You are not authorized to perform this request` | The principal you're authenticating with doesn't have the `inspect generative-ai-endpoints` policy in the endpoint's compartment. | + +## Where the wiring lives + +- [`src/locus/models/registry.py`](https://github.com/oracle-samples/locus/blob/main/src/locus/models/registry.py) + — DAC OCIDs are detected by `lowered.startswith("ocid1.generativeaiendpoint.")` + and routed to `OCIModel`. +- [`src/locus/models/providers/oci/client.py`](https://github.com/oracle-samples/locus/blob/main/src/locus/models/providers/oci/client.py) + — `OCIClient.get_serving_mode()` returns `DedicatedServingMode(endpoint_id=...)` + for OCID-shaped model ids. +- [`src/locus/models/providers/oci/__init__.py`](https://github.com/oracle-samples/locus/blob/main/src/locus/models/providers/oci/__init__.py) + — `OCIModel.stream()` does the real SSE iteration. +- [`tests/unit/test_oci_dac.py`](https://github.com/oracle-samples/locus/blob/main/tests/unit/test_oci_dac.py) + — 12 unit tests covering routing, serving-mode selection, and stream-chunk parsing. + +## Related + +- [OCI GenAI](../concepts/providers/oci.md) — overview of the V1 vs SDK transports. +- [OCI models how-to](oci-models.md) — full transport story for on-demand. diff --git a/examples/config.py b/examples/config.py index c8463b0f..b3de6a9f 100644 --- a/examples/config.py +++ b/examples/config.py @@ -178,7 +178,13 @@ def _pick_oci_transport(model_id: str) -> str: forced = os.environ.get("LOCUS_OCI_TRANSPORT") if forced in ("v1", "sdk"): return forced - return "sdk" if model_id.lower().startswith("cohere.command-r") else "v1" + lowered = model_id.lower() + # DAC endpoint OCIDs and Cohere R-series both need the SDK + # transport (DedicatedServingMode for the former, the proprietary + # Cohere chat shape for the latter). + if lowered.startswith(("ocid1.generativeaiendpoint.", "cohere.command-r")): + return "sdk" + return "v1" def _get_oci_model(**kwargs: Any) -> Any: diff --git a/mkdocs.yml b/mkdocs.yml index 7a2864d4..bdb34657 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -164,6 +164,7 @@ nav: - Build a custom tool: how-to/custom-tools.md - Add a checkpointer backend: how-to/custom-checkpointer.md - OCI GenAI models: how-to/oci-models.md + - OCI Dedicated AI Cluster (DAC): how-to/oci-dac.md - API reference: - Agent: api/agent.md - Checkpointers: api/checkpointers.md diff --git a/src/locus/models/providers/oci/__init__.py b/src/locus/models/providers/oci/__init__.py index ad40f69b..2a20248a 100644 --- a/src/locus/models/providers/oci/__init__.py +++ b/src/locus/models/providers/oci/__init__.py @@ -244,24 +244,93 @@ async def stream( tools: list[dict[str, Any]] | None = None, **kwargs: Any, ) -> AsyncIterator[ModelChunkEvent]: - """Stream a chat response. + """Stream a chat response via the OCI GenAI SDK. - Note: OCI GenAI streaming is limited. This implementation - falls back to non-streaming and yields the full response. + Sets ``is_stream=True`` on the chat request so the SDK returns + an SSE event stream. Each ``data:`` event carries a JSON chunk + with ``message.content`` deltas and (on the last event) + ``finishReason``. Works for both ``OnDemandServingMode`` + (model id) and ``DedicatedServingMode`` (DAC endpoint OCID). + + On any exception the stream falls back to the non-streaming + ``complete()`` path and yields a single chunk with the full + content — robust to providers that reject ``is_stream``. """ - # OCI GenAI has limited streaming support - # Fall back to complete and yield in chunks - response = await self.complete(messages, tools, **kwargs) - - if response.content: - # Yield content in chunks for better UX - chunk_size = 50 - content = response.content - for i in range(0, len(content), chunk_size): - yield ModelChunkEvent(content=content[i : i + chunk_size]) - - if response.tool_calls: - yield ModelChunkEvent(tool_calls=response.tool_calls) + import json as _json + + from oci.generative_ai_inference import models + + # Build the same request shape as ``complete()`` but with + # ``is_stream=True`` so the SDK returns a streaming response. + converted_messages = self.provider.convert_messages(messages, model_id=self.config.model_id) + converted_tools = self.provider.convert_tools(tools) + request_kwargs = { + "max_tokens": self.config.max_tokens, + "temperature": self.config.temperature, + "top_p": self.config.top_p, + "model_id": self.config.model_id, + "is_stream": True, + } + + chat_request = self.provider.build_request( + converted_messages, + converted_tools, + **request_kwargs, + ) + # Some provider request builders (Cohere) take messages + # under a different field — they may have ignored is_stream + # in build_request. Set it on the resulting object as a + # belt-and-braces step. + if hasattr(chat_request, "is_stream"): + chat_request.is_stream = True + + chat_details = models.ChatDetails( + compartment_id=self.client.compartment_id, + serving_mode=self.client.get_serving_mode(self.config.model_id), + chat_request=chat_request, + ) + + loop = asyncio.get_running_loop() + try: + response = await loop.run_in_executor(None, lambda: self.client.chat(chat_details)) + except Exception: # noqa: BLE001 — fall back on any provider error + # Some DAC endpoints / model versions reject is_stream. + # Hand the user a working stream by chunking the + # non-streaming response. + non_stream = await self.complete(messages, tools, **kwargs) + if non_stream.content: + yield ModelChunkEvent(content=non_stream.content) + if non_stream.tool_calls: + yield ModelChunkEvent(tool_calls=non_stream.tool_calls) + yield ModelChunkEvent(done=True) + return + + # ``response.data`` is the raw streaming body. Iterate the SSE + # event stream synchronously in a worker thread so the asyncio + # loop stays responsive — each event is a small JSON delta. + events_iter = response.data.events() + sentinel = object() + + def _next_event() -> Any: + return next(events_iter, sentinel) + + while True: + event = await loop.run_in_executor(None, _next_event) + if event is sentinel: + break + data = getattr(event, "data", None) + if not data: + continue + try: + chunk = _json.loads(data) + except (ValueError, TypeError): + # Skip malformed deltas — keep the stream alive. + continue + content_delta, tool_calls_delta, _is_done = self.provider.parse_stream_chunk(chunk) + if content_delta: + yield ModelChunkEvent(content=content_delta) + if tool_calls_delta: + yield ModelChunkEvent(tool_calls=tool_calls_delta) yield ModelChunkEvent(done=True) diff --git a/src/locus/models/registry.py b/src/locus/models/registry.py index e824b50d..dd56eb4e 100644 --- a/src/locus/models/registry.py +++ b/src/locus/models/registry.py @@ -104,22 +104,35 @@ def _register_defaults() -> None: pass # OCI GenAI — pick the right transport per model family. - # Cohere R-series (cohere.command-r-*) needs the OCI SDK's - # proprietary chat shape and is routed through OCIModel. - # Everything else (OpenAI / Meta / xAI / Mistral / Gemini and - # non-R Cohere) goes through OCIOpenAIModel against - # /openai/v1/chat/completions — real SSE streaming, day-0 model - # support, no Project OCID required. See - # docs/how-to/oci-models.md. + # + # Three transport rules, evaluated top-down: + # 1. Dedicated AI Cluster (DAC) endpoint OCIDs — strings starting + # with ``ocid1.generativeaiendpoint.`` — go through ``OCIModel`` + # (SDK transport). The DAC endpoint OCID is passed verbatim to + # ``DedicatedServingMode(endpoint_id=...)``; the V1 transport + # doesn't speak that mode. + # 2. Cohere R-series (``cohere.command-r-*``) needs the OCI SDK's + # proprietary chat shape — also ``OCIModel``. + # 3. Everything else (OpenAI / Meta / xAI / Mistral / Gemini and + # non-R Cohere on-demand) goes through ``OCIOpenAIModel`` + # against ``/openai/v1/chat/completions`` — real SSE streaming, + # day-0 model support, no Project OCID required. + # + # See docs/how-to/oci-models.md and docs/how-to/oci-dac.md. try: from locus.models.providers.oci import OCIModel, OCIOpenAIModel def _make_oci(m: str, **kw: Any) -> ModelProtocol: - if m.lower().startswith("cohere.command-r"): + lowered = m.lower() + # Rule 1: DAC endpoint OCID → SDK transport. + if lowered.startswith("ocid1.generativeaiendpoint."): + return OCIModel(model_id=m, **kw) + # Rule 2: Cohere R-series → SDK transport. + if lowered.startswith("cohere.command-r"): # SDK transport: defaults to profile_name="DEFAULT" + API_KEY, # so no env-var fallback needed for one-line ergonomics. return OCIModel(model_id=m, **kw) - # V1 transport: strictly requires profile= or auth_type=. + # Rule 3: V1 transport. Strictly requires profile= or auth_type=. # Fall back to OCI_PROFILE env var so `Agent(model="oci:...")` # works in one line. Explicit kwargs always win. if "profile" not in kw and "auth_type" not in kw: diff --git a/tests/integration/test_oci_dac_live.py b/tests/integration/test_oci_dac_live.py new file mode 100644 index 00000000..f6993050 --- /dev/null +++ b/tests/integration/test_oci_dac_live.py @@ -0,0 +1,138 @@ +# Copyright (c) 2025, 2026 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v1.0 as shown at +# https://oss.oracle.com/licenses/upl/ + +"""Live integration tests for OCI Dedicated AI Cluster (DAC) endpoints. + +These tests fire real inference requests against a configured DAC +endpoint and skip cleanly when the credentials / OCID aren't +configured. Activation requires: + +- ``OCI_DAC_ENDPOINT_OCID`` — the ``ocid1.generativeaiendpoint....`` + OCID for the DAC. +- ``OCI_DAC_COMPARTMENT_ID`` — compartment OCID where the endpoint + lives. +- ``OCI_DAC_REGION`` — the region hosting the endpoint + (e.g. ``uk-london-1``). +- ``OCI_PROFILE`` — a profile in ``~/.oci/config`` with permission to + invoke the endpoint. +- ``oci`` SDK installed (``pip install -e ".[oci]"``). + +Example: + + export OCI_DAC_ENDPOINT_OCID="ocid1.generativeaiendpoint.oc1.uk-london-1...." + export OCI_DAC_COMPARTMENT_ID="ocid1.compartment.oc1...." + export OCI_DAC_REGION="uk-london-1" + export OCI_PROFILE="MY_DAC_PROFILE" + pytest tests/integration/test_oci_dac_live.py -v + +The OCIDs are intentionally read from env vars rather than checked +into the repo (CLAUDE.md privacy rule). Each test asserts the smallest +useful invariant for the layer under test, so the tests stay +informative regardless of which model is wired behind the DAC +(qwen, llama, etc.). +""" + +from __future__ import annotations + +import os + +import pytest + + +# Skip everything if the DAC endpoint isn't configured. +_DAC_OCID = os.environ.get("OCI_DAC_ENDPOINT_OCID") +_DAC_COMPARTMENT = os.environ.get("OCI_DAC_COMPARTMENT_ID") +_DAC_REGION = os.environ.get("OCI_DAC_REGION", "us-chicago-1") +_OCI_PROFILE = os.environ.get("OCI_PROFILE", "DEFAULT") + + +pytestmark = pytest.mark.skipif( + not (_DAC_OCID and _DAC_COMPARTMENT), + reason=( + "OCI DAC endpoint not configured. Set OCI_DAC_ENDPOINT_OCID + " + "OCI_DAC_COMPARTMENT_ID + OCI_DAC_REGION + OCI_PROFILE to run." + ), +) + + +@pytest.fixture +def dac_model() -> object: + """Build an OCIModel pointed at the configured DAC endpoint.""" + pytest.importorskip("oci") + from locus.models.providers.oci import OCIAuthType, OCIModel + + # The DAC endpoint OCID *is* the model_id for routing; OCIClient's + # get_serving_mode() recognises the prefix and returns + # DedicatedServingMode(endpoint_id=...). + return OCIModel( + model_id=_DAC_OCID, # type: ignore[arg-type] # filtered above + compartment_id=_DAC_COMPARTMENT, + profile_name=_OCI_PROFILE, + auth_type=OCIAuthType.API_KEY, + # The SDK derives the service endpoint from the region. + service_endpoint=(f"https://inference.generativeai.{_DAC_REGION}.oci.oraclecloud.com"), + max_tokens=128, + ) + + +@pytest.mark.asyncio +async def test_dac_complete_returns_content(dac_model: object) -> None: + """Non-streaming chat against the DAC returns a non-empty response.""" + from locus.core.messages import Message + + response = await dac_model.complete( # type: ignore[attr-defined] + messages=[Message.user("Reply with the single word 'OK'.")], + tools=None, + ) + assert response.message is not None + content = response.message.content or "" + assert content.strip(), ( + f"DAC complete() returned empty content. " + f"usage={response.usage}, stop_reason={response.stop_reason}" + ) + + +@pytest.mark.asyncio +async def test_dac_stream_yields_chunks(dac_model: object) -> None: + """Streaming chat against the DAC yields at least one content chunk + and a final done event. + + Robust to endpoints that reject ``is_stream`` — the OCIModel.stream() + fallback path emits a single content chunk with the full response + in that case, which still satisfies the assertions. + """ + from locus.core.events import ModelChunkEvent + from locus.core.messages import Message + + chunks: list[ModelChunkEvent] = [] + done = False + async for event in dac_model.stream( # type: ignore[attr-defined] + messages=[Message.user("Count from 1 to 3, one number per line.")], + tools=None, + ): + chunks.append(event) + if event.done: + done = True + + assert done, "stream never emitted a done=True event" + content_chunks = [c for c in chunks if c.content] + assert content_chunks, f"stream emitted no content chunks. total_events={len(chunks)}" + full_content = "".join(c.content or "" for c in content_chunks) + assert full_content.strip(), "stream content is empty after concat" + + +@pytest.mark.asyncio +async def test_dac_via_get_model_routes_to_oci_model(dac_model: object) -> None: + """``get_model("oci:")`` returns an ``OCIModel`` instance + rather than ``OCIOpenAIModel`` (V1 transport can't speak DAC).""" + from locus.models import get_model + from locus.models.providers.oci import OCIModel + + model = get_model( + f"oci:{_DAC_OCID}", + compartment_id=_DAC_COMPARTMENT, + profile_name=_OCI_PROFILE, + service_endpoint=(f"https://inference.generativeai.{_DAC_REGION}.oci.oraclecloud.com"), + ) + assert isinstance(model, OCIModel) diff --git a/tests/unit/test_oci_dac.py b/tests/unit/test_oci_dac.py new file mode 100644 index 00000000..2cdccc06 --- /dev/null +++ b/tests/unit/test_oci_dac.py @@ -0,0 +1,248 @@ +# Copyright (c) 2025, 2026 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v1.0 as shown at +# https://oss.oracle.com/licenses/upl/ + +"""Unit tests for OCI Dedicated AI Cluster (DAC) endpoint support. + +Covers: + +- ``get_model("oci:ocid1.generativeaiendpoint.....")`` routes + to ``OCIModel`` (SDK transport), not ``OCIOpenAIModel`` (V1). +- ``OCIClient.get_serving_mode()`` returns ``DedicatedServingMode`` for + endpoint OCIDs and ``OnDemandServingMode`` for plain model ids. +- ``GenericProvider.parse_stream_chunk()`` correctly extracts text + + tool-call deltas from the SSE event format the SDK emits. +- ``CohereProvider.parse_stream_chunk()`` does the same for Cohere's + chunk shape. +- ``examples/config.py`` ``_pick_oci_transport()`` returns ``"sdk"`` + for DAC OCIDs. + +Tests skip cleanly when the ``oci`` SDK isn't installed — the +provider routing tests fall through to the V1 transport in that case +(which is itself testable without the SDK). +""" + +from __future__ import annotations + +import pytest + + +# Generic placeholder OCIDs — never use real tenancy / endpoint OCIDs +# in test fixtures (CLAUDE.md privacy rule). These match the OCI shape +# (``ocid1..oc1..``) but the id portion is +# obviously synthetic. +_FAKE_DAC_OCID = ( + "ocid1.generativeaiendpoint.oc1.uk-london-1." + "abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890ab" +) +_FAKE_COMPARTMENT = "ocid1.compartment.oc1..abcdef1234567890abcdef1234567890abcdef1234567890abcdef" + + +# --------------------------------------------------------------------------- +# Routing: get_model("oci:ocid...") +# --------------------------------------------------------------------------- + + +class TestModelRegistryRoutesDACToSDK: + def test_dac_endpoint_routes_to_oci_model(self) -> None: + pytest.importorskip("oci") + from locus.models import get_model + from locus.models.providers.oci import OCIModel + + # The registry needs an auth path to actually instantiate the + # SDK client. ``profile_name`` keeps the constructor pure + # (deferred client creation), so we don't have to mock OCI. + model = get_model( + f"oci:{_FAKE_DAC_OCID}", + compartment_id=_FAKE_COMPARTMENT, + profile_name="DEFAULT", + ) + assert isinstance(model, OCIModel) + assert model.config.model_id == _FAKE_DAC_OCID + + def test_cohere_r_still_routes_to_oci_model(self) -> None: + # Regression check: the new DAC rule shouldn't change Cohere + # R-series routing. + pytest.importorskip("oci") + from locus.models import get_model + from locus.models.providers.oci import OCIModel + + model = get_model( + "oci:cohere.command-r-plus-08-2024", + compartment_id=_FAKE_COMPARTMENT, + profile_name="DEFAULT", + ) + assert isinstance(model, OCIModel) + + def test_non_dac_non_cohere_still_routes_to_v1(self) -> None: + # gpt-style on-demand models continue to use the V1 transport. + pytest.importorskip("oci") + from locus.models import get_model + from locus.models.providers.oci import OCIOpenAIModel + + model = get_model("oci:openai.gpt-5.5", profile="DEFAULT") + assert isinstance(model, OCIOpenAIModel) + + +# --------------------------------------------------------------------------- +# Serving-mode selection (already existed but worth a focused test) +# --------------------------------------------------------------------------- + + +class TestServingModeForDACOCIDs: + def test_endpoint_ocid_yields_dedicated_serving_mode(self) -> None: + pytest.importorskip("oci") + from oci.generative_ai_inference import models as oci_models + + from locus.models.providers.oci.client import ( + OCIAuthType, + OCIClient, + OCIClientConfig, + ) + + # Don't instantiate the actual OCI client; use the helper directly. + cfg = OCIClientConfig( + auth_type=OCIAuthType.API_KEY, + profile_name="DEFAULT", + compartment_id=_FAKE_COMPARTMENT, + ) + # OCIClient.__init__ defers SDK client creation until first use, + # so this is safe without mocking. + client = OCIClient.__new__(OCIClient) + client.config = cfg + + mode = client.get_serving_mode(_FAKE_DAC_OCID) + assert isinstance(mode, oci_models.DedicatedServingMode) + assert mode.endpoint_id == _FAKE_DAC_OCID + + def test_plain_model_id_yields_on_demand_serving_mode(self) -> None: + pytest.importorskip("oci") + from oci.generative_ai_inference import models as oci_models + + from locus.models.providers.oci.client import ( + OCIAuthType, + OCIClient, + OCIClientConfig, + ) + + cfg = OCIClientConfig( + auth_type=OCIAuthType.API_KEY, + profile_name="DEFAULT", + compartment_id=_FAKE_COMPARTMENT, + ) + client = OCIClient.__new__(OCIClient) + client.config = cfg + + mode = client.get_serving_mode("cohere.command-r-plus") + assert isinstance(mode, oci_models.OnDemandServingMode) + assert mode.model_id == "cohere.command-r-plus" + + +# --------------------------------------------------------------------------- +# Streaming chunk parsers +# --------------------------------------------------------------------------- + + +class TestGenericProviderStreamChunks: + def test_parses_text_delta(self) -> None: + pytest.importorskip("oci") + from locus.models.providers.oci.models import GenericProvider + + chunk = { + "message": { + "content": [ + {"type": "TEXT", "text": "Hello, "}, + ] + }, + } + text, tool_calls, is_done = GenericProvider().parse_stream_chunk(chunk) + assert text == "Hello, " + assert tool_calls == [] + assert not is_done + + def test_parses_finish_reason(self) -> None: + pytest.importorskip("oci") + from locus.models.providers.oci.models import GenericProvider + + chunk = {"finishReason": "stop"} + text, tool_calls, is_done = GenericProvider().parse_stream_chunk(chunk) + assert text == "" + assert tool_calls == [] + assert is_done is True + + def test_parses_tool_call_delta(self) -> None: + pytest.importorskip("oci") + from locus.models.providers.oci.models import GenericProvider + + chunk = { + "message": {"toolCalls": [{"id": "tc-1", "name": "lookup", "arguments": '{"q":"x"}'}]}, + } + text, tool_calls, is_done = GenericProvider().parse_stream_chunk(chunk) + assert text == "" + assert len(tool_calls) == 1 + assert tool_calls[0].name == "lookup" + assert tool_calls[0].arguments == {"q": "x"} + assert tool_calls[0].id == "tc-1" + + def test_handles_malformed_tool_args_gracefully(self) -> None: + pytest.importorskip("oci") + from locus.models.providers.oci.models import GenericProvider + + chunk = { + "message": {"toolCalls": [{"name": "x", "arguments": "not-json"}]}, + } + _, tool_calls, _ = GenericProvider().parse_stream_chunk(chunk) + assert len(tool_calls) == 1 + assert tool_calls[0].arguments == {} # Falls back, doesn't raise + + +class TestCohereProviderStreamChunks: + def test_parses_text_delta(self) -> None: + pytest.importorskip("oci") + from locus.models.providers.oci.models import CohereProvider + + chunk = {"text": "world"} + text, tool_calls, is_done = CohereProvider().parse_stream_chunk(chunk) + assert text == "world" + assert tool_calls == [] + assert not is_done + + def test_parses_tool_call_on_final_event(self) -> None: + pytest.importorskip("oci") + from locus.models.providers.oci.models import CohereProvider + + chunk = { + "finishReason": "stop", + "toolCalls": [{"name": "search", "parameters": {"q": "tokyo"}}], + } + text, tool_calls, is_done = CohereProvider().parse_stream_chunk(chunk) + assert text == "" + assert is_done is True + assert len(tool_calls) == 1 + assert tool_calls[0].name == "search" + assert tool_calls[0].arguments == {"q": "tokyo"} + + +# --------------------------------------------------------------------------- +# examples/config.py transport routing +# --------------------------------------------------------------------------- + + +class TestExamplesConfigPicksTransport: + def test_dac_ocid_picks_sdk(self) -> None: + # The examples/config.py module isn't on the package path; load + # it as a script for this test. Doing it here instead of a + # conftest fixture keeps the dependency narrow. + import importlib.util + from pathlib import Path + + path = Path(__file__).resolve().parents[2] / "examples" / "config.py" + spec = importlib.util.spec_from_file_location("examples_config", path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + assert module._pick_oci_transport(_FAKE_DAC_OCID) == "sdk" + assert module._pick_oci_transport("cohere.command-r-plus-08-2024") == "sdk" + assert module._pick_oci_transport("openai.gpt-5.5") == "v1"