Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,4 @@ examples/start_and_test.sh
# Old tutorials directory (superseded by examples/tutorial_*.py)
tutorials/
site/
.claude/
33 changes: 27 additions & 6 deletions src/locus/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,16 +1625,37 @@ async def _get_model_response(
# Pre-model hooks: allow hooks to modify messages before model call
messages = await self._run_before_model_hooks(messages, tool_schemas or None)

# When ``output_schema`` is set AND the provider ships native
# structured output (OpenAI's ``response_format`` shape), pass
# the JSON schema through directly. The provider parses + returns
# a typed response without the prompted-JSON fallback. Otherwise
# the schema only lives in the system prompt (see
# ``_create_initial_state``) and is parsed post-hoc.
native_response_format: dict[str, Any] | None = None
if self.config.output_schema is not None and getattr(
self._model, "supports_structured_output", False
):
from locus.core.structured import build_response_format

native_response_format = build_response_format(
self.config.output_schema,
strict=self.config.output_schema_strict,
)

# Call model with hook-driven retry support
# Hooks can request retries via event.retry = True
max_model_retries = 5
for _model_attempt in range(max_model_retries):
response = await self._model.complete(
messages=messages,
tools=tool_schemas or None,
temperature=self.config.temperature,
max_tokens=self.config.max_tokens,
)
complete_kwargs: dict[str, Any] = {
"messages": messages,
"tools": tool_schemas or None,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
}
if native_response_format is not None:
complete_kwargs["response_format"] = native_response_format

response = await self._model.complete(**complete_kwargs)

# Post-model hooks: event.retry = True to re-call
after_event = await self._run_after_model_hooks(response, messages)
Expand Down
9 changes: 9 additions & 0 deletions src/locus/models/native/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ class AnthropicModel(BaseModel):

model_config = {"arbitrary_types_allowed": True}

@property
def supports_structured_output(self) -> bool:
"""Anthropic doesn't ship OpenAI-style ``response_format``.

The agent loop falls back to the prompted-JSON path with
post-hoc parsing for Anthropic models.
"""
return False

def __init__(
self,
model: str = "claude-sonnet-4-20250514",
Expand Down
9 changes: 9 additions & 0 deletions src/locus/models/native/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ class OllamaModel(BaseModel):

model_config = {"arbitrary_types_allowed": True}

@property
def supports_structured_output(self) -> bool:
"""Ollama doesn't yet ship OpenAI-style ``response_format``.

The agent loop falls back to the prompted-JSON path with
post-hoc parsing for Ollama models.
"""
return False

def __init__(
self,
model: str = "llama3.3",
Expand Down
11 changes: 11 additions & 0 deletions src/locus/models/native/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ class OpenAIModel(BaseModel):

model_config = {"arbitrary_types_allowed": True}

@property
def supports_structured_output(self) -> bool:
"""Native ``response_format={"type":"json_schema",...}`` support.

OpenAI's chat-completions API accepts a JSON-schema response_format
and guarantees a parseable instance. The agent loop uses this
property to skip the prompted-JSON fallback when the provider
ships native structured output.
"""
return True

def __init__(
self,
model: str = "gpt-4o",
Expand Down
7 changes: 7 additions & 0 deletions src/locus/models/providers/oci/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ class OCIModel(BaseModel):

model_config = {"arbitrary_types_allowed": True}

@property
def supports_structured_output(self) -> bool:
"""OCI's native SDK transport (Cohere R-series) doesn't expose
OpenAI-style ``response_format``. Use the V1 transport
(``OCIOpenAIModel``) for that."""
return False

def __init__(
self,
model_id: str = "cohere.command-r-plus",
Expand Down
113 changes: 113 additions & 0 deletions tests/unit/test_native_structured_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""Unit tests for native ``response_format`` pass-through on the agent loop.

When ``Agent(output_schema=Pydantic)`` is configured AND the provider
exposes ``supports_structured_output`` as True, the loop should pass
``response_format=`` to ``model.complete()`` directly — skipping the
prompted-JSON fallback.

When the provider returns False (Anthropic, Ollama, OCI's native SDK),
the loop falls back to the prompted-JSON path and ``response_format``
is NOT passed.
"""

from __future__ import annotations

from typing import Any
from unittest.mock import AsyncMock

import pytest
from pydantic import BaseModel

from locus.core.messages import Message


pytest.importorskip("openai")
pytest.importorskip("anthropic")


class SamplePayload(BaseModel):
name: str
score: float


class _StubModel:
"""Minimal model stub that records the kwargs it received."""

def __init__(self, supports: bool) -> None:
self.supports_structured_output = supports
self.complete = AsyncMock()
self.stream = AsyncMock()
self._captured_kwargs: dict[str, Any] = {}

async def complete(self, **kwargs: Any) -> Any: # type: ignore[override,no-redef]
self._captured_kwargs = kwargs
from locus.models.base import ModelResponse

return ModelResponse(
message=Message.assistant('{"name": "ok", "score": 0.9}'),
usage={"input_tokens": 10, "output_tokens": 4},
)


def test_supports_structured_output_capability_on_openai_model():
"""OpenAIModel reports True; structured output passes through natively."""
from locus.models.native.openai import OpenAIModel

model = OpenAIModel(model="gpt-4o", api_key="sk-test")
assert model.supports_structured_output is True


def test_supports_structured_output_capability_on_anthropic_model():
"""AnthropicModel reports False; falls back to prompted JSON."""
from locus.models.native.anthropic import AnthropicModel

model = AnthropicModel(model="claude-sonnet-4-20250514", api_key="sk-test")
assert model.supports_structured_output is False


def test_supports_structured_output_capability_on_ollama_model():
"""OllamaModel reports False."""
from locus.models.native.ollama import OllamaModel

model = OllamaModel(model="llama3.3")
assert model.supports_structured_output is False


def test_supports_structured_output_capability_on_oci_native_model():
"""OCIModel (native SDK transport) reports False; use OCIOpenAIModel for native."""
pytest.importorskip("oci")
from locus.models.providers.oci import OCIModel

# Model id chosen to route to the native SDK transport.
try:
model = OCIModel(model_id="cohere.command-r-08-2024", profile_name="DEFAULT")
except Exception:
pytest.skip("OCI client construction requires real config")
assert model.supports_structured_output is False


def test_oci_openai_compat_inherits_capability():
"""OCIOpenAIModel inherits from OpenAIModel; reports True."""
pytest.importorskip("oci")
from locus.models.providers.oci.openai_compat import OCIOpenAIModel

try:
model = OCIOpenAIModel(model="openai.gpt-5", profile_name="DEFAULT")
except Exception:
pytest.skip("OCIOpenAIModel construction requires OCI config")
assert model.supports_structured_output is True


def test_build_response_format_returns_openai_shape():
"""``build_response_format`` already returns the right shape — sanity check."""
from locus.core.structured import build_response_format

rf = build_response_format(SamplePayload, strict=True)
assert rf["type"] == "json_schema"
assert rf["json_schema"]["name"] == "SamplePayload"
assert rf["json_schema"]["strict"] is True
assert "schema" in rf["json_schema"]
# required fields propagated:
schema = rf["json_schema"]["schema"]
assert "name" in schema.get("required", [])
assert "score" in schema.get("required", [])
Loading