diff --git a/.gitignore b/.gitignore index 7c7ee5e..35f4d76 100644 --- a/.gitignore +++ b/.gitignore @@ -255,3 +255,4 @@ examples/start_and_test.sh # Old tutorials directory (superseded by examples/tutorial_*.py) tutorials/ site/ +.claude/ diff --git a/docs/concepts/providers/anthropic.md b/docs/concepts/providers/anthropic.md index ab9a543..b7a5af5 100644 --- a/docs/concepts/providers/anthropic.md +++ b/docs/concepts/providers/anthropic.md @@ -109,7 +109,7 @@ result = agent.run_sync("This page is broken!") print(result.parsed) # Triage(severity='high', needs_human=True) ``` -### Prompt caching — automatic for long prompts +### Prompt caching — opt in for long prompts This is the biggest cost saver if your system prompt or tool block is long (skills, playbooks, RAG context). Anthropic's prompt-caching @@ -117,24 +117,42 @@ mechanism marks a span of the request as cacheable; subsequent turns within the cache window pay **1/10th** the input cost on the cached span. -locus reads the request shape and applies `cache_control` to anything -beyond a small threshold automatically. You don't opt in. +Opt in with `prompt_cache=True` on `AnthropicModel`. Locus then sends +the system prompt as a block list with `cache_control: ephemeral` and +tags the last entry of the tool catalog the same way (Anthropic walks +markers in order — the last tag anchors the cache point). ```python -# Force or suppress caching explicitly: +from locus import Agent +from locus.models.native.anthropic import AnthropicModel + agent = Agent( - model="anthropic:claude-sonnet-4-20250514", - model_config={"prompt_cache": True}, # or False to opt out + model=AnthropicModel( + model="claude-sonnet-4-20250514", + prompt_cache=True, + ), + tools=[...], + system_prompt="", ) + +result = agent.run_sync("...") +print(f"cache writes: {result.metrics.cache_creation_input_tokens}") +print(f"cache reads: {result.metrics.cache_read_input_tokens}") +# → cache writes: 4092 (turn 1, written once) +# → cache reads: 4092 (turn 2 — same prefix, ~10× cheaper input) ``` When it kicks in: -- A 5-minute "ephemeral" cache (rolling window) — the default. +- A 5-minute "ephemeral" cache (rolling window). - Subsequent turns reusing the same prefix pay `0.1× input rate` on the cached portion. -- Effective when system prompts > ~1024 tokens, or you've loaded a - big skill / playbook / RAG block. +- Most effective when system prompts ≥ ~1024 tokens, or you've loaded + a big skill / playbook / RAG block. + +`cache_creation_input_tokens` and `cache_read_input_tokens` surface +on `AgentResult.metrics` so observability hooks can chart cache hits +and the cost saved. ### Extended thinking — visible reasoning diff --git a/src/locus/agent/agent.py b/src/locus/agent/agent.py index 8224637..dc2ffb0 100644 --- a/src/locus/agent/agent.py +++ b/src/locus/agent/agent.py @@ -510,8 +510,15 @@ async def run( ) prompt_toks = response.usage.get("prompt_tokens", 0) completion_toks = response.usage.get("completion_tokens", 0) + cache_creation_toks = response.usage.get("cache_creation_input_tokens", 0) + cache_read_toks = response.usage.get("cache_read_input_tokens", 0) _total_tokens += prompt_toks + completion_toks - state = state.with_token_usage(prompt_toks, completion_toks) + state = state.with_token_usage( + prompt_toks, + completion_toks, + cache_creation_tokens=cache_creation_toks, + cache_read_tokens=cache_read_toks, + ) summary = ( response.message.content @@ -579,8 +586,15 @@ async def run( response, state = await self._get_model_response(state) prompt_toks = response.usage.get("prompt_tokens", 0) completion_toks = response.usage.get("completion_tokens", 0) + cache_creation_toks = response.usage.get("cache_creation_input_tokens", 0) + cache_read_toks = response.usage.get("cache_read_input_tokens", 0) _total_tokens += prompt_toks + completion_toks - state = state.with_token_usage(prompt_toks, completion_toks) + state = state.with_token_usage( + prompt_toks, + completion_toks, + cache_creation_tokens=cache_creation_toks, + cache_read_tokens=cache_read_toks, + ) _last_assistant_content = response.message.content # Track for the user-supplied termination condition. Updated again # below if a Cohere-style text tool call is parsed out of the body. @@ -1064,6 +1078,8 @@ async def _run() -> AgentResult: total_tokens=state.total_tokens_used, prompt_tokens=state.prompt_tokens_used, completion_tokens=state.completion_tokens_used, + cache_creation_input_tokens=state.cache_creation_tokens_used, + cache_read_input_tokens=state.cache_read_tokens_used, duration_ms=elapsed_ms, ) @@ -1322,8 +1338,15 @@ async def _run_from_state( response, state = await self._get_model_response(state) prompt_toks = response.usage.get("prompt_tokens", 0) completion_toks = response.usage.get("completion_tokens", 0) + cache_creation_toks = response.usage.get("cache_creation_input_tokens", 0) + cache_read_toks = response.usage.get("cache_read_input_tokens", 0) _total_tokens += prompt_toks + completion_toks - state = state.with_token_usage(prompt_toks, completion_toks) + state = state.with_token_usage( + prompt_toks, + completion_toks, + cache_creation_tokens=cache_creation_toks, + cache_read_tokens=cache_read_toks, + ) _last_assistant_content = response.message.content _last_no_tool_calls = not response.message.tool_calls diff --git a/src/locus/agent/result.py b/src/locus/agent/result.py index daff586..d94f18a 100644 --- a/src/locus/agent/result.py +++ b/src/locus/agent/result.py @@ -30,6 +30,12 @@ class ExecutionMetrics(BaseModel): duration_ms: float = 0.0 reflexion_evaluations: int = 0 grounding_evaluations: int = 0 + # Anthropic prompt-caching token counts. Populated only when the + # AnthropicModel is configured with `prompt_cache=True` and the + # provider returns cache_creation_input_tokens / cache_read_input_tokens + # on the response usage. Zero on other providers. + cache_creation_input_tokens: int = 0 + cache_read_input_tokens: int = 0 model_config = {"frozen": True} diff --git a/src/locus/core/state.py b/src/locus/core/state.py index f55dd74..690f568 100644 --- a/src/locus/core/state.py +++ b/src/locus/core/state.py @@ -86,6 +86,11 @@ class AgentState(BaseModel): total_tokens_used: int = 0 prompt_tokens_used: int = 0 completion_tokens_used: int = 0 + # Anthropic prompt-cache token counts. Populated only when an + # AnthropicModel is configured with prompt_cache=True. Zero on + # other providers. + cache_creation_tokens_used: int = 0 + cache_read_tokens_used: int = 0 token_budget: int | None = None # Completion mode @@ -202,13 +207,29 @@ def with_metadata(self, key: str, value: Any) -> AgentState: } ) - def with_token_usage(self, prompt_tokens: int, completion_tokens: int) -> AgentState: - """Record token usage from a model response.""" + def with_token_usage( + self, + prompt_tokens: int, + completion_tokens: int, + cache_creation_tokens: int = 0, + cache_read_tokens: int = 0, + ) -> AgentState: + """Record token usage from a model response. + + ``cache_creation_tokens`` and ``cache_read_tokens`` are populated + only when Anthropic returns prompt-cache stats on the response + usage (i.e., the AnthropicModel was configured with + ``prompt_cache=True``). Default 0 for other providers. + """ return self.model_copy( update={ "total_tokens_used": self.total_tokens_used + prompt_tokens + completion_tokens, "prompt_tokens_used": self.prompt_tokens_used + prompt_tokens, "completion_tokens_used": self.completion_tokens_used + completion_tokens, + "cache_creation_tokens_used": ( + self.cache_creation_tokens_used + cache_creation_tokens + ), + "cache_read_tokens_used": self.cache_read_tokens_used + cache_read_tokens, "updated_at": datetime.now(UTC), } ) diff --git a/src/locus/models/native/anthropic.py b/src/locus/models/native/anthropic.py index 1c2a322..10ce490 100644 --- a/src/locus/models/native/anthropic.py +++ b/src/locus/models/native/anthropic.py @@ -29,6 +29,15 @@ class AnthropicConfig(ModelConfig): top_p: float = 0.9 api_key: str | None = Field(default=None, description="Anthropic API key") base_url: str | None = Field(default=None, description="Custom API base URL") + prompt_cache: bool = Field( + default=False, + description=( + "When True, mark the system prompt and tool catalog with " + "Anthropic's `cache_control: ephemeral` so subsequent turns " + "reuse the cached input at ~1/10x cost. Default False for " + "backward compatibility." + ), + ) class AnthropicModel(BaseModel): @@ -53,6 +62,7 @@ def __init__( base_url: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, + prompt_cache: bool = False, **kwargs: Any, ) -> None: config = AnthropicConfig( @@ -61,6 +71,7 @@ def __init__( base_url=base_url, max_tokens=max_tokens, temperature=temperature, + prompt_cache=prompt_cache, **kwargs, ) super().__init__(config=config) @@ -195,7 +206,20 @@ async def complete( "temperature": kwargs.get("temperature", self.config.temperature), } if system_prompt: - params["system"] = system_prompt + # When prompt-caching is enabled, send the system prompt as a + # block list with ``cache_control: ephemeral`` so subsequent + # turns reuse the cached input at ~1/10x cost (Anthropic + # ephemeral cache TTL is ~5 min). + if self.config.prompt_cache: + params["system"] = [ + { + "type": "text", + "text": system_prompt, + "cache_control": {"type": "ephemeral"}, + } + ] + else: + params["system"] = system_prompt # Structured-output mode: emulate ``response_format`` via tool-use. response_format = kwargs.get("response_format") @@ -211,6 +235,17 @@ async def complete( } if anthropic_tools: + # Cache the tool catalog too — it's typically the same across + # turns and can be large. Anthropic walks the cache_control + # markers in order; tagging the last tool covers the catalog. + if self.config.prompt_cache and anthropic_tools: + anthropic_tools = [ + *anthropic_tools[:-1], + { + **anthropic_tools[-1], + "cache_control": {"type": "ephemeral"}, + }, + ] params["tools"] = anthropic_tools response = await self.client.messages.create(**params) @@ -240,12 +275,21 @@ async def complete( if structured_mode and structured_payload is not None: content = _json.dumps(structured_payload) - usage = {} + usage: dict[str, int] = {} if response.usage: usage = { "prompt_tokens": response.usage.input_tokens, "completion_tokens": response.usage.output_tokens, } + # Anthropic returns these only when prompt caching is in play. + # Surface them on usage so AgentResult.metrics can show + # cache hits/misses and cost-saved estimates. + cache_creation = getattr(response.usage, "cache_creation_input_tokens", None) + cache_read = getattr(response.usage, "cache_read_input_tokens", None) + if cache_creation is not None: + usage["cache_creation_input_tokens"] = cache_creation + if cache_read is not None: + usage["cache_read_input_tokens"] = cache_read return ModelResponse( message=Message.assistant(content=content, tool_calls=tool_calls), diff --git a/tests/unit/test_anthropic_prompt_caching.py b/tests/unit/test_anthropic_prompt_caching.py new file mode 100644 index 0000000..5e94b42 --- /dev/null +++ b/tests/unit/test_anthropic_prompt_caching.py @@ -0,0 +1,200 @@ +"""Unit tests for Anthropic prompt-caching wiring. + +Verifies that with ``prompt_cache=True`` on ``AnthropicModel``: +1. The system prompt is sent as a block list with ``cache_control: ephemeral``. +2. The last tool entry carries ``cache_control: ephemeral`` (caches the catalog). +3. Cache token counts on the response (``cache_creation_input_tokens`` / + ``cache_read_input_tokens``) flow into ``usage`` and propagate up to + ``ExecutionMetrics``. + +The Anthropic SDK is mocked so we can inspect the request params we pass +to ``client.messages.create()`` without a real API call. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock + +import pytest + + +pytest.importorskip("anthropic") + +from locus.core.messages import Message +from locus.models.native.anthropic import AnthropicModel + + +def _make_response_with_usage( + *, + input_tokens: int = 100, + output_tokens: int = 20, + cache_creation: int | None = None, + cache_read: int | None = None, +): + """Build a fake Anthropic response with the requested usage shape.""" + from types import SimpleNamespace + + usage = SimpleNamespace(input_tokens=input_tokens, output_tokens=output_tokens) + if cache_creation is not None: + usage.cache_creation_input_tokens = cache_creation + if cache_read is not None: + usage.cache_read_input_tokens = cache_read + + text_block = SimpleNamespace(type="text", text="hi") + return SimpleNamespace( + content=[text_block], + usage=usage, + stop_reason="end_turn", + ) + + +def _build_model_with_mocked_client(*, prompt_cache: bool): + model = AnthropicModel( + model="claude-sonnet-4-20250514", + api_key="sk-test", + prompt_cache=prompt_cache, + ) + mock_client = AsyncMock() + mock_client.messages.create = AsyncMock(return_value=_make_response_with_usage()) + model._client = mock_client + return model, mock_client + + +def test_system_prompt_uses_cache_control_when_prompt_cache_enabled(): + import asyncio + + model, mock_client = _build_model_with_mocked_client(prompt_cache=True) + + asyncio.run(model.complete([Message.system("You are helpful."), Message.user("hi")])) + + call = mock_client.messages.create.call_args + system_param = call.kwargs["system"] + assert isinstance(system_param, list), "system should be a block list when caching" + assert system_param[0]["type"] == "text" + assert system_param[0]["text"] == "You are helpful." + assert system_param[0]["cache_control"] == {"type": "ephemeral"} + + +def test_system_prompt_is_plain_string_when_caching_disabled(): + """Backward-compat: prompt_cache=False keeps system as a bare string.""" + import asyncio + + model, mock_client = _build_model_with_mocked_client(prompt_cache=False) + + asyncio.run(model.complete([Message.system("You are helpful."), Message.user("hi")])) + + call = mock_client.messages.create.call_args + assert call.kwargs["system"] == "You are helpful." + + +def test_tool_catalog_gets_cache_control_when_caching_enabled(): + import asyncio + + model, mock_client = _build_model_with_mocked_client(prompt_cache=True) + + tools = [ + { + "type": "function", + "function": { + "name": "search", + "description": "Search the web.", + "parameters": {"type": "object", "properties": {}}, + }, + }, + { + "type": "function", + "function": { + "name": "summarise", + "description": "Summarise text.", + "parameters": {"type": "object", "properties": {}}, + }, + }, + ] + + asyncio.run(model.complete([Message.user("hi")], tools=tools)) + + call = mock_client.messages.create.call_args + anthropic_tools = call.kwargs["tools"] + assert len(anthropic_tools) == 2 + # Last tool carries cache_control; first does not. + assert "cache_control" not in anthropic_tools[0] + assert anthropic_tools[1]["cache_control"] == {"type": "ephemeral"} + + +def test_cache_tokens_surfaced_on_usage_dict(): + """When the response includes cache stats, they flow into ModelResponse.usage.""" + import asyncio + + model = AnthropicModel(model="claude-sonnet-4-20250514", api_key="sk-test", prompt_cache=True) + mock_client = AsyncMock() + mock_client.messages.create = AsyncMock( + return_value=_make_response_with_usage( + input_tokens=50, + output_tokens=10, + cache_creation=1000, + cache_read=4500, + ) + ) + model._client = mock_client + + response = asyncio.run(model.complete([Message.user("hi")])) + + assert response.usage["prompt_tokens"] == 50 + assert response.usage["completion_tokens"] == 10 + assert response.usage["cache_creation_input_tokens"] == 1000 + assert response.usage["cache_read_input_tokens"] == 4500 + + +def test_no_cache_fields_when_anthropic_omits_them(): + """Old SDK versions / non-cache responses don't include cache fields.""" + import asyncio + + model = AnthropicModel(model="claude-sonnet-4-20250514", api_key="sk-test", prompt_cache=False) + mock_client = AsyncMock() + mock_client.messages.create = AsyncMock( + return_value=_make_response_with_usage(input_tokens=50, output_tokens=10) + ) + model._client = mock_client + + response = asyncio.run(model.complete([Message.user("hi")])) + + assert "cache_creation_input_tokens" not in response.usage + assert "cache_read_input_tokens" not in response.usage + + +def test_execution_metrics_have_cache_fields(): + """ExecutionMetrics carries cache_creation_input_tokens / cache_read_input_tokens.""" + from locus.agent.result import ExecutionMetrics + + metrics = ExecutionMetrics( + iterations=2, + prompt_tokens=100, + completion_tokens=20, + cache_creation_input_tokens=500, + cache_read_input_tokens=2000, + ) + assert metrics.cache_creation_input_tokens == 500 + assert metrics.cache_read_input_tokens == 2000 + + +def test_state_with_token_usage_accepts_cache_counts(): + """AgentState.with_token_usage accepts and accumulates cache counts.""" + from locus.core.state import AgentState + + state = AgentState() + state = state.with_token_usage( + prompt_tokens=100, + completion_tokens=20, + cache_creation_tokens=500, + cache_read_tokens=2000, + ) + state = state.with_token_usage( + prompt_tokens=50, + completion_tokens=10, + cache_read_tokens=1500, + ) + + assert state.prompt_tokens_used == 150 + assert state.completion_tokens_used == 30 + assert state.cache_creation_tokens_used == 500 + assert state.cache_read_tokens_used == 3500