From 1f035947c0477f533ed645eb96fd6106bbb43ba0 Mon Sep 17 00:00:00 2001 From: Frank Speiser Date: Tue, 26 May 2026 11:03:40 -0400 Subject: [PATCH] Add per-turn cost cap for worker tool-use (soft default + 3 modes) Closes the PR #21 follow-up. The cap is per-worker-turn (matches the hop-budget scope) and defaults to soft-warn behavior so the calling agent can ask the user how to proceed when the cap is hit. Three modes: - warn (default when a cap is set): emit a `worker_tool_cost_warning` ndjson event the first time observed > cap, attach a `cost_cap` block to the answer, and KEEP RUNNING. The response also carries an `operator_prompt` string so the calling agent has a clear signal to surface an AskUserQuestion offering enforce / warn-only / ignore. - enforce: when observed > cap, the next inner tool_call is refused with a `cost_cap_exceeded` payload (wrapped as a tool_result so the worker sees a consistent envelope). Worker gets one final emission round to produce its answer with whatever it already has. - off: no checks at all; `cost_cap` block is suppressed entirely (the operator asked to ignore the cap, so the response stays clean). Scope decision: per-turn, not session-wide. Each top-level confer/coordinate call gets a fresh cap budget, matching how the hop budget already works. Implementation: - `_worker_tool_cost_cap_defaults(kwarg_cap, kwarg_mode)` resolves per-call args against CFG.worker_tools.{cost_cap_usd,cost_cap_mode}. cap_usd <= 0 disables; unknown mode falls through to "warn". - `_worker_tool_cost_observed(aggregated)` pulls cumulative $-cost out of the merged answer's usage block. - `_worker_tool_cost_cap_refusal(observed, cap)` formats the enforce-mode refusal payload (wrapped as a `` so the worker sees the same envelope shape regardless of outcome). - `_ask_one_with_tools` grows `cost_cap_usd` + `cost_cap_mode` kwargs. Pre-dispatch check: if enforce + observed > cap, refuse + one final emission round. Post-call check (warn mode): emit the warning event the first time we observe > cap. - `_request_structured_with_tools` mirrors the same logic and uses a `_finalize()` helper to consistently attach the cost_cap block on every return path. - `_request_structured` + `_ask_many_parallel` forward the new kwargs through. - `tool_confer` + `tool_coordinate` accept `worker_tool_cost_cap_usd` + `worker_tool_cost_cap_mode` args. Both surface aggregated cost-cap state on the response under `worker_tools.cost_cap` (with `per_provider` / `per_role` breakdown + an `operator_prompt` when soft-warn was tripped). Schema additions: `worker_tool_cost_cap_usd` (number, minimum 0) and `worker_tool_cost_cap_mode` (enum) on both confer.input and coordinate.input; descriptions document the warn-default + the agent's AskUserQuestion follow-up responsibility. Tests (scripts/test_worker_tool_cost_cap.py): - `_worker_tool_cost_cap_defaults`: CFG fallthrough, per-call override, 0/negative disables, unknown mode falls through to warn - warn mode: both inner fetches execute, cost_cap block attached with exceeded=true / blocked=false - enforce mode: only first fetch executes; second gets cost_cap_exceeded; final emission round runs - off mode: no cost_cap block on the answer (caller asked to ignore) - No cap set: legacy worker_tools behavior preserved (no block) - Cap not exceeded: block still attached with exceeded=false - Structured variant: same behavior inside `_request_structured_with_tools` - End-to-end coordinate: cost_cap surfaces on `worker_tools.cost_cap` with per_role breakdown; synth role NOT in the per_role list (excluded from worker_tools by design); operator_prompt present when warn-mode tripped Full suite (36 scripts) passes. Co-Authored-By: Claude Opus 4.7 --- schema/tools.schema.json | 12 +- scripts/test_worker_tool_cost_cap.py | 328 +++++++++++++++++++++++++++ servers/python/crosscheck_server.py | 314 ++++++++++++++++++++++--- 3 files changed, 623 insertions(+), 31 deletions(-) create mode 100644 scripts/test_worker_tool_cost_cap.py diff --git a/schema/tools.schema.json b/schema/tools.schema.json index cb4f151..ed6d6bb 100644 --- a/schema/tools.schema.json +++ b/schema/tools.schema.json @@ -267,7 +267,11 @@ "inject_session_memory": { "type": "boolean", "default": false, "description": "When true (and a session_id is set), prepend the session's non-stale `` block (decisions / facts / open_questions) to the user message. Stale entries from a prior failed audit are excluded." }, "worker_tools": { "type": "array", "items": { "type": "string", "enum": ["fetch", "verify"] }, - "description": "Opt-in: enable bounded mid-turn tool use for each worker. Workers may emit `{\"name\":\"TOOL\",\"args\":{...}}` to request the listed tools; results are wrapped as untrusted-input and re-prompted. Hard hop budget = 2 per worker turn. Only `fetch` and `verify` (read-only / deterministic) are callable — recursive ReAct via LLM-spawning tools is explicitly disallowed." } + "description": "Opt-in: enable bounded mid-turn tool use for each worker. Workers may emit `{\"name\":\"TOOL\",\"args\":{...}}` to request the listed tools; results are wrapped as untrusted-input and re-prompted. Hard hop budget = 2 per worker turn. Only `fetch` and `verify` (read-only / deterministic) are callable — recursive ReAct via LLM-spawning tools is explicitly disallowed." }, + "worker_tool_cost_cap_usd": { "type": "number", "minimum": 0, + "description": "Optional soft USD cap on each worker's cumulative inner-call cost (LLM re-prompts + inner tool calls) for this turn. 0 / omitted = disabled. Defaults to `CFG.worker_tools.cost_cap_usd` when omitted." }, + "worker_tool_cost_cap_mode": { "type": "string", "enum": ["warn", "enforce", "off"], "default": "warn", + "description": "Behavior when the cap is crossed. `warn` (default): emit a warning event + a `worker_tools.cost_cap` block on the response and keep running; the calling agent should surface an AskUserQuestion prompt offering enforce / warn-only / ignore for the next call. `enforce`: refuse the next inner tool call when the cap is exceeded (worker gets one final emission round). `off`: no checks." } }, "required": ["question"] }, @@ -400,7 +404,11 @@ "inject_session_memory": { "type": "boolean", "default": false, "description": "When true (and a session_id is set), prepend the session's non-stale `` block to the topic before dispatch. Stale entries (e.g. from a prior failed audit) are excluded." }, "worker_tools": { "type": "array", "items": { "type": "string", "enum": ["fetch", "verify"] }, - "description": "Opt-in: enable bounded mid-turn tool use on the PROPOSER and CRITIC roles (the SYNTHESIZER is purely combinatorial and is intentionally excluded). Workers may emit `{\"name\":\"TOOL\",\"args\":{...}}` BEFORE the final structured emission; tool results are wrapped as untrusted-input and re-prompted. Hard hop budget = 2 per role. Only `fetch` and `verify` are callable — recursive ReAct via LLM-spawning tools is disallowed." } + "description": "Opt-in: enable bounded mid-turn tool use on the PROPOSER and CRITIC roles (the SYNTHESIZER is purely combinatorial and is intentionally excluded). Workers may emit `{\"name\":\"TOOL\",\"args\":{...}}` BEFORE the final structured emission; tool results are wrapped as untrusted-input and re-prompted. Hard hop budget = 2 per role. Only `fetch` and `verify` are callable — recursive ReAct via LLM-spawning tools is disallowed." }, + "worker_tool_cost_cap_usd": { "type": "number", "minimum": 0, + "description": "Optional soft USD cap on each proposer/critic role's cumulative inner-call cost for this turn. 0 / omitted = disabled. Defaults to `CFG.worker_tools.cost_cap_usd` when omitted." }, + "worker_tool_cost_cap_mode": { "type": "string", "enum": ["warn", "enforce", "off"], "default": "warn", + "description": "Behavior when the cap is crossed. `warn` (default): emit a warning event + a `worker_tools.cost_cap` block on the response and keep running; the calling agent should surface an AskUserQuestion prompt offering enforce / warn-only / ignore for the next call. `enforce`: refuse the next inner tool call when the cap is exceeded (worker gets one final emission round). `off`: no checks." } }, "required": ["topic"] }, diff --git a/scripts/test_worker_tool_cost_cap.py b/scripts/test_worker_tool_cost_cap.py new file mode 100644 index 0000000..a111ab6 --- /dev/null +++ b/scripts/test_worker_tool_cost_cap.py @@ -0,0 +1,328 @@ +#!/usr/bin/env python3 +"""Tests for the per-turn worker tool-use cost cap. + +Three modes: + - warn (default): emit warning event + attach cost_cap block; keep going + - enforce: refuse next inner tool call when cap is exceeded + - off: no checks at all + +The user explicitly asked for warn-as-default + an agent-side +AskUserQuestion follow-up flow; the server emits the structured signal +so the agent can surface it. +""" + +from __future__ import annotations + +import json +import os +import sys +import tempfile +from pathlib import Path + + +def main() -> int: + here = Path(__file__).resolve().parents[1] + sys.path.insert(0, str(here / "servers" / "python")) + + tmp = Path(tempfile.mkdtemp()) + pricing = tmp / "pricing.json" + # Generous per-token cost so cap thresholds are easy to reason about. + # gpt-test: 30 prompt + 10 completion = 30*0.01 + 10*0.05 = $0.80 per call. + pricing.write_text(json.dumps({ + "openai": {"gpt-test": {"prompt_per_1k": 10.0, "completion_per_1k": 50.0, "cached_per_1k": 5.0}}, + "anthropic": {"claude-test": {"prompt_per_1k": 10.0, "completion_per_1k": 50.0, "cached_per_1k": 5.0}}, + })) + os.environ["CROSSCHECK_PRICING_PATH"] = str(pricing) + os.environ.pop("CROSSCHECK_REJECT_CONFIG_DRIFT", None) + + import crosscheck_server as srv + + srv.CFG = dict(srv.CFG) + srv.CFG["session_db"] = str(tmp / "sessions.db") + srv.CFG["transcript_dir"] = str(tmp / "transcripts") + srv.CFG["cache"] = {"enabled": False} + srv.CFG["node_cache"] = {"enabled": False} + srv.CFG["prompt_adapters"] = {"enabled": False} + srv.CFG["fetch"] = {"url_allowlist": ["https://example.com/"]} + srv.CFG["config_pinning"] = {"reject_drift": False} + srv.CFG.pop("worker_tools", None) + srv.TRANSCRIPT_DIR = Path(srv.CFG["transcript_dir"]) + srv._DB_INIT_DONE = False + srv._FTS5_AVAILABLE = None + srv._PRICING_CACHE = None + srv.PRICING_PATH = pricing + srv._CONFIG_PIN_STARTUP_DONE = True + + srv.ENV = dict(srv.ENV) + srv.ENV["OPENAI_API_KEY"] = "stub"; srv.ENV["OPENAI_MODEL"] = "gpt-test" + srv.ENV["ANTHROPIC_API_KEY"] = "stub"; srv.ENV["ANTHROPIC_MODEL"] = "claude-test" + srv.ALL_PROVIDERS = srv.build_providers() + srv.CFG["providers"] = ["openai", "anthropic"] + srv.CFG["moderator"] = "anthropic" + + openai_responses: list[str] = [] + fetched_urls: list[str] = [] + + def fake_post(url, h, b, **kw): + if "openai.com" in url: + text = openai_responses.pop(0) if openai_responses else "DONE." + return ({"choices": [{"message": {"content": text}}], + "usage": {"prompt_tokens": 30, "completion_tokens": 10}}, 1) + if "anthropic.com" in url: + text = openai_responses.pop(0) if openai_responses else "DONE." + return ({"content": [{"type": "text", "text": text}], + "usage": {"input_tokens": 30, "output_tokens": 10}}, 1) + return ({}, 1) + srv._http_post_resilient = fake_post + + real_tool_fetch = srv.tool_fetch + def fake_tool_fetch(args): + fetched_urls.append(args.get("url", "")) + return {"tool": "fetch", "url": args.get("url"), + "status": "ok", "content_excerpt": "stub"} + srv.tool_fetch = fake_tool_fetch + + p = srv.ALL_PROVIDERS["openai"] + + # ------------------------------------------------------------------ + # 1) Defaults helper + # ------------------------------------------------------------------ + # No CFG, no kwargs -> disabled + cap, mode = srv._worker_tool_cost_cap_defaults(None, None) + assert cap is None and mode == "warn" # mode defaults to warn but cap=None disables + + # Per-call cap with explicit mode + cap, mode = srv._worker_tool_cost_cap_defaults(0.5, "enforce") + assert cap == 0.5 and mode == "enforce" + + # 0 / negative -> disabled + cap, mode = srv._worker_tool_cost_cap_defaults(0, "warn") + assert cap is None + cap, mode = srv._worker_tool_cost_cap_defaults(-1.0, "warn") + assert cap is None + + # Bad mode falls through to warn + cap, mode = srv._worker_tool_cost_cap_defaults(0.5, "blah") + assert mode == "warn" + + # CFG default picked up when kwargs absent + srv.CFG["worker_tools"] = {"cost_cap_usd": 0.25, "cost_cap_mode": "enforce"} + cap, mode = srv._worker_tool_cost_cap_defaults(None, None) + assert cap == 0.25 and mode == "enforce" + + # Per-call kwarg beats CFG default + cap, mode = srv._worker_tool_cost_cap_defaults(1.0, "warn") + assert cap == 1.0 and mode == "warn" + srv.CFG.pop("worker_tools", None) + + # ------------------------------------------------------------------ + # 2) warn mode: cap crossed, loop continues, block attached + # ------------------------------------------------------------------ + # Each call costs $0.80. Cap at $1.00 -> first round under cap, second + # round (after first fetch + LLM re-prompt) takes cumulative to $1.60 + # which trips the warning. Worker emits final answer. + openai_responses[:] = [ + '{"name":"fetch","args":{"url":"https://example.com/a"}}', + # After re-prompt + LLM response, cumulative is $1.60 (over $1.00). + # warn mode: keep going. Worker emits one MORE tool call, allowed. + '{"name":"fetch","args":{"url":"https://example.com/b"}}', + # Final answer after second fetch. + "Done with sources gathered.", + ] + fetched_urls.clear() + ans = srv._ask_one_with_tools( + p, + [{"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Investigate."}], + deadline=__import__("time").monotonic() + 30, + max_tokens=2048, purpose="worker", + worker_tools=["fetch"], session_id="warn-mode", + cost_cap_usd=1.0, cost_cap_mode="warn", + ) + assert "Done with sources gathered" in ans["response"], ans + cc = ans["cost_cap"] + assert cc["cap_usd"] == 1.0 + assert cc["mode"] == "warn" + assert cc["exceeded"] is True, cc + assert cc["blocked"] is False, cc + # Both fetches executed (warn never blocks) + assert len(fetched_urls) == 2, fetched_urls + statuses = [c["status"] for c in ans["inner_tool_calls"]] + assert statuses == ["ok", "ok"], statuses + + # ------------------------------------------------------------------ + # 3) enforce mode: next inner call refused after cap crossed + # ------------------------------------------------------------------ + # Cap at $1.00. First fetch + reprompt -> $1.60 -> over cap. + # Worker requests a SECOND fetch but it's refused. Worker gets one + # final round to produce its answer. + openai_responses[:] = [ + '{"name":"fetch","args":{"url":"https://example.com/a"}}', + '{"name":"fetch","args":{"url":"https://example.com/b"}}', + "Final answer with whatever we had.", + ] + fetched_urls.clear() + ans = srv._ask_one_with_tools( + p, + [{"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Investigate."}], + deadline=__import__("time").monotonic() + 30, + max_tokens=2048, purpose="worker", + worker_tools=["fetch"], session_id="enforce-mode", + cost_cap_usd=1.0, cost_cap_mode="enforce", + ) + cc = ans["cost_cap"] + assert cc["mode"] == "enforce" + assert cc["exceeded"] is True, cc + assert cc["blocked"] is True, cc + # Only ONE fetch ran; the second got cost_cap_exceeded + assert len(fetched_urls) == 1, fetched_urls + statuses = [c["status"] for c in ans["inner_tool_calls"]] + assert statuses == ["ok", "cost_cap_exceeded"], statuses + assert "Final answer" in ans["response"] + + # ------------------------------------------------------------------ + # 4) off mode: no checks at all (no cost_cap block, no events) + # ------------------------------------------------------------------ + openai_responses[:] = [ + '{"name":"fetch","args":{"url":"https://example.com/a"}}', + "Done.", + ] + fetched_urls.clear() + ans = srv._ask_one_with_tools( + p, + [{"role": "user", "content": "Investigate."}], + deadline=__import__("time").monotonic() + 30, + max_tokens=2048, purpose="worker", + worker_tools=["fetch"], session_id="off-mode", + cost_cap_usd=0.01, # would absolutely trip in warn/enforce + cost_cap_mode="off", + ) + assert "cost_cap" not in ans, ans + + # ------------------------------------------------------------------ + # 5) Cap not set -> no cost_cap block (legacy worker_tools behavior) + # ------------------------------------------------------------------ + openai_responses[:] = [ + '{"name":"fetch","args":{"url":"https://example.com/a"}}', + "Done.", + ] + fetched_urls.clear() + ans = srv._ask_one_with_tools( + p, + [{"role": "user", "content": "Investigate."}], + deadline=__import__("time").monotonic() + 30, + max_tokens=2048, purpose="worker", + worker_tools=["fetch"], session_id="no-cap", + ) + assert "cost_cap" not in ans + + # ------------------------------------------------------------------ + # 6) Cap NOT exceeded -> block still attached but exceeded=false + # ------------------------------------------------------------------ + openai_responses[:] = [ + "Immediate answer, no tool calls.", + ] + ans = srv._ask_one_with_tools( + p, + [{"role": "user", "content": "Hi."}], + deadline=__import__("time").monotonic() + 30, + max_tokens=2048, purpose="worker", + worker_tools=["fetch"], session_id="under-cap", + cost_cap_usd=5.0, cost_cap_mode="warn", + ) + cc = ans["cost_cap"] + assert cc["cap_usd"] == 5.0 + assert cc["exceeded"] is False, cc + assert cc["blocked"] is False + + # ------------------------------------------------------------------ + # 7) Structured variant: cost_cap honored inside _request_structured_with_tools + # ------------------------------------------------------------------ + role_schema = srv._role_turn_schema() + valid_proposer = json.dumps({ + "role": "proposer", "summary": "draft", "confidence": 0.8, + "ballot": "agree", + }) + openai_responses[:] = [ + '{"name":"fetch","args":{"url":"https://example.com/a"}}', + '{"name":"fetch","args":{"url":"https://example.com/b"}}', + valid_proposer, + ] + fetched_urls.clear() + obj, ans, errs = srv._request_structured( + p, + [{"role": "system", "content": "PROPOSER."}, + {"role": "user", "content": "Draft."}], + role_schema, + max_tokens=2048, deadline=__import__("time").monotonic() + 30, + max_retries=1, purpose="worker", + worker_tools=["fetch"], session_id="rst-enforce", + cost_cap_usd=1.0, cost_cap_mode="enforce", + ) + assert errs == [], errs + assert obj["role"] == "proposer" + cc = ans["cost_cap"] + assert cc["blocked"] is True, cc + # Only one fetch executed; second refused with cost_cap_exceeded + assert len(fetched_urls) == 1, fetched_urls + statuses = [c["status"] for c in ans["inner_tool_calls"]] + assert "cost_cap_exceeded" in statuses, statuses + + # ------------------------------------------------------------------ + # 8) End-to-end coordinate surfaces cost_cap on result + operator_prompt + # ------------------------------------------------------------------ + valid_critic = json.dumps({ + "role": "critic", "summary": "looks ok", "confidence": 0.7, + "ballot": "agree", + }) + valid_synth = json.dumps({ + "consensus": "use opaque tokens", "weighted_confidence": 0.85, + "key_claims": [], + }) + # Proposer (anthropic) does 2 fetches + final emission -> exceeds cap + # Critics (openai) each do 0 fetches + final emission -> under cap individually + # Synth (anthropic) is not given tools -> no cost_cap block + openai_responses[:] = [ + # Anthropic proposer: 2 tool calls + final + '{"name":"fetch","args":{"url":"https://example.com/p1"}}', + '{"name":"fetch","args":{"url":"https://example.com/p2"}}', + valid_proposer, + # OpenAI critic: final immediately + valid_critic, + # Synth (anthropic): final immediately + valid_synth, + ] + fetched_urls.clear() + res = srv.tool_coordinate({ + "topic": "tokens", + "providers": ["openai", "anthropic"], + "proposer": "anthropic", + "critics": ["openai"], + "synthesizer": "anthropic", + "worker_tools": ["fetch"], + "worker_tool_cost_cap_usd": 1.0, + "worker_tool_cost_cap_mode": "warn", + "session_id": "coord-warn", + }) + wt = res["worker_tools"] + assert "cost_cap" in wt, wt + cc = wt["cost_cap"] + assert cc["cap_usd"] == 1.0 and cc["mode"] == "warn" + assert cc["exceeded_any"] is True + assert cc["blocked_any"] is False + # operator_prompt only appears when exceeded but NOT blocked (warn mode) + assert "operator_prompt" in cc, cc + # Synth not in per_role list + roles_in_caps = [r["role"] for r in cc["per_role"]] + assert "synthesizer" not in roles_in_caps and "synth" not in roles_in_caps, roles_in_caps + # Proposer is in there + assert "proposer" in roles_in_caps, roles_in_caps + + srv.tool_fetch = real_tool_fetch + print("OK: test_worker_tool_cost_cap") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/servers/python/crosscheck_server.py b/servers/python/crosscheck_server.py index 6b57586..b7ed356 100755 --- a/servers/python/crosscheck_server.py +++ b/servers/python/crosscheck_server.py @@ -2708,6 +2708,57 @@ def _ask_one(p: Provider, messages: list[dict], deadline: float, max_tokens: int _TOOL_CALL_RE = re.compile(r"(.*?)", flags=re.DOTALL | re.IGNORECASE) _WORKER_TOOLS_MAX_RESULT_CHARS = 4000 # per-result truncation before re-prompt +_WORKER_TOOL_COST_CAP_MODES = ("warn", "enforce", "off") + + +def _worker_tool_cost_cap_defaults(caller_cap_usd: Any, + caller_mode: Any + ) -> tuple[float | None, str]: + """Resolve per-call kwargs against CFG defaults. Returns + (cap_usd_or_None, mode_in_{warn,enforce,off}). + + cap_usd_or_None == None means the cap is disabled entirely — no + tracking, no warnings, no enforcement. Mode is only meaningful + when a cap is set.""" + cfg = (CFG.get("worker_tools") or {}) if isinstance(CFG, dict) else {} + cap = caller_cap_usd if caller_cap_usd is not None else cfg.get("cost_cap_usd") + try: + cap = float(cap) if cap is not None else None + except (TypeError, ValueError): + cap = None + if cap is not None and cap <= 0: + cap = None + mode = caller_mode if isinstance(caller_mode, str) and caller_mode \ + else cfg.get("cost_cap_mode") or "warn" + if mode not in _WORKER_TOOL_COST_CAP_MODES: + mode = "warn" + return cap, mode + + +def _worker_tool_cost_observed(aggregated: dict) -> float: + """Pull the cumulative cost out of the merged answer's usage block.""" + if not isinstance(aggregated, dict): + return 0.0 + usage = aggregated.get("usage") or {} + try: + return float(usage.get("cost_usd", 0.0) or 0.0) + except (TypeError, ValueError): + return 0.0 + + +def _worker_tool_cost_cap_refusal(observed: float, cap: float) -> str: + """`enforce`-mode refusal payload wrapped as a tool_result so the + worker sees a consistent envelope shape.""" + payload = { + "refused": True, + "tool": "", + "reason": (f"per-turn cost cap exceeded " + f"(observed=${observed:.4f}, cap=${cap:.4f})"), + "operator_hint": ("The worker has used up its inner-tool cost " + "budget for this turn. Emit your final answer " + "now using whatever information you already have."), + } + return _wrap_tool_result("", json.dumps(payload)) def _worker_tools_system_hint(worker_tools: list[str]) -> str: @@ -2866,14 +2917,19 @@ def _merge_answer_usage(base: dict, extra: dict) -> dict: def _ask_one_with_tools(p: Provider, messages: list[dict], deadline: float, max_tokens: int, purpose: str, *, worker_tools: list[str], - session_id: str | None) -> dict: + session_id: str | None, + cost_cap_usd: float | None = None, + cost_cap_mode: str | None = None) -> dict: """`_ask_one` wrapped in a bounded tool-call loop. Returns the same answer shape, plus an `inner_tool_calls` field listing each inner - call's name + status.""" + call's name + status. When a `cost_cap_usd` is set, attaches a + `cost_cap` block summarising the per-turn cost-cap state.""" allowed = [t for t in (worker_tools or []) if t in _WORKER_TOOL_ALLOWLIST] if not allowed: return _ask_one(p, messages, deadline, max_tokens, purpose=purpose) + cap_usd, cap_mode = _worker_tool_cost_cap_defaults(cost_cap_usd, cost_cap_mode) + # Inject the tool-use system hint into the FIRST system message (or # add one) so the worker knows the envelope syntax. msgs = [dict(m) for m in messages] @@ -2888,6 +2944,33 @@ def _ask_one_with_tools(p: Provider, messages: list[dict], deadline: float, aggregated: dict = {} hops = 0 last_answer: dict = {} + cost_cap_warning_emitted = False + cost_cap_blocked = False + + def _check_cost_cap() -> bool: + """Returns True if the cap is set, mode is enforce, AND observed > cap. + Side-effect: emits the warn event at most once for warn-mode.""" + nonlocal cost_cap_warning_emitted + if cap_usd is None or cap_mode == "off": + return False + observed = _worker_tool_cost_observed(aggregated) + if observed <= cap_usd: + return False + if not cost_cap_warning_emitted: + cost_cap_warning_emitted = True + _emit_event("worker_tool_cost_warning", + provider=p.name, model=p.model, purpose=purpose, + observed_usd=round(observed, 6), + cap_usd=round(cap_usd, 6), + mode=cap_mode, + session_id=session_id) + _emit_progress( + f"{p.name}: worker_tool cost cap exceeded " + f"(observed=${observed:.4f}, cap=${cap_usd:.4f}, mode={cap_mode})", + provider=p.name, model=p.model, purpose=purpose, + observed_usd=observed, cap_usd=cap_usd, mode=cap_mode, + ) + return cap_mode == "enforce" while True: ans = _ask_one(p, msgs, deadline, max_tokens, purpose=purpose) @@ -2918,6 +3001,27 @@ def _ask_one_with_tools(p: Provider, messages: list[dict], deadline: float, hops += 1 continue + # Cost-cap check BEFORE dispatching the call so an `enforce`-mode + # refusal happens BEFORE the next inner call (and before the worker + # spends more LLM tokens replying to it). + if _check_cost_cap(): + cost_cap_blocked = True + inner_calls.append({"hop": hops + 1, + "name": call.get("name"), + "status": "cost_cap_exceeded"}) + msgs = list(msgs) + [ + {"role": "assistant", "content": ans["response"]}, + {"role": "user", + "content": _worker_tool_cost_cap_refusal( + _worker_tool_cost_observed(aggregated), cap_usd)}, + ] + # One final emission round so the worker sees the refusal and + # can produce an answer with what it already has. + ans2 = _ask_one(p, msgs, deadline, max_tokens, purpose=purpose) + aggregated = _merge_answer_usage(aggregated, ans2) + last_answer = ans2 + break + # Hop budget check BEFORE executing the call so a 3rd request gets # a refusal it can incorporate (not an executed call). if hops >= _WORKER_TOOL_HOP_BUDGET: @@ -2958,9 +3062,23 @@ def _ask_one_with_tools(p: Provider, messages: list[dict], deadline: float, {"role": "user", "content": result_block}, ] hops += 1 + # Re-check after the call settles so warn-mode still emits when + # the LLM re-prompt pushed us over the cap. + _check_cost_cap() if inner_calls: aggregated["inner_tool_calls"] = inner_calls + # Mode "off" suppresses the cost_cap block entirely — the operator + # asked to ignore the cap, so don't surface it in the response. + if cap_usd is not None and cap_mode != "off": + observed = _worker_tool_cost_observed(aggregated) + aggregated["cost_cap"] = { + "cap_usd": round(cap_usd, 6), + "observed_usd": round(observed, 6), + "mode": cap_mode, + "exceeded": observed > cap_usd, + "blocked": cost_cap_blocked, + } return aggregated @@ -2991,14 +3109,18 @@ def _budget_summary(call_started: float, deadline: float, answers: list[dict], def _ask_many_parallel(providers: list[Provider], messages: list[dict], deadline: float, max_tokens: int, purpose: str = "worker", *, worker_tools: list[str] | None = None, - session_id: str | None = None) -> list[dict]: + session_id: str | None = None, + cost_cap_usd: float | None = None, + cost_cap_mode: str | None = None) -> list[dict]: # When worker_tools is provided + non-empty, each worker runs in the # bounded tool-call loop; otherwise the standard single-shot dispatch. def _dispatch_one(provider: Provider) -> dict: if worker_tools: return _ask_one_with_tools(provider, messages, deadline, max_tokens, purpose, worker_tools=worker_tools, - session_id=session_id) + session_id=session_id, + cost_cap_usd=cost_cap_usd, + cost_cap_mode=cost_cap_mode) return _ask_one(provider, messages, deadline, max_tokens, purpose) if len(providers) <= 1: @@ -3404,6 +3526,8 @@ def tool_confer(args: dict) -> dict: rejected_worker_tools = [t for t in requested_worker_tools if not (isinstance(t, str) and t in _WORKER_TOOL_ALLOWLIST)] inner_session_id = session.get("session_id") if session else args.get("session_id") + cost_cap_usd_arg = args.get("worker_tool_cost_cap_usd") + cost_cap_mode_arg = args.get("worker_tool_cost_cap_mode") if early_stop and len(selected) >= 3: # Phase 1: dispatch the first 2 panelists; check agreement; skip the @@ -3411,7 +3535,9 @@ def tool_confer(args: dict) -> dict: phase1 = _ask_many_parallel(selected[:2], messages, deadline, per_call, purpose="confer", worker_tools=accepted_worker_tools, - session_id=inner_session_id) + session_id=inner_session_id, + cost_cap_usd=cost_cap_usd_arg, + cost_cap_mode=cost_cap_mode_arg) phase1_clean = [a for a in phase1 if isinstance(a, dict) and not a.get("error")] # If a breaker would trip once phase-1's cost is rolled in, skip the @@ -3447,13 +3573,17 @@ def tool_confer(args: dict) -> dict: phase2 = _ask_many_parallel(selected[2:], messages, deadline, per_call, purpose="confer", worker_tools=accepted_worker_tools, - session_id=inner_session_id) + session_id=inner_session_id, + cost_cap_usd=cost_cap_usd_arg, + cost_cap_mode=cost_cap_mode_arg) answers = phase1 + phase2 else: answers = _ask_many_parallel(selected, messages, deadline, per_call, purpose="confer", worker_tools=accepted_worker_tools, - session_id=inner_session_id) + session_id=inner_session_id, + cost_cap_usd=cost_cap_usd_arg, + cost_cap_mode=cost_cap_mode_arg) # Scan for canary leaks BEFORE downstream derived structures consume # the answers. Any provider that echoed the nonce had indirect injection @@ -3497,9 +3627,40 @@ def tool_confer(args: dict) -> dict: if canary_leaks: result["canary_leaks"] = canary_leaks if requested_worker_tools: - result["worker_tools"] = {"accepted": accepted_worker_tools, - "rejected": rejected_worker_tools, - "hop_budget": _WORKER_TOOL_HOP_BUDGET} + wt_meta: dict = {"accepted": accepted_worker_tools, + "rejected": rejected_worker_tools, + "hop_budget": _WORKER_TOOL_HOP_BUDGET} + # Aggregate per-worker cost-cap state — anyone exceeded triggers + # the agent-facing warning even if only some panelists ran over. + cost_caps = [a.get("cost_cap") for a in answers + if isinstance(a, dict) and isinstance(a.get("cost_cap"), dict)] + if cost_caps: + any_exceeded = any(c.get("exceeded") for c in cost_caps) + any_blocked = any(c.get("blocked") for c in cost_caps) + wt_meta["cost_cap"] = { + "cap_usd": cost_caps[0]["cap_usd"], + "mode": cost_caps[0]["mode"], + "exceeded_any": bool(any_exceeded), + "blocked_any": bool(any_blocked), + "per_provider": [ + {"provider": a.get("provider"), + **(a.get("cost_cap") or {})} + for a in answers + if isinstance(a, dict) and isinstance(a.get("cost_cap"), dict) + ], + } + if any_exceeded and not any_blocked: + # Operator-facing prompt: this is the agent's signal to + # ask the user how to proceed (enforce next time, allow + # override, ignore). + wt_meta["cost_cap"]["operator_prompt"] = ( + "Worker tool-use cost exceeded the soft cap on at " + "least one panelist. Decide before the next call: " + "enforce (refuse further inner calls), warn-only " + "(continue, override allowed), or ignore (disable " + "the cap)." + ) + result["worker_tools"] = wt_meta if early_stop: result["early_stopped"] = early_stopped result["skipped_providers"] = skipped_providers @@ -4101,6 +4262,8 @@ def tool_coordinate(args: dict) -> dict: rejected_worker_tools = [t for t in requested_worker_tools if not (isinstance(t, str) and t in _WORKER_TOOL_ALLOWLIST)] inner_session_id = session.get("session_id") if session else args.get("session_id") + cost_cap_usd_arg = args.get("worker_tool_cost_cap_usd") + cost_cap_mode_arg = args.get("worker_tool_cost_cap_mode") # ---- Step 1: Proposer ---------------------------------------------------- prop_system = ( @@ -4117,6 +4280,8 @@ def tool_coordinate(args: dict) -> dict: purpose="worker", worker_tools=accepted_worker_tools, session_id=inner_session_id, + cost_cap_usd=cost_cap_usd_arg, + cost_cap_mode=cost_cap_mode_arg, ) proposal_render = _format_role_turn("proposer", proposal_obj, fallback_text=proposal_ans.get("response", "") if proposal_ans else "") @@ -4137,6 +4302,8 @@ def _critique(p: Provider) -> tuple[Provider, dict | None, dict, list[str]]: purpose="worker", worker_tools=accepted_worker_tools, session_id=inner_session_id, + cost_cap_usd=cost_cap_usd_arg, + cost_cap_mode=cost_cap_mode_arg, ) return p, obj, ans, errs @@ -4295,12 +4462,39 @@ def _critique(p: Provider) -> tuple[Provider, dict | None, dict, list[str]]: if canary_leaks: result["canary_leaks"] = canary_leaks if requested_worker_tools: - result["worker_tools"] = { + wt_meta: dict = { "accepted": accepted_worker_tools, "rejected": rejected_worker_tools, "hop_budget": _WORKER_TOOL_HOP_BUDGET, "applies_to": ["proposer", "critic"], # synth is excluded } + # Pull cost_cap state off the proposer + critic answers (synth is + # excluded from worker_tools so it has no cost_cap block). + role_caps = [] + if isinstance(proposal_ans, dict) and isinstance(proposal_ans.get("cost_cap"), dict): + role_caps.append(("proposer", proposal_ans["cost_cap"])) + for ans in critique_answers: + if isinstance(ans, dict) and isinstance(ans.get("cost_cap"), dict): + role_caps.append(("critic", ans["cost_cap"])) + if role_caps: + any_exceeded = any(c.get("exceeded") for _, c in role_caps) + any_blocked = any(c.get("blocked") for _, c in role_caps) + wt_meta["cost_cap"] = { + "cap_usd": role_caps[0][1]["cap_usd"], + "mode": role_caps[0][1]["mode"], + "exceeded_any": bool(any_exceeded), + "blocked_any": bool(any_blocked), + "per_role": [{"role": role, **cap} for role, cap in role_caps], + } + if any_exceeded and not any_blocked: + wt_meta["cost_cap"]["operator_prompt"] = ( + "Worker tool-use cost exceeded the soft cap on at " + "least one role. Decide before the next call: " + "enforce (refuse further inner calls), warn-only " + "(continue, override allowed), or ignore (disable " + "the cap)." + ) + result["worker_tools"] = wt_meta _attach_usage_block(result, all_answers, session_id=session.get("session_id") if session else None, tool_name="coordinate") @@ -9444,6 +9638,8 @@ def _request_structured(p: "Provider", base_messages: list[dict], schema: dict, *, worker_tools: list[str] | None = None, session_id: str | None = None, + cost_cap_usd: float | None = None, + cost_cap_mode: str | None = None, ) -> tuple[Any | None, dict, list[str]]: """Ask provider for JSON matching `schema`. Validate; retry once on failure with the errors fed back in the prompt. Returns (parsed_or_None, raw_answer, errors). @@ -9460,6 +9656,7 @@ def _request_structured(p: "Provider", base_messages: list[dict], schema: dict, p, base_messages, schema, max_tokens=max_tokens, deadline=deadline, max_retries=max_retries, purpose=purpose, worker_tools=allowed_tools, session_id=session_id, + cost_cap_usd=cost_cap_usd, cost_cap_mode=cost_cap_mode, ) sys_idx = next((i for i, m in enumerate(base_messages) if m.get("role") == "system"), None) schema_text = json.dumps(schema, separators=(",", ":")) @@ -9503,11 +9700,14 @@ def _request_structured_with_tools(p: "Provider", base_messages: list[dict], max_retries: int, purpose: str, worker_tools: list[str], session_id: str | None, + cost_cap_usd: float | None = None, + cost_cap_mode: str | None = None, ) -> tuple[Any | None, dict, list[str]]: """Tool-call-aware variant of `_request_structured`. Interleaves the `_ask_one_with_tools` hop loop with schema-validated emission: tool calls (counted against the hop budget) come first; the final emission - is parsed + validated; one retry on validation failure.""" + is parsed + validated; one retry on validation failure. The per-turn + cost cap is honored identically to `_ask_one_with_tools`.""" sys_idx = next((i for i, m in enumerate(base_messages) if m.get("role") == "system"), None) schema_text = json.dumps(schema, separators=(",", ":")) instr = ( @@ -9528,12 +9728,54 @@ def _request_structured_with_tools(p: "Provider", base_messages: list[dict], else: msgs.insert(0, {"role": "system", "content": (instr + hint).strip()}) + cap_usd, cap_mode = _worker_tool_cost_cap_defaults(cost_cap_usd, cost_cap_mode) + aggregated: dict = {} inner_calls: list[dict] = [] hops = 0 attempt = 0 last_answer: dict = {} last_errs: list[str] = [] + cost_cap_warning_emitted = False + cost_cap_blocked = False + + def _check_cost_cap() -> bool: + nonlocal cost_cap_warning_emitted + if cap_usd is None or cap_mode == "off": + return False + observed = _worker_tool_cost_observed(aggregated) + if observed <= cap_usd: + return False + if not cost_cap_warning_emitted: + cost_cap_warning_emitted = True + _emit_event("worker_tool_cost_warning", + provider=p.name, model=p.model, purpose=purpose, + observed_usd=round(observed, 6), + cap_usd=round(cap_usd, 6), + mode=cap_mode, + session_id=session_id) + _emit_progress( + f"{p.name}: worker_tool cost cap exceeded " + f"(observed=${observed:.4f}, cap=${cap_usd:.4f}, mode={cap_mode})", + provider=p.name, model=p.model, purpose=purpose, + observed_usd=observed, cap_usd=cap_usd, mode=cap_mode, + ) + return cap_mode == "enforce" + + def _finalize(obj: Any | None, errs: list[str]) -> tuple[Any | None, dict, list[str]]: + if inner_calls: + aggregated["inner_tool_calls"] = inner_calls + # Mode "off" suppresses the cost_cap block entirely. + if cap_usd is not None and cap_mode != "off": + observed = _worker_tool_cost_observed(aggregated) + aggregated["cost_cap"] = { + "cap_usd": round(cap_usd, 6), + "observed_usd": round(observed, 6), + "mode": cap_mode, + "exceeded": observed > cap_usd, + "blocked": cost_cap_blocked, + } + return obj, aggregated, errs while True: ans = _ask_one(p, msgs, deadline, max_tokens, purpose=purpose) @@ -9541,11 +9783,9 @@ def _request_structured_with_tools(p: "Provider", base_messages: list[dict], aggregated = _merge_answer_usage(aggregated, ans) if aggregated else dict(ans) if "error" in ans or not isinstance(ans.get("response"), str): - if inner_calls: - aggregated["inner_tool_calls"] = inner_calls - return None, aggregated, [ + return _finalize(None, [ f"provider error: {ans.get('error_kind', 'other')}: {ans.get('error', '')}" - ] + ]) text = ans["response"] call, parse_err = _extract_tool_call(text) @@ -9556,8 +9796,7 @@ def _request_structured_with_tools(p: "Provider", base_messages: list[dict], inner_calls.append({"hop": hops + 1, "name": None, "status": "parse_error", "error": parse_err}) if hops >= _WORKER_TOOL_HOP_BUDGET: - aggregated["inner_tool_calls"] = inner_calls - return None, aggregated, [parse_err or "malformed tool_call"] + return _finalize(None, [parse_err or "malformed tool_call"]) msgs = list(msgs) + [ {"role": "assistant", "content": text}, {"role": "user", @@ -9569,6 +9808,28 @@ def _request_structured_with_tools(p: "Provider", base_messages: list[dict], hops += 1 continue + # Cost-cap check BEFORE hop budget check: an enforce-mode cap + # should refuse the call regardless of whether the hop budget + # still has room. + if _check_cost_cap(): + cost_cap_blocked = True + inner_calls.append({"hop": hops + 1, + "name": call.get("name"), + "status": "cost_cap_exceeded"}) + msgs = list(msgs) + [ + {"role": "assistant", "content": text}, + {"role": "user", + "content": _worker_tool_cost_cap_refusal( + _worker_tool_cost_observed(aggregated), cap_usd)}, + ] + ans2 = _ask_one(p, msgs, deadline, max_tokens, purpose=purpose) + aggregated = _merge_answer_usage(aggregated, ans2) + last_answer = ans2 + text = ans2.get("response", "") if isinstance(ans2, dict) else "" + obj = _extract_json(text) if text else None + errs = _validate(obj, schema) if obj else ["could not parse JSON from response"] + return _finalize(obj if not errs else None, errs) + if hops >= _WORKER_TOOL_HOP_BUDGET: inner_calls.append({"hop": hops + 1, "name": call.get("name"), @@ -9584,14 +9845,10 @@ def _request_structured_with_tools(p: "Provider", base_messages: list[dict], ans2 = _ask_one(p, msgs, deadline, max_tokens, purpose=purpose) aggregated = _merge_answer_usage(aggregated, ans2) last_answer = ans2 - # Fall through to schema validation on the next iteration's - # response by setting `text` and continuing the parse path. text = ans2.get("response", "") if isinstance(ans2, dict) else "" - # Don't retry on schema fail; we're already at the limit. obj = _extract_json(text) if text else None errs = _validate(obj, schema) if obj else ["could not parse JSON from response"] - aggregated["inner_tool_calls"] = inner_calls - return (obj if not errs else None), aggregated, errs + return _finalize(obj if not errs else None, errs) tool_name = call.get("name") result_block = _worker_tools_dispatch(call, session_id=session_id) @@ -9612,6 +9869,9 @@ def _request_structured_with_tools(p: "Provider", base_messages: list[dict], {"role": "user", "content": result_block}, ] hops += 1 + # Post-call cap check so warn-mode notices when an LLM + # re-prompt itself pushed us over the cap. + _check_cost_cap() continue # No tool_call in the response: this is the worker's final emission. @@ -9622,15 +9882,11 @@ def _request_structured_with_tools(p: "Provider", base_messages: list[dict], else: errs = _validate(obj, schema) if not errs: - if inner_calls: - aggregated["inner_tool_calls"] = inner_calls - return obj, aggregated, [] + return _finalize(obj, []) last_errs = errs if attempt >= max_retries: - if inner_calls: - aggregated["inner_tool_calls"] = inner_calls - return None, aggregated, last_errs + return _finalize(None, last_errs) # One re-prompt with validation feedback. msgs = list(msgs) + [ {"role": "assistant", "content": text},